<a href="https://colab.research.google.com/github/Paraskevi-KIvroglou/rl-pong-agent/blob/main/Atari_Agent_Async.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision
!pip install gymnasium
!pip install gymnasium[atari]
!pip install gymnasium[accept-rom-license]
!pip install wandb
!pip install matplotlib

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
from torchvision import transforms
import wandb

  self.hub = sentry_sdk.Hub(client)


In [None]:
wandb.init(project="dqn-pong", entity="paraskevikivroglou")

  return LooseVersion(v) >= LooseVersion(check)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

In [None]:
# Experience Replay buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        experiences = random.sample(self.buffer, k=batch_size)
        states, actions, rewards, next_states, dones = zip(*experiences)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

In [None]:
import cv2
import numpy as np

def preprocess_frame(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    frame = cv2.resize(frame, (84, 84))
    frame = frame / 255.0
    return frame

In [None]:
# DQN Agent
class DQNAgent:
  def __init__(self, state_shape, n_actions, device):
      self.device = device
      self.dqn = DQN(state_shape, n_actions).to(device)
      self.target_dqn = DQN(state_shape, n_actions).to(device)
      self.target_dqn.load_state_dict(self.dqn.state_dict())
      self.optimizer = optim.RMSprop(self.dqn.parameters(), lr=0.00025, alpha=0.95, eps=0.01)
      self.memory = ReplayBuffer(1000000)
      self.batch_size = 4
      self.gamma = 0.99
      self.epsilon = 1.0
      self.epsilon_min = 0.1
      self.epsilon_decay = 1000000
      self.update_target_steps = 10000
      self.steps = 0
      self.rewards = []
      self.losses = []

  def remember(self, state, action, reward, next_state, done):
      self.memory.push(state, action, reward, next_state, done)

  def act(self, state):
      if random.random() < self.epsilon:
          return random.randrange(self.dqn.fc[-1].out_features)
      with torch.no_grad():
          state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
          q_values = self.dqn(state)
          return q_values.argmax().item()

  def update_target_network(self):
      self.target_dqn.load_state_dict(self.dqn.state_dict())

  def update_epsilon(self):
      self.epsilon = max(self.epsilon_min, self.epsilon - (self.epsilon - self.epsilon_min) / self.epsilon_decay)

  def train(self):
      if len(self.memory) < self.batch_size:
          return

      batch = self.memory.sample(self.batch_size)
      states, actions, rewards, next_states, dones = zip(*batch)

      states = torch.FloatTensor(states).to(self.device)
      actions = torch.LongTensor(actions).to(self.device)
      rewards = torch.FloatTensor(rewards).to(self.device)
      next_states = torch.FloatTensor(next_states).to(self.device)
      dones = torch.FloatTensor(dones).to(self.device)

      current_q_values = self.dqn(states).gather(1, actions.unsqueeze(1))
      next_q_values = self.target_dqn(next_states).max(1)[0].detach()
      target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

      loss = nn.functional.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))

      self.optimizer.zero_grad()
      loss.backward()
      for param in self.dqn.parameters():
          param.grad.data.clamp_(-1, 1)
      self.optimizer.step()

      self.steps += 1
      if self.steps % self.update_target_steps == 0:
          self.target_dqn.load_state_dict(self.dqn.state_dict())

      self.update_epsilon()

  def replay(self, batch_size):
    memory = self.memory.buffer
    minibatch = random.sample(memory, batch_size)
    for state, action, reward, next_state, done in minibatch:
        target = reward
        self.rewards.append(reward)
        if not done:
            next_state = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
            target += self.gamma * torch.max(self.target_dqn(next_state)).item()
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        target_f = self.dqn(state)
        target_f[0][action] = target
        self.optimizer.zero_grad()
        loss = nn.MSELoss()(target_f, self.dqn(state))
        loss.backward()
        self.losses.append(loss.item())
        self.optimizer.step()
    if self.epsilon > self.epsilon_min:
        self.epsilon *= self.epsilon_decay

# Main training loop (assuming you have an environment)
def train(env, agent, num_episodes):
  for episode in range(num_episodes):
      state = env.reset()
      state = preprocess_frame(state[0])
      state = np.stack([state] * 4, axis=0)
      done = False
      total_reward = 0

      while not done:
          action = agent.act(state)
          next_state, reward, done, _ , _= env.step(action)
          next_state = preprocess_frame(next_state)
          next_state = np.append(state[1:], np.expand_dims(next_state, axis=0), axis=0)
          agent.remember(state, action, reward, next_state, done)
          state = next_state
          total_reward += reward

          if len(agent.memory) > 32:
              loss = agent.replay(32)
              agent.losses.append(loss)
      print(episode)
      agent.rewards.append(total_reward)
      agent.update_target_network()

      if episode % 10 == 0:
          #print(f'Episode {e}/{1000}, Reward: {total_reward}, Epsilon: {agent.epsilon}')
          wandb.log({"episode": episode, "total_reward": total_reward, "Epsilon": agent.epsilon})
          print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon}")

In [None]:
def worker(agent, env_name, max_episodes, epsilon, epsilon_decay, epsilon_min, global_rewards, global_losses):
    env = gym.make(env_name)
    for episode in range(max_episodes):
        total_reward, loss = agent.train(env, max_steps=1000, epsilon=epsilon, epsilon_decay=epsilon_decay, epsilon_min=epsilon_min)
        global_rewards.append(total_reward)
        global_losses.append(loss)
        if episode % 10 == 0:
            print(f"Thread {threading.current_thread().name}, Episode {episode}, Reward: {total_reward}, Loss: {loss}")


In [None]:
thread_metrics = {}

In [None]:
def check_and_initialize_optimizer_state(optimizer):
    for param_group in optimizer.param_groups:
        for param in param_group['params']:
            state = optimizer.state[param]
            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(param.data)
            if 'exp_avg_sq' not in state:
                state['exp_avg_sq'] = torch.zeros_like(param.data)
            if 'step' not in state:
                state['step'] = torch.zeros(1, dtype=torch.float)

In [None]:
import threading
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
from collections import deque
import random
import matplotlib.pyplot as plt
import os
import cv2


def preprocess_state(state, device):
    state = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
    state = cv2.resize(state, (84, 84))
    state = np.array(state, dtype=np.float32) / 255.0  # Normalize pixel values
    state = np.expand_dims(state, axis=0)  # Add channel dimension
    return torch.tensor(state, device=device)


# Replay Buffer class
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        experiences = random.sample(self.buffer, k=batch_size)
        states, actions, rewards, next_states, dones = zip(*experiences)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)


# Define the Q-network
class QNetwork(nn.Module):
    def __init__(self, input_shape, action_size):
        super(QNetwork, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(self.feature_size(input_shape), 512)
        self.fc2 = nn.Linear(512, action_size)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def feature_size(self, input_shape):
         with torch.no_grad():
            return self.conv3(self.conv2(self.conv1(torch.zeros(1, *input_shape)))).view(1, -1).size(1)


# Define the Agent
class Agent:
    def __init__(self, state_size, action_size, device, gamma=0.99):
        self.state_size = state_size
        self.action_size = action_size
        self.device = device
        # self.optimizer = optimizer
        self.gamma = gamma
        #self.local_network.load_state_dict(self.global_network.state_dict())
        self.q_network = QNetwork(state_size, action_size).to(device)
        self.target_network = QNetwork(state_size, action_size).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=0.001)

        self.replay_buffer = ReplayBuffer(10000)
        self.batch_size = 64
        self.gamma = 0.99

        self.epsilon = 0.5
        self.epsilon_min = 0.1
        self.decay_episodes = 200
        self.epsilon_decay = (self.epsilon_min / self.epsilon) ** (1 / self.decay_episodes)
        self.episode_count = 0

        self.losses = []
        self.q_values = []
        self.epsilons = []
        self.steps = []

    def choose_action(self, state):
        if np.random.rand() <= self.epsilon:
            action = random.choice(range(self.action_size))
        else:
            with torch.no_grad():
                q_values = self.q_network(state)
                self.q_values.append(torch.max(q_values).item())
                action = torch.argmax(q_values).item()
        return action

    def replay_experience(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        states = torch.stack([torch.tensor(s, dtype=torch.float32, device=self.device) for s in states])
        next_states = torch.stack([torch.tensor(s, dtype=torch.float32, device=self.device) for s in next_states])
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        q_values = self.q_network(states)
        next_q_values = self.target_network(next_states)
        q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = rewards + (1 - dones) * self.gamma * next_q_value

        loss = nn.MSELoss()(q_value, expected_q_value)

        check_and_initialize_optimizer_state(self.optimizer)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.losses.append(loss.item())

    # def preprocess_state(self, state):
    #     state = np.mean(state, axis=2).astype(np.float32)
    #     state = np.resize(state, (84, 84))
    #     state /= 255.0
    #     return state

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def save_model(self, filename):
        torch.save(self.q_network.state_dict(), filename)

    def load_model(self, filename):
        if os.path.isfile(filename):
            self.q_network.load_state_dict(torch.load(filename))
            self.target_network.load_state_dict(self.q_network.state_dict())
            print(f"Model loaded from {filename}")
        else:
            print(f"No model file found at {filename}")

# Define the worker
def worker(agent, env_name, max_episodes, lock, thread_id, thread_metrics):
    env = gym.make(env_name)
    total_steps = 0

    thread_metrics[thread_id] = {
        'epsilons': [],
        'q_values': [],
        'losses': [],
        'steps' : []
    }

    for episode in range(max_episodes):
        state = env.reset()
        state = preprocess_state(state[0], agent.device)

        state_stack = [state] * 4  # Create a stack of 4 initial frames .unsqueeze(0)
        # state = state.repeat(1, 4, 1, 1)
        #state = state.repeat(1, state.shape[0], 1, 1)

        total_reward = 0
        done = False

        while not done:

            state_tensor = torch.cat(state_stack, dim=0)
            action = agent.choose_action(state_tensor.unsqueeze(0))
            next_state, reward, done, _ , _= env.step(action)

            next_state = preprocess_state(next_state, agent.device)
            #next_state = next_state.unsqueeze(0).repeat(1, 1, 1, 1)
            #next_state = np.append(state[1:], np.expand_dims(next_state, axis=0), axis=0)
            state_stack.pop(0)  # Remove the oldest frame
            state_stack.append(next_state)

            next_state_tensor = torch.cat(state_stack, dim=0)  # Concatenate list of frames to tensor with shape (4, 84, 84)
            agent.replay_buffer.push(state_tensor.cpu().numpy(), action, reward, next_state_tensor.cpu().numpy(), done)

            total_reward += reward

            with lock:
                agent.replay_experience()

            if total_steps % 100 == 0:
                with lock:
                    agent.update_target_network()

            total_steps += 1

        if episode % 10 == 0:
            agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
        agent.episode_count += 1
        thread_metrics[thread_id]['epsilons'].append(agent.epsilon)
        thread_metrics[thread_id]['q_values'].extend(agent.q_values)
        thread_metrics[thread_id]['losses'].extend(agent.losses)
        thread_metrics[thread_id]['steps'].extend([total_steps] * len(agent.losses))
        agent.q_values.clear()
        agent.losses.clear()

        print(f"Thread: {thread_id}, Episode: {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon}")


# Main function to initialize agents and run threads
def main(agent,env_name = 'PongNoFrameskip-v4', max_episodes = 300, load_model_file=None, save_model_file=None):
    if load_model_file:
        agent.load_model(load_model_file)

    lock = threading.Lock()
    threads = []
    # thread_metrics = {}
    num_threads = 4

    for thread_id in range(num_threads):
        t = threading.Thread(target=worker, args=(agent, env_name, max_episodes, lock, thread_id, thread_metrics))
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

    if save_model_file:
        agent.save_model(save_model_file)

    print("Training completed.")



In [None]:
def plot_metrics(thread_metrics):
    plt.figure(figsize=(15, 12))

    for thread_id, metrics in thread_metrics.items():
        plt.subplot(3, 1, 1)
        plt.plot(metrics['epsilons'], label=f'Thread {thread_id}')
        plt.xlabel('Episode')
        plt.ylabel('Epsilon')
        plt.title('Epsilon Decay')

        plt.subplot(3, 1, 2)
        plt.plot(metrics['q_values'], label=f'Thread {thread_id}')
        plt.xlabel('Step')
        plt.ylabel('Q Value')
        plt.title('Q Values')

        plt.subplot(3, 1, 3)
        plt.plot(metrics['losses'], label=f'Thread {thread_id}')
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.title('Losses')

    plt.subplot(3, 1, 1)
    plt.legend()
    plt.subplot(3, 1, 2)
    plt.legend()
    plt.subplot(3, 1, 3)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
env_name = 'PongNoFrameskip-v4'
state_size = (4, 84, 84)  # Example shape for Atari frames
action_size = gym.make(env_name).action_space.n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(state_size)
print(action_size)

agent = Agent(state_size, action_size, device)
main(agent, save_model_file='dqn_model.pth')

(4, 84, 84)
6
Thread: 2, Episode: 0, Total Reward: -21.0, Epsilon: 0.49599255119493924
Thread: 3, Episode: 0, Total Reward: -20.0, Epsilon: 0.49201722168172884
Thread: 1, Episode: 0, Total Reward: -20.0, Epsilon: 0.48807375402753334
Thread: 0, Episode: 0, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 2, Episode: 1, Total Reward: -20.0, Epsilon: 0.484161892862815
Thread: 3, Episode: 1, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 0, Episode: 1, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 1, Episode: 1, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 2, Episode: 2, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 3, Episode: 2, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 0, Episode: 2, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 1, Episode: 2, Total Reward: -18.0, Epsilon: 0.484161892862815
Thread: 2, Episode: 3, Total Reward: -21.0, Epsilon: 0.484161892862815
Thread: 0, Episode: 3, Total Reward: -21.0, Epsilon: 0.48

In [None]:
plot_metrics(thread_metrics)

NameError: name 'thread_metrics' is not defined

In [None]:
def evaluate_model(env_name, global_network, num_episodes=10):
    env = gym.make(env_name)
    total_rewards = []

    for episode in range(num_episodes):
        state = env.reset()
        state = preprocess_frame(state)
        state = np.stack([state] * 4, axis=0)
        done = False
        total_reward = 0

        while not done:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_values = global_network(state)
            action = torch.argmax(q_values).item()
            next_state, reward, done, _ = env.step(action)
            next_state = preprocess_frame(next_state)
            next_state = np.append(state[1:], np.expand_dims(next_state, axis=0), axis=0)
            total_reward += reward
            state = next_state

        total_rewards