<a href="https://colab.research.google.com/github/Loki-33/RL-Algos/blob/main/OpenAI_gym%5BAtari%5D_DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import gym
import cv2
from gym.wrappers import AtariPreprocessing, FrameStack

In [None]:
def preprocess_frame(frame):
  gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
  resized = cv2.resize(gray, (84,84), interpolation=cv2.INTER_AREA) #INTER_AREA->preserves structure and motion
  return resized

def stack_frames(stacked_frames, frames, is_new_episode):
  frame = preprocess_frame(frames)
  if is_new_episode:
    stacked_frames = deque([frame]*4, maxlen=4)
  else:
    stacked_frames.append(frame)
  return np.stack(stacked_frames, axis=0), stacked_frames

In [None]:
#we donot need to do the above we can use gym build in wrappers

In [None]:
!pip install gym[atari,accept-rom-license]

In [None]:
env = gym.make("BreakoutNoFrameskip-v4", render_mode='rgb_array')
env = AtariPreprocessing(env, frame_skip=1, grayscale_obs=True, scale_obs=False)
env = FrameStack(env, num_stack=4)

In [None]:
state, _ = env.reset()

In [None]:
state

In [None]:
print(env.action_space)
'''
0: DO NOTHING
1: FIRE
2: MOVE RIGHT
3: MOVE LEFT
'''

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig = plt.figure(figsize=(16,4))
for i in range(4):
  fig.add_subplot(1,4,i+1)
  plt.imshow(state[i], cmap="gray")
  plt.axis("off")

In [None]:
class DQNetwork(nn.Module):
  def __init__(self, action_dim):
    super(DQNetwork, self).__init__()
    self.net = nn.Sequential(
        nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(7*7*64, 512),
        nn.ReLU(),
        nn.Linear(512, action_dim)
    )
  def forward(self, x):
    return self.net(x)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
action_dim = 4
q_net = DQNetwork(action_dim).to(device)
target_net = DQNetwork(action_dim).to(device)
target_net.load_state_dict(q_net.state_dict())

In [None]:
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.05
lr = 1e-3
batch_size = 32
buffer_size = 1000
target_update_freq = 20
episodes = 500

In [None]:
optimizer = optim.Adam(q_net.parameters(), lr=lr)
loss_fn = nn.MSELoss().to(device)

In [None]:
buffer = deque(maxlen=buffer_size)

In [None]:
all_rewards = []

In [None]:
for ep in range(episodes):
  state, _ = env.reset()
  done = False
  episode_reward = 0
  while not done:
    if random.random() < epsilon:
      action = env.action_space.sample()
    else:
      with torch.no_grad():
        state_tensor = torch.tensor(np.array(state), dtype=torch.float32, device=device).unsqueeze(0) / 255.0
        q_values = q_net(state_tensor)
        action = torch.argmax(q_values, dim=1).item()
    next_state, reward, done, truncated, _ = env.step(action)
    episode_reward += reward
    buffer.append((state, action, reward, next_state, done))
    state = next_state
    reward = np.clip(reward, -1, 1)
    if len(buffer) >= batch_size:
      batch = random.sample(buffer, batch_size)
      states, actions, rewards, next_states, dones = zip(*batch)

      s = torch.tensor(np.array(states), dtype=torch.float32, device=device) / 255.0
      a = torch.tensor(actions, dtype=torch.int64, device=device).unsqueeze(1)
      r = torch.tensor(rewards, dtype=torch.float32, device=device).unsqueeze(1)
      s2 = torch.tensor(np.array(next_states), dtype=torch.float32, device=device)/ 255.0
      d = torch.tensor(dones, dtype=torch.float32, device=device).unsqueeze(1)

      q_values = q_net(s).gather(1, a)

      with torch.no_grad():
        #double DQN
        next_action = q_net(s2).max(1, keepdim=True)[1]
        next_q_values = target_net(s2).gather(1, next_action)
        target_q_values = r + gamma * next_q_values * (1 - d)

      loss = loss_fn(q_values, target_q_values)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    if done or truncated:
      break

  epsilon = max(epsilon_min, epsilon * epsilon_decay)

  if ep % target_update_freq == 0:
    target_net.load_state_dict(q_net.state_dict())

  all_rewards.append(episode_reward)

  if ep%50 == 0:
    avg_reward = np.mean(all_rewards[-50:])
    print(f"Episode {ep}, Reward: {episode_reward}, Avg(50): {avg_reward:.2f}, Epsilon: {epsilon:.3f}")


In [None]:
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
def plot_frames(frame):
  frame = frame[0]
  plt.imshow(frame)
  plt.axis("off")
  clear_output(wait=True)
  display(plt.gcf())

In [None]:
done = False
state, _ = env.reset()

In [None]:
while not done:
  frame = env.render()
  plot_frames(frame)
  time.sleep(0.01)
  with torch.no_grad():
    state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) / 255.0
    q_values = q_net(state_tensor)
    action = torch.argmax(q_values, dim=1).item()
  next_state, reward, done, truncated, _ = env.step(action)
  state = next_state
env.close()