<a href="https://colab.research.google.com/github/S1R3S1D/SAIDL_Spring_2022_Assignment/blob/main/Reinforcement_Learning/DQN_with_SimCLR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Runtime Dependencies
%pip install -U gym>=0.21.0
%pip install -U gym[atari,accept-rom-license]

In [15]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import gym
import os
import PIL
from PIL import Image
from collections import deque
from tqdm import tqdm
import torch.nn.functional as F
import copy
import random
import matplotlib.pyplot as plt

In [3]:
#Projection Head

class projection_head(nn.Module):
    def __init__(self, embed_dim=1024, output_dim=128):

        super(projection_head, self).__init__()

        self.embed_dim = embed_dim
        self.output_dim = output_dim

        self.projection = nn.Sequential(
            nn.Linear(self.embed_dim, 2048),
            nn.ReLU(),

            nn.Linear(2048, self.output_dim)
        )

    def forward(self, x):
        x = self.projection(x)
        return x

In [4]:
#Model

class ConvNN(nn.Module):
  def __init__(self, emb_dim):

    super(ConvNN, self).__init__()

    self.conv_net = nn.Sequential(

        nn.Conv2d(4, 32, 8, 4),
        nn.ReLU(),

        nn.Conv2d(32, 64, 4, 2),
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.Conv2d(64, 64,3, 1),
        nn.ReLU()
    )

    self.flatten = nn.Flatten(start_dim = 1)

    self.fc_layer = nn.Sequential(
        nn.Linear(64*7*7, 512),
        nn.ReLU(),

        nn.Linear(512, emb_dim),
    )

  def forward(self, x):

    x = self.conv_net(x)
    x = self.flatten(x)
    x = self.fc_layer(x)

    return x

In [5]:
#Wrappers

class cropped84x84_grayscale(gym.ObservationWrapper):
    def __init__(self, env):
        super(cropped84x84_grayscale, self).__init__(env)

    def observation(self, observation):
        observation = Image.fromarray(observation)

        observation = observation.crop((0, 35, 160, 190))

        observation = observation.resize((84, 84), Image.LANCZOS)

        observation = observation.convert('L')

        observation = np.asarray(observation)

        return observation

class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip = 4):
        super(MaxAndSkipEnv, self).__init__(env)
        self._obs_buffer = deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward+=reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


In [6]:
#Dataset

class OpenAIGymData(Dataset):
  def __init__(self, env, fire = False, n_obs = 2560):

    os.mkdir('Data')
    os.mkdir('Data/AtariEnvImageData')

    self.n_obs = n_obs
    # env = gym.wrappers.AtariPreprocessing(env, noop_max=30, screen_size=84, terminal_on_life_loss=True, grayscale_obs=True)
    env = cropped84x84_grayscale(env)
    env = MaxAndSkipEnv(env)
    env = gym.wrappers.FrameStack(env, 4)
    env.reset()
    if fire:
      env.step(1)
    data_folder = ''
    self.data_folder = data_folder
    for i in range(int(16)):
      file_name = data_folder+'/'+str(i)+'.npy'
      with open(file_name, 'wb') as f:
        samples = []
        for i in range(int(self.n_obs/16)):
          action = np.random.randint(0, env.action_space.n)
          (sample, reward, done, info) = env.step(action)
          if done == True:
            env.reset()
            if fire:
              env.step(1)

          sample = np.array(sample)
          sample = sample.astype(float)/255.0
          samples.append(sample)
        samples = np.array(samples)
        np.save(f, samples)

  def __len__(self):
    return self.n_obs

  def __getitem__(self, index):

    file_no = str(int(index/int(self.n_obs/16)))
    in_file_index = index%int(self.n_obs/16)

    file_name = self.data_folder+'/'+file_no+'.npy'
    with open(file_name, 'rb') as f:
      sample = np.load(f)
      return sample[in_file_index]

In [7]:
#Data
!rm -rf Data/
env = gym.make('BreakoutNoFrameskip-v4')
breakout_dataset = OpenAIGymData(env)


In [8]:
#DataLoader
breakout_dataloader = DataLoader(breakout_dataset, batch_size=32)

In [9]:
#Loss Function

class ntxent(nn.Module):
    def __init__(self, batch_size, temperature):
        super(ntxent, self).__init__()
        self.batch_size = batch_size
        self.register_buffer("neg_eye", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())
        self.register_buffer("temperature", torch.tensor(temperature))
    def forward(self, emb_i, emb_j):

        z_i = F.normalize(emb_i)
        z_j = F.normalize(emb_j)

        z = torch.cat([z_i, z_j], 0)

        sim_mat = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
        # print("shape of similarity matrix :", sim_mat.shape)

        sim_ij = torch.diag(sim_mat, self.batch_size)
        sim_ji = torch.diag(sim_mat, -self.batch_size)

        positives = torch.cat([sim_ij, sim_ji], 0)

        numerator = torch.exp(positives/self.temperature)
        # print("Shape of positives:", positives.shape)
        denominator = self.neg_eye*torch.exp(sim_mat/self.temperature)

        loss_partial = -torch.log(numerator/(torch.sum(denominator, 0)-numerator))

        loss = torch.sum(loss_partial)/(2*self.batch_size)

        return loss

In [10]:
#Hyperparameters and rest

lr = 0.001

model = ConvNN(1024)
projection = projection_head(1024, 128)

params = list(model.parameters())+list(projection.parameters())

optimizer = torch.optim.Adam(params, lr=lr)

loss_fn = ntxent(32, 0.07)

epochs = 2

In [None]:
#Training Model for contrastive learning

for epoch in tqdm(range(epochs)):
  for images in breakout_dataloader:
    images = images.clone().detach().float()

    embed1 = model(images)
    proj1 = projection(embed1)

    embed2 = model(images)
    proj2 = projection(embed2)

    loss = loss_fn(proj1, proj2)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [11]:
#Training DQN with model embeddings

#Deep Q Network Acrhitecture

class DQN(nn.Module):
  def __init__(self, n_actions):
    
    super(DQN, self).__init__()

    self.n_actions = n_actions

    self.FC = nn.Sequential(
        nn.Linear(1024, 2048),
        nn.ReLU(),

        nn.Linear(2048, 512),
        nn.ReLU(),

        nn.Linear(512, self.n_actions)
    )

  def forward(self, x):

    x = self.FC(x)

    return x

In [None]:
#Hyperparameters, loss function, optimizer and others for Deep Q learning

lr = 0.001

dqn = DQN(env.action_space.n)

target_network = copy.deepcopy(dqn)
target_network.load_state_dict(dqn.state_dict())

sync_freq = 5000

loss_fn_rl = nn.MSELoss()

optimizer_rl = torch.optim.Adam(dqn.parameters(), lr = lr)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

dqn.to(device)
target_network.to(device)

In [16]:
#DQN testing module

def test_dqn(env, render = False):
  
  with torch.no_grad():
    print("Testing Dqn:")
    env.reset()
    tot_pos_rew = 0
    tot_reward = 0
    (state, reward, done, info) = env.step(np.random.randint(0, env.action_space.n))
    while not done:
      state = torch.tensor(np.expand_dims(np.array(state), axis=0)).float()/255.0
      state = state.to(device)

      qval_ = dqn(state)
      qval = qval_.cpu().detach().squeeze().numpy()
      
      if np.random.random()<0.1:
        action = np.random.randint(0, env.action_space.n)
      else:
        action = np.argmax(qval)
      
      (state, reward, done, info) = env.step(action)
      if render:
        plt.imshow(np.array(state)[0])
        plt.show()

      if reward>0:
        tot_pos_rew+=reward

      tot_reward+=reward
    
    print("Total Positive Reward :", tot_pos_rew)
    print("Total Reward :", tot_reward)

In [17]:
def train_dqn_w_trgnet_expreplay(n_eps):

  replay = deque(maxlen=30000)
  mini_batch_size = 32
  gamma = 0.99
  episode = 0
  losses = []
  
  for _ in range(n_eps):

    print("Training Episode :", episode)
    epsilon = 1
    episode+=1
    epoch = 0
    env.reset()
    (state1, reward1, done1, info1) = env.step(np.random.randint(0, env.action_space.n))

    
    for _ in tqdm(range(100000)):
      
      
      state1_ = torch.tensor(np.expand_dims(np.array(state1), axis=0)).float()/255.0
      state1_ = torch.Tensor(state1_)
      state1_ = state1_.to(device)

      state1_ = model(state1_)#Model embeddings

      qval_ = dqn(state1_)
      qval = qval_.cpu().detach().squeeze().numpy()

      if random.random()<epsilon:
        action = np.random.randint(0, env.action_space.n)
      else:
        action = np.argmax(qval)

      if epsilon>0.1:
        epsilon-=1/50000

      (state2, reward2, done2, info2) = env.step(action)

      if done2:
        env.reset()

      state2_ = torch.from_numpy(np.expand_dims(np.array(state2), axis=0)).float()/255.0
      state2_ = torch.Tensor(state2_)
      state2_ = state2_.to(device)

      state2_ = model(state2_)#Model embeddings

      exp = (state1_, action, state2_, reward2, done2)
      replay.append(exp)

      state1 = state2

      if len(replay)>7000:

        mini_batch = random.sample(replay, mini_batch_size)

        state1_batch = torch.cat([s1 for (s1, a, s2, r, d) in mini_batch]).to(device)
        state2_batch = torch.cat([s2 for (s1, a, s2, r, d) in mini_batch]).to(device)
        action_batch = torch.Tensor([a for (s1, a, s2, r, d) in mini_batch]).to(device)
        reward_batch = torch.Tensor([r for (s1, a, s2, r, d) in mini_batch]).to(device)
        done_batch = torch.Tensor([d for (s1, a, s2, r, d) in mini_batch]).to(device)

        Q1 = dqn(state1_batch)
        with torch.no_grad():
          Q2 = target_network(state2_batch)

        Y = reward_batch + gamma * ((1-done_batch)*torch.max(Q2, dim=1)[0])
        X = Q1.gather(dim = 1, index = action_batch.long().unsqueeze(dim=1)).squeeze()

        loss = loss_fn_rl(X, Y.detach())

        optimizer_rl.zero_grad()
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        epoch+=1
        if epoch%sync_freq ==0:
          target_network.load_state_dict(dqn.state_dict())
    if episode%2==0:
      test_dqn(env)

  losses = np.array(losses)
  return losses