In [None]:
!pip install gym[atari,accept-rom-license] atari-py
!pip install moviepy
# in case you get a error saying: "cannot import name 'NotRequired' from 'typing_extensions'". Do these:
# !pip uninstall tensorflow-gpu
#!pip install --upgrade typing-extensions

In [None]:
%matplotlib inline

In [None]:
# Hyperparameters
BATCH_SIZE = 128 # Transcations to be sampled from replay buffer
GAMMA = 0.99 # discount factor
TAU = 0.005#Updation rate
LR = 1e-4#learning rate
EPS_START = 0.9#beginning value of epsilon
EPS_END = 0.05 # epsilon value after decay
EPS_DECAY = 1000# Rate of epsilon decay in epsilon greedy approach

In [None]:
import gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Pretrained model utils
from play import *

# Note: Depending on what we need in the assignment, we can use various gym wrappers, Monitor wrapper, video recording wrapper, etc
from gym.wrappers import AtariPreprocessing, FrameStack

# Using Atari preprocessing requires us to use NoFramesskip environment instead of the usual Breakout-v4
# env_ = 
# env = AtariPreprocessing(env_, scale_obs=False)  # auto skips 4 frames, converts grayscale

num_envs = 6# CHANGE THIS FOR NUMBER OF ENVIRONMENTS ------------------

# Introducing Frame stack to get vel, acc, etc
# env = gym.vector.AsyncVectorEnv([
env = gym.vector.AsyncVectorEnv(
    [lambda: FrameStack(AtariPreprocessing(
                            gym.make("BreakoutNoFrameskip-v4",render_mode='rgb_array'), 
                            scale_obs=False), 
                        num_stack=4
                        ) for _ in range(num_envs)])
# env = gym.wrappers.RecordEpisodeStatistics(env_, deque_size=1000)

# record_every = 100

#record video for every even episodes
# env = gym.wrappers.RecordVideo(env, 'video', episode_trigger = lambda x: x % record_every == 0)

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
Interaction = namedtuple('Interaction',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, state, action, next_state, reward):
        self.memory.append(Interaction(state, action, next_state, reward))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
# Get number of actions from gym action space
n_actions = env.action_space[0].n
# Get the number of state observations
states, infos = env.reset()
n_observations = len(states[0])

def getEpsilon():
    return EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)

def select_action(state):
    global steps_done
    batch_size = states.shape[0]
    steps_done += 1
    if random.random() > getEpsilon():
        with torch.no_grad():
            return policy_net(states).max(1)[1].unsqueeze(1)
    else:
        return torch.tensor(env.action_space.sample(), device=device, dtype=torch.long).unsqueeze(1).expand(batch_size, 1)


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())
            
def plot_rewards(show_result=False, clipAt=10000, saveFig=False):
    plt.figure(1)
    durations_t = torch.tensor([x if x<clipAt else clipAt for x in rewards_episodes], dtype=torch.float)[:450]
    if show_result:
        plt.title('SpaceInvaders Within-Game RL')
    else:
        plt.clf()
        plt.title('SpaceInvaders Within-Game RL')
    plt.xlabel('Episode')
    plt.ylabel('Rewards')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    means = torch.zeros(99)#torch.tensor([durations_t[:i+1].mean() for i in range(min(len(durations_t), 100))])
    if len(durations_t)>=100:  
        means_later = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((means, means_later))
    plt.plot(means.numpy())
    if saveFig:
        plt.savefig(f"{EXP_NAME}/rewards.jpg")
        
    plt.close()

#     plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            return
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())
            
def plot_losses(show_result=False, saveFig=False):
    plt.figure(1)
    durations_t = torch.tensor([x for x in episode_losses], dtype=torch.float)[:450]
    if show_result:
        plt.title('SpaceInvaders Within-Game RL')
    else:
        plt.clf()
        plt.title('SpaceInvaders Within-Game RL')
    plt.xlabel('Episode')
    plt.ylabel('Loss')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    means = torch.zeros(99)#torch.tensor([durations_t[:i+1].mean() for i in range(min(len(durations_t), 100))])
    if len(durations_t)>=100:  
        means_later = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((means, means_later))
    plt.plot(means.numpy())
    if saveFig:
        plt.savefig(f"{EXP_NAME}/losses.jpg")
    plt.close()
#     plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            return
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    Interactions = memory.sample(BATCH_SIZE)

    batch = Interaction(*zip(*Interactions))

    # Mask of non-final states and concatenate the batch elements
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
#     print(torch.cat(batch.state).shape, batch.state)
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action).unsqueeze(1)
    reward_batch = torch.cat(batch.reward).unsqueeze(1)
    
#     print(state_batch.shape, action_batch.shape, reward_batch.shape, non_final_next_states.shape)

    # Compute Q(s_t, a)
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # V(s_{t+1})
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    # expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

In [None]:
def _load_checkpoint(fpath, device="cpu"):
    fpath = Path(fpath)
    with fpath.open("rb") as file:
        with GzipFile(fileobj=file) as inflated:
            return torch.load(inflated, map_location=device)

path = "models/DQN_modern/Breakout/2/model_50000000.gz"

pretrained_model = AtariNet(env.action_space[0].n, distributional="C51_" in path)
ckpt = _load_checkpoint(path)
pretrained_model.load_state_dict(ckpt["estimator_state"])

In [None]:
class DQN_Atari(nn.Module):
    """ Estimator used by DQN-style algorithms for ATARI games.
        Works with DQN, M-DQN and C51.
    """
    def __init__(self, action_no, distributional=False):
        super().__init__()

        self.action_no = out_size = action_no
        self.distributional = distributional

        # configure the support if distributional
        if distributional:
            support = torch.linspace(-10, 10, 51)
            self.__support = nn.Parameter(support, requires_grad=False)
            out_size = action_no * len(self.__support)

        # get the feature extractor and fully connected layers
#         self.__features = nn.Sequential(
#             nn.Conv2d(4, 32, kernel_size=8, stride=4),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(32, 64, kernel_size=4, stride=2),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(64, 64, kernel_size=3, stride=1),
#             nn.ReLU(inplace=True),
#         )
#         def _load_checkpoint(fpath, device="cpu"):
#             fpath = Path(fpath)
#             with fpath.open("rb") as file:
#                 with GzipFile(fileobj=file) as inflated:
#                     return torch.load(inflated, map_location=device)
        
#         path = "models/DQN_modern/Breakout/2/model_50000000.gz"
        
#         pretrained_model = AtariNet(env.action_space.n, distributional="C51_" in path)
#         ckpt = _load_checkpoint(path)
#         pretrained_model.load_state_dict(ckpt["estimator_state"])
    
        self.__features = pretrained_model._AtariNet__features
        for param in self.__features.parameters():
            param.requires_grad = False

        self.__head = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512), nn.ReLU(inplace=True), nn.Linear(512, out_size),
        )

    def forward(self, x):
        x = x.clamp(0, 255).to(torch.uint8)
        assert x.dtype == torch.uint8, "The model expects states of type ByteTensor"
        x = x.float().div(255)

        x = self.__features(x)
        qs = self.__head(x.view(x.size(0), -1))

        if self.distributional:
            logits = qs.view(qs.shape[0], self.action_no, len(self.__support))
            qs_probs = torch.softmax(logits, dim=2)
            return torch.mul(qs_probs, self.__support.expand_as(qs_probs)).sum(2)
        return qs

In [None]:
policy_net = DQN_Atari(n_actions).to(device)
target_net = DQN_Atari(n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)

steps_done = 0

rewards_episodes = []
episode_durations = []

i_episode = -1

In [None]:
import sys
if torch.cuda.is_available():
    total_episodes = 10000# Default was 600
else:
    total_episodes = 50

MEMORY_SIZE = 10000
memory = ReplayMemory(MEMORY_SIZE)

while i_episode < total_episodes:
    i_episode += 1
    # Initialize the environment and get it's state
    states, infos = env.reset()
    print(f"In {i_episode}th episode--------------> ")
    state = torch.tensor(states, dtype=torch.float32, device=device)#.unsqueeze(0)
    episode_reward = torch.zeros(num_envs, device=device)
    dones = [False for _ in range(num_envs)]
    for t in count():
        action = select_action(state)# [1, 1, NUM_ENVS]. Not sure why this size?
        actions = action.view(-1, 1)
#         print(action.shape, actions.shape, state.shape)
        
        obs, rewards, terminated, truncated, infos = env.step(actions)
        obs = torch.tensor(obs, dtype=torch.float32, device=device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
        old_dones = dones
        dones = [termi or trunc or dones[i] for i, (termi, trunc) in enumerate(zip(terminated, truncated))]

        episode_reward += rewards
        # Store Interactions in memory
        for i in range(num_envs):
            if not old_dones[i]:
                memory.push(state[i].unsqueeze(0), actions[i], obs[i].unsqueeze(0) if not dones[i] else None, torch.tensor([rewards[i]], device=device))
        
        # Move to the next state
        state = obs

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)
        
        # Record statistics and print progress
        if any(dones):
            for i, done in enumerate(dones):
                if done and not old_dones[i]:#i.e. in this exact frame, the game ended for ith environment. After this, the condition will be false
                    episode_durations.append(t + 1)
                    rewards_episodes.append(episode_reward[i].item())
                    print(f"Episode {i_episode}/{total_episodes}, Env {i+1}/{num_envs}, Duration {t+1}, "
                              f"Reward {episode_reward[i].item():.2f}")
        if dones==[True for _ in range(num_envs)]:   
            # Save the model weights
            if i_episode % 100 == 0:
                torch.save(target_net.state_dict(), "target_net.pth")
                torch.save(policy_net.state_dict(), "policy_net.pth")
            if i_episode<20 or i_episode%100==0:
                plot_rewards(show_result=True)
            break

In [None]:
plot_durations(show_result=True)
plot_rewards(show_result=True)