<a href="https://colab.research.google.com/github/nerdk312/Model-based-RL/blob/master/Embed_2_Contrast_170520.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+git://github.com/openai/baselines
!pip install wandb

Collecting git+git://github.com/openai/baselines
  Cloning git://github.com/openai/baselines to /tmp/pip-req-build-fr8a_r4y
  Running command git clone -q git://github.com/openai/baselines /tmp/pip-req-build-fr8a_r4y
Collecting gym<0.16.0,>=0.15.4
[?25l  Downloading https://files.pythonhosted.org/packages/e0/01/8771e8f914a627022296dab694092a11a7d417b6c8364f0a44a8debca734/gym-0.15.7.tar.gz (1.6MB)
[K     |████████████████████████████████| 1.6MB 8.6MB/s 
Building wheels for collected packages: baselines, gym
  Building wheel for baselines (setup.py) ... [?25l[?25hdone
  Created wheel for baselines: filename=baselines-0.1.6-cp36-none-any.whl size=220664 sha256=f0aa208484c6e70a3d304159b1e426c220fc5ec398814453bc793e39ff8e4e73
  Stored in directory: /tmp/pip-ephem-wheel-cache-is7fdxb7/wheels/42/1c/91/28314e0cd1d2cc57cf8dd18b20c4c9a0f39ae518adc13caf24
  Building wheel for gym (setup.py) ... [?25l[?25hdone
  Created wheel for gym: filename=gym-0.15.7-cp36-none-any.whl size=1648840 sha256

In [0]:
import os
from __future__ import print_function
import pickle
import sys
sys.path.append('/content/drive/My Drive')
import wandb

import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import gym
import time
import matplotlib.pyplot as plt

from Embed_2_Contrast.custom_wrappers import custom_wrapper
from Embed_2_Contrast.encoder import make_encoder
from Embed_2_Contrast.EarlyStopping import EarlyStopping_loss
from Embed_2_Contrast.GeneralFunctions import General_functions
from Embed_2_Contrast.utils import make_dir, random_crop,center_crop_image, soft_update_params, weight_init, random_color_jitter
from torch.autograd import Variable
from Embed_2_Contrast.DataCollection import Data_collection
from Embed_2_Contrast.models import CURL
from Embed_2_Contrast.replay_buffer import ReplayBuffer

# Needed to create dataloaders
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [0]:
!wandb login #########

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [0]:
class CurlAgent(object):
    ''' CURL representation learning'''
    def __init__(
        self,
        obs_shape,
        device,
        frames,
        encoder_feature_dim = 50, # This is the size of the embedding used for the 
        encoder_lr = 1e-4,
        encoder_tau = 0.001,
        num_layers=4,
        num_filters = 32,
        cpc_update_freq=1,
        encoder_update_freq = 1,
        random_jitter = True,
        detach_encoder = True
    ):
        self.device = device
        self.cpc_update_freq = cpc_update_freq
        self.image_size = obs_shape[-2] # Changed this to the numpy dimension
        self.frames = frames

        self.encoder_tau = encoder_tau
        self.epoch_step = 0
        self.encoder_update_freq = encoder_update_freq
        self.random_jitter = random_jitter
        
        self.CURL = CURL(obs_shape, encoder_feature_dim,
                         encoder_feature_dim, num_layers, num_filters).to(self.device)
        
        if self.detach_encoder: # If the encoder for the dynamics network is not being updated , then we only want to use the contrastive loss to update the network
            self.Model.encoder.copy_conv_weights_from(self.CURL.encoder) 

        self.cpc_optimizer = torch.optim.Adam(
                self.CURL.parameters(), lr=encoder_lr
            )
        
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.train()
    
    def train(self, training = True):
        self.training = training
        self.CURL.train(training)

    def update(self, train_dataloader,val_dataloader,early_stopper):
        #torch.cuda.empty_cache() # Releases cache so the GPU has more memory
        if early_stopper.early_stop:
            print('early stopping')
            return

        for step, (obs, actions, next_obs, cpc_kwargs) in enumerate(train_dataloader):

            if step % self.encoder_update_freq == 0:
                soft_update_params(
                    self.CURL.encoder, self.CURL.encoder_target,
                    self.encoder_tau
                )
            if step % self.cpc_update_freq == 0:            
                obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs["obs_pos"]
                self.update_cpc(obs_anchor, obs_pos) # Nawid -  Performs the contrastive loss I believe
        
        self.validation(val_dataloader,early_stopper)
    

    def update_cpc(self, obs_anchor, obs_pos):
        obs_anchor, obs_pos = obs_anchor.to(self.device), obs_pos.to(self.device)
        if self.random_jitter:
            obs_anchor, obs_pos = random_color_jitter(obs_anchor,batch_size = obs_anchor.shape[0],frames = self.frames), random_color_jitter(obs_pos,batch_size = obs_pos.shape[0],frames= self.frames)

        z_a = self.CURL.encode(obs_anchor) # Nawid -  Encode the anchor
        z_pos = self.CURL.encode(obs_pos, ema=True) # Nawid- Encode the positive with the momentum encoder

        logits = self.CURL.compute_logits(z_a, z_pos) #  Nawid- Compute the logits between them
        labels = torch.arange(logits.shape[0]).long().to(self.device)
        loss = self.cross_entropy_loss(logits, labels)
        wandb.log({'Contrastive Training loss':loss.item()})

        self.cpc_optimizer.zero_grad()
        loss.backward()

        self.cpc_optimizer.step()  # Nawid - Used to update the cpc
    
    def validation(self, dataloader,early_stopper):
        epoch_contrastive_loss = 0
        self.CURL.eval()
        with torch.no_grad():
            for i, (obses, actions, next_obses, cpc_kwargs) in enumerate(dataloader):
                obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs["obs_pos"]
                obses, obs_anchor,obs_pos = obses.to(self.device), obs_anchor.to(self.device), obs_pos.to(self.device)
                if self.random_jitter:
                    obs_anchor, obs_pos =  random_color_jitter(obs_anchor,batch_size = obs_anchor.shape[0],frames = self.frames), random_color_jitter(obs_pos,batch_size = obs_pos.shape[0],frames= self.frames)

                ''' Code to check the appearance of the image
                image = obs_pos[0]
                image = image.permute(1, 2, 0)
                plt.imshow(image)
                plt.figure()
                plt.show()
                return 
                ''' 
                actions, next_obses = actions.to(self.device), next_obses.to(self.device)
                z_a = self.CURL.encode(obs_anchor) # Nawid -  Encode the anchor
                z_pos = self.CURL.encode(obs_pos, ema=True) # Nawid- Encode the positive with the momentum encoder
                logits = self.CURL.compute_logits(z_a, z_pos) #  Nawid- Compute the logits between them
                labels = torch.arange(logits.shape[0]).long().to(self.device)
                loss = self.cross_entropy_loss(logits, labels)
                epoch_contrastive_loss += loss.item()
                
            average_epoch_contrastive_loss = epoch_contrastive_loss/(i+1)           
            self.epoch_step += 1 # increase epoch counter
            wandb.log({'Contrastive Validation loss':average_epoch_contrastive_loss,'epoch': self.epoch_step})

            print('epoch:', self.epoch_step)
            early_stopper(average_epoch_contrastive_loss,self.CURL,self.cpc_optimizer)
            
        self.train()
    
def make_agent(obs_shape, device, dict_info):
    return CurlAgent(
        obs_shape = obs_shape,
        device = device,
        frames = dict_info['frames'],
        detach_encoder =dict_info['detach_encoder'],
        random_jitter = dict_info['random_jitter'],
        encoder_update_freq =dict_info['encoder_update_freq'],
        encoder_feature_dim = dict_info['encoder_feature_dim'], #  size of the embedding from the projection head
        encoder_lr = dict_info['encoder_lr'],
        encoder_tau = dict_info['encoder_tau'],
        num_layers = dict_info['num_layers'],
        num_filters = dict_info['num_filters'], # num of conv filters
    )

In [0]:
ENV_NAME = 'MsPacmanDeterministic-v4'
n_actions = 4 #9 - Nawid - Change to 5 actions as the 4 other actions are simply copies of the other actions, therefore 5 actions should lower the amount of data needed.
'''
data_transform =  transforms.Compose([
        transforms.ColorJitter(0.8 * 1, 0.8 * 1, 0.8 * 1, 0.2 * 1),
        transforms.ToTensor()])
'''
data_transform = transforms.Compose([
                                    transforms.ToTensor()])

no_agents = 5
state_space = no_agents*2 

parse_dict= {'pre_transform_image_size':100,
             'image_size':84,
             'frame_stack':True,
             'frames': 4,
             'state_space':state_space,
             'train_capacity':100000,
             'val_capacity':20000,
             'num_train_epochs':20,
             'batch_size':512,
             'random_crop': True,
             'encoder_update_freq':1,
             'encoder_feature_dim':50,
             'encoder_lr':1e-3,
             'encoder_tau':0.05, # value used for atari experiments in curl
             'num_layers':4,
             'num_filters':32,
             'grayscale': False,
             'load_pretrain_model': False,
             'walls_present':True,
             'pretrain_model':False,
             'save_data':False,
             'num_pretrain_epochs':25,
             'transform': data_transform,
             'random_jitter':False,
             'detach_encoder:'True
            }

#custom_name = 'rand_crop-' +str(parse_dict['random_crop'])  + '_gray-' + str(parse_dict['grayscale']) + '_walls-' +str(parse_dict['walls_present'])  + '_pretrain-' + str(parse_dict['pretrain_model'])
custom_name = 'Contrastive_hp_testing_random_jitter-'+str(parse_dict['random_jitter']) + '_encoder_tau-' +str(parse_dict['encoder_tau']) 
wandb.init(entity="nerdk312",name=custom_name, project="Embed2Contrast",config = parse_dict)

possible_positions = np.load('/content/drive/My Drive/MsPacman-data/possible_pacman_positions.npy',allow_pickle=True)

config = wandb.config

if parse_dict['load_pretrain_model']:
    config.pretrained_model = pretrain_model_dir

# Data collection
data_object = Data_collection(ENV_NAME,n_actions,possible_positions, parse_dict,parse_dict['train_capacity'])
val_data_object = Data_collection(ENV_NAME,n_actions,possible_positions, parse_dict, parse_dict['val_capacity'])

data_object.gather_random_trajectories(5000)
val_data_object.gather_random_trajectories(5000)

data_object.replay_buffer.crop_control(parse_dict['random_crop'])
val_data_object.replay_buffer.crop_control(parse_dict['random_crop'])

train_dataloader = DataLoader(data_object.replay_buffer, batch_size = parse_dict['batch_size'], shuffle = True)
val_dataloader = DataLoader(val_data_object.replay_buffer, batch_size = parse_dict['batch_size'], shuffle = True)



test_info = [0.001,0.005,0.01,0.05,0.1,0.5,1]
#tests = len(test_info) + 1
tests = 1 

for i in range(tests):  
    print(i)  
    if i >0:
        #parse_dict['encoder_tau'] = np.random.uniform(1e-4,1e-2)
        #parse_dict['encoder_lr'] = np.random.uniform(1e-3,1e-2)
        parse_dict['random_jitter'] = True 
        parse_dict['encoder_tau'] = test_info[i-1]
        custom_name = 'Contrastive_hp_testing_random_jitter-'+str(parse_dict['random_jitter']) + '_encoder_tau-' +str(parse_dict['encoder_tau']) 
        wandb.init(entity="nerdk312",name=custom_name, project="Contrastive_learning",config = parse_dict)

    agent = make_agent(
    obs_shape = data_object.obs_shape,
    device =data_object.device,
    dict_info = parse_dict
    )

    pretrain_model_name = 'Contrastive' +'_' + data_object.ts
    early_stopping_contrastive = EarlyStopping_loss(patience=3, verbose=True, wandb=wandb, name=pretrain_model_name)

    for step in range(parse_dict['num_train_epochs']):
        if early_stopping_contrastive.early_stop: #  Stops the training if early stopping counter is hit
            break    
        agent.update(train_dataloader,val_dataloader,early_stopping_contrastive)

    wandb.join()    

cuda
cuda
trajectory number: 0
trajectory number: 10
trajectory number: 20
trajectory number: 30
trajectory number: 40
trajectory number: 50
trajectory number: 60
trajectory number: 70
trajectory number: 80
trajectory number: 90
trajectory number: 100
trajectory number: 110
trajectory number: 120
trajectory number: 130
trajectory number: 140
trajectory number: 150
trajectory number: 160
trajectory number: 170
trajectory number: 180
trajectory number: 190
trajectory number: 200
trajectory number: 210
trajectory number: 220
trajectory number: 230
trajectory number: 240


In [0]:
test_info = [0.001,0.005,0.01,0.5,0.1,0.5,1]
test_info[2]

0.01

In [0]:
pretrain_model_name = 'Contrastive' +'_' + data_object.ts
early_stopping_contrastive = EarlyStopping_loss(patience=3, verbose=True, wandb=wandb, name=pretrain_model_name)

for step in range(parse_dict['num_train_epochs']):
    if early_stopping_contrastive.early_stop: #  Stops the training if early stopping counter is hit
        break    
    agent.update(train_dataloader,val_dataloader,early_stopping_contrastive)

epoch: 1
epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9
epoch: 10
epoch: 11
epoch: 12
epoch: 13
epoch: 14
epoch: 15
epoch: 16
epoch: 17
epoch: 18
epoch: 19
epoch: 20


# OLD CODE

In [0]:
data_object.replay_buffer.transform

Compose(
    ColorJitter(brightness=[0.19999999999999996, 1.8], contrast=[0.19999999999999996, 1.8], saturation=[0.19999999999999996, 1.8], hue=[-0.2, 0.2])
    ToTensor()
)

In [0]:
data_transform =  transforms.Compose([
        transforms.ColorJitter(0.8 * 1, 0.8 * 1, 0.8 * 1, 0.2 * 1),
        transforms.ToTensor()])
ENV_NAME = 'MsPacmanDeterministic-v4'
env = gym.make(ENV_NAME)
env = custom_wrapper(env,ATARI_labels=False)
obs =  env.reset()
values = []
while True:
    next_obs, reward, done, _ = env.step(1)
    
    values.append([obs,next_obs])

    obs = next_obs
    next_obs = 5
    if done:
        break
