In [None]:
!pip install gymnasium
# !pip install torchviz
!pip install pygame

In [None]:
# %matplotlib inline
# %load_ext tensorboard
# %tensorboard --logdir logs

In [None]:
from itertools import count
import random
import gymnasium as gym
import matplotlib
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import math
from collections import deque, namedtuple
import random
from torchvision.models import resnet18
import numpy as np

In [None]:
def render_environment(env):
    plt.figure(figsize=(6, 4))
    plt.imshow(env.render())
    plt.axis('off')
    plt.show()

In [None]:
env = gym.make("CartPole-v1",render_mode="rgb_array")
writer = SummaryWriter("logs")
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else ("cuda" if torch.cuda.is_available() else "cpu")
)
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

In [None]:
#HYPER-PARAMETERS
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer

BATCH_SIZE = 512
GAMMA = 0.4
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
REPLAY_BUFFER_SIZE = 10000

In [None]:
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
# resnet_model = resnet18(pretrained=True)

In [None]:
# class ResNetFeatureExtractor(nn.Module):
#     def __init__(self):
#         super(ResNetFeatureExtractor, self).__init__()
#         self.features = nn.Sequential(*list(resnet_model.children())[:-1])
#         for param in self.features.parameters():
#             param.requires_grad = False
#     def forward(self, x):
#         x = self.features(x)
#         return x

In [None]:
class DQN(nn.Module):
    def __init__(self, n_observations,n_actions):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)
        self.bn3 = nn.BatchNorm2d(64)
        
        # Define ReLU activation function
        self.relu = nn.ReLU(inplace=True)
        
        # Define adaptive average pooling
        self.avgpool = nn.AdaptiveAvgPool2d((4,4))
        self.layer1 = nn.Linear(1024, 128)
        self.layer2 = nn.Linear(128, 64)
        self.layer3 = nn.Linear(64, n_actions)

    def forward(self, x):
        x = self.conv1(x.to(device))
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        
        # Adaptive average pooling
        x = self.avgpool(x)
        x = F.relu(self.layer1(x.flatten(1)))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [None]:
n_actions = env.action_space.n
state, info = env.reset(seed=42)
n_observations = len(state)
render_environment(env)

In [None]:
preprocess = transforms.Compose([
  transforms.ToPILImage(),
  transforms.CenterCrop(448), 
  transforms.Resize((224, 224)),
  transforms.ToTensor(),
])

In [None]:
def select_action(state,policy_net):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(
        -1.0 * steps_done / EPS_DECAY
    )
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor(
            [[env.action_space.sample()]], device=device, dtype=torch.long
        )

In [None]:
steps_done = 0

In [None]:
def single_run(BATCH_SIZE, GAMMA, EPS_START, EPS_END, EPS_DECAY, TAU, LR,REPLAY_BUFFER_SIZE,EXP_NAME):
  writer = SummaryWriter(f'logs/{EXP_NAME}/CNN_BASED/RESNET/BATCH_SIZE-{BATCH_SIZE}-GAMMA-{GAMMA}-EPS_START-{EPS_START}-EPS_END-{EPS_END}-EPS_DECAY-{EPS_DECAY}-TAU-{TAU}-LR-{LR}-REPLAY_BUFFEER_SIZE-{REPLAY_BUFFER_SIZE}')
  policy_net = DQN(n_observations, n_actions).to(device)
  target_net = DQN(n_observations, n_actions).to(device)
  target_net.load_state_dict(policy_net.state_dict())
  memory = ReplayBuffer(REPLAY_BUFFER_SIZE)
  optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
  steps_done = 0
  episode_durations = []
  episode_rewards = []
  epoch_q_values = []
  if torch.cuda.is_available() or torch.backends.mps.is_available():
      num_episodes = 500
  else:
      num_episodes = 50

  for i_episode in tqdm(range(num_episodes)):
      torch.cuda.empty_cache()
      # Initialize the environment and get its state
      state, info = env.reset()
      image = env.render()
      image = preprocess(image)
      # print(image.shape)
      state = torch.tensor(image, dtype=torch.float32, device=device).unsqueeze(0)
      # print(state.shape)
      episode_reward = 0
      epoch_q_value = 0
      
      for t in count():
          action = select_action(state,policy_net)
          observation, reward, terminated, truncated, _ = env.step(action.item())
          reward = torch.tensor([reward], device=device)
          episode_reward+=reward.item()

          q_values = policy_net(state)
          q_value = q_values[0, action].item()
          epoch_q_value += q_value


          done = terminated or truncated

          if terminated:
              next_state = None
          else:
              image = env.render()
              image = preprocess(image)
              next_state = torch.tensor(image, dtype=torch.float32, device=device).unsqueeze(0)

          memory.push(state.detach().cpu(), action.detach().cpu(), next_state.detach().cpu() if next_state is not None else None, reward.detach().cpu())
          state = next_state

          # Perform one step of the optimization (on the policy network)
          if len(memory) >= BATCH_SIZE:
            transitions = memory.sample(BATCH_SIZE)
            batch = Transition(*zip(*transitions))
            non_final_mask = torch.tensor(
                tuple(map(lambda s: s is not None, batch.next_state)),
                device=device,
                dtype=torch.bool,
            )
            non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
            state_batch = torch.cat(batch.state).to(device)
            action_batch = torch.cat(batch.action).to(device)
            reward_batch = torch.cat(batch.reward).to(device)
            state_action_values = policy_net(state_batch).gather(1, action_batch)
            next_state_values = torch.zeros(BATCH_SIZE, device=device)
            with torch.no_grad():
                next_state_values[non_final_mask] = (
                    target_net(non_final_next_states).max(1).values
                )
            expected_state_action_values = (next_state_values * GAMMA) + reward_batch
            criterion = nn.SmoothL1Loss()
            loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
            writer.add_scalar("Loss", loss.item(), i_episode)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
            optimizer.step()

          target_net_state_dict = target_net.state_dict()
          policy_net_state_dict = policy_net.state_dict()
          for key in policy_net_state_dict:
              target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
          target_net.load_state_dict(target_net_state_dict)
          if done:
              writer.add_scalar("Average Reward",episode_reward,i_episode)
              average_q_value = epoch_q_value / (t + 1)
              epoch_q_values.append(average_q_value)
              writer.add_scalar("Average Q-value", average_q_value, i_episode)
              episode_durations.append(t + 1)
              break
  writer.add_hparams({'LR': LR, 'BATCH_SIZE': BATCH_SIZE, 'GAMMA':GAMMA,'EPS_START':EPS_START,'EPS_END':EPS_END,'TAU':TAU,'REPLAY_BUFFER_SIZE':REPLAY_BUFFER_SIZE}, {'hparam/loss': loss.item(), 'hparam/reward': reward,'average_q_value':average_q_value})
  torch.save(policy_net.state_dict(), 'dqn_cnn_model.pth')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
REPLAY_BUFFER_SIZE = 100000
BATCH_SIZE = 128
GAMMA = 0.4
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
single_run(BATCH_SIZE, GAMMA, EPS_START, EPS_END, EPS_DECAY, TAU, LR,REPLAY_BUFFER_SIZE,'REPLAY_BUFFER')

In [None]:
# from tensorboard import notebook
# notebook.list()
# notebook.display(port=6006, height=500)

In [None]:
# policy_net = DQN(n_observations, n_actions).to(device)
# policy_net.load_state_dict(torch.load('dqn_cnn_model.pth'))
# policy_net.eval() 

In [None]:
# !pip install imageio
# import imageio
# from IPython.display import Video

In [None]:
# state,_ = env.reset()
# image = env.render()
# image = preprocess(image)
# state = torch.tensor(image, dtype=torch.float32, device=device).unsqueeze(0)

# done = False
# frames = []
# while not done:
#   frames.append(env.render())
#   with torch.no_grad():
#         image = env.render()
#         image = preprocess(image)
#         state = torch.tensor(image, dtype=torch.float32, device=device).unsqueeze(0)
#         action = policy_net(state).argmax(dim=1).item()
#   state, reward, terminated, truncated,_ = env.step(action)
#   done = terminated or truncated
# imageio.mimsave('output.mp4', frames)
# env.close()

# Video("output.mp4", embed=True)