In [54]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
from pettingzoo.atari import space_invaders_v2
import imageio

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [55]:
# Hyperparameters
EPISODES = 1  # Number of episodes for training
BATCH_SIZE = 32  # Batch size for replay buffer sampling
GAMMA = 0.99  # Discount factor
LEARNING_RATE = 1e-3  # Learning rate for optimizer
BUFFER_SIZE = 100  # Size of replay buffer
TARGET_UPDATE = 10  # Frequency to update target network

In [56]:
# Q-Network architecture
class DQN(nn.Module):
    def __init__(self, action_size):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(3 * 210 * 160, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_size),
        )

    def forward(self, x):  # x.shape = (batch_size, 3, 210, 160)
        return self.net(x.reshape(x.size(0), -1))

In [57]:
# Replay Buffer
class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)

    def add(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

In [58]:
# DQN Agent
class DQNAgent:
    def __init__(self, action_size):
        self.action_size = action_size
        self.q_network = DQN(action_size).to(device)
        self.target_network = DQN(action_size).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=LEARNING_RATE)
        self.replay_buffer = ReplayBuffer(BUFFER_SIZE)
        self.steps_done = 0

    def select_action(self, state, epsilon=0.1):
        if random.random() < epsilon:
            return random.randint(0, self.action_size - 1)
        else:
            with torch.no_grad():
                state = torch.tensor(
                    state, device=device, dtype=torch.float32
                ).unsqueeze(0)
                q_values = self.q_network(state)
                return q_values.argmax().item()

    def train(self):
        if len(self.replay_buffer.buffer) < BATCH_SIZE:
            return  # Wait until buffer has enough samples

        # Sample batch from replay buffer
        batch = self.replay_buffer.sample(BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.tensor(np.array(states), device=device, dtype=torch.float32)
        actions = torch.tensor(actions, device=device, dtype=torch.long)
        rewards = torch.tensor(rewards, device=device, dtype=torch.float32)
        next_states = torch.tensor(
            np.array(next_states), device=device, dtype=torch.float32
        )
        dones = torch.tensor(dones, device=device, dtype=torch.float32)

        # Compute Q-values for current states
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()

        # Compute target Q-values for next states
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + GAMMA * next_q_values * (1 - dones)

        # Compute loss and optimize
        loss = nn.MSELoss()(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network periodically
        if self.steps_done % TARGET_UPDATE == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

In [59]:
# Initialize environment and agent
env = space_invaders_v2.env(render_mode="rgb_array")  # Render mode for capturing frames
agent = DQNAgent(action_size=env.action_space(env.possible_agents[0]).n)

In [60]:
# Training loop
for episode in range(EPISODES):
    env.reset(seed=42)
    total_reward = 0

    # Initialize video writer
    frames = []

    for agent_id in env.agent_iter():
        observation, reward, termination, truncation, info = env.last()
        done = termination or truncation

        # Preprocess observation (resize and normalize)
        if observation is not None:
            observation = (
                np.transpose(observation, (2, 0, 1)) / 255.0
            )  # Normalize pixel values
            observation = torch.tensor(observation, device=device, dtype=torch.float32)

        # Select action
        action = agent.select_action(observation) if not done else None

        # Take action in environment
        env.step(action)
        next_observation, reward, _, _, _ = env.last()
        total_reward += reward

        # Capture frame for video
        frame = env.render()
        frames.append(frame)

        # Process next observation
        if next_observation is not None:
            next_observation = np.transpose(next_observation, (2, 0, 1)) / 255.0
            next_observation = torch.tensor(
                next_observation, device=device, dtype=torch.float32
            )

        # Store experience in replay buffer
        agent.replay_buffer.add(
            (
                observation.cpu().numpy(),
                action,
                reward,
                (
                    next_observation.cpu().numpy()
                    if next_observation is not None
                    else None
                ),
                done,
            )
        )

        # Train the agent
        agent.train()
        agent.steps_done += 1

        # End the loop if the game is over
        if done:
            break

    # Save episode as video
    video_filename = f"space_invaders_episode_{episode + 1}.mp4"
    with imageio.get_writer(video_filename, fps=30) as video:
        for frame in frames:
            video.append_data(frame)
    print(
        f"Episode {episode + 1}/{EPISODES}, Total Reward: {total_reward}, Video saved as {video_filename}"
    )

env.close()

  state = torch.tensor(


TypeError: 'NoneType' object cannot be interpreted as an integer