# Imports


In [1]:
import sys 
import os

In [2]:

# Set path to parent dir to import personal imports
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Personal 
from data.TrajectorySet import TrajectorySet
from data.Sampler import Sampler 
from data.DatasetCL import DatasetCL 
from utils.tensor_utils import convert_batch_to_tensor

# Misc
import minari 
import numpy as np
import wandb
import os

# Torch 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import torch.utils.data as data 

# PyTorch Lightning 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint



In [3]:
dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Model Architecture

In [5]:
class mlpCL(pl.LightningModule): 
    def __init__(self, lr, weight_decay, temperature=30, max_epochs=1000, h1=256, h2=128, h3=64, h4=32):
        super().__init__() # inherit from LightningModule and nn.module 
        self.save_hyperparameters() # save args  

        self.mlp = nn.Sequential(
            nn.Linear(4, h1), 
            nn.ReLU(inplace=True), 

            nn.Linear(h1, h2), 
            nn.ReLU(inplace=True),

            nn.Linear(h2, h3), 
            nn.ReLU(inplace=True),

            nn.Linear(h3, h4), # representation z 
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return [optimizer, lr_scheduler]
    
    def info_nce_loss(self, batch, mode="train"): 
        # Organizes the states such that their positive pairs are (i + N // 2) away. 
        batch = convert_batch_to_tensor(batch=batch) 
        batch = torch.cat(batch, dim=0)  

        # Encode states 
        z = self.mlp(batch)

        # Get cosine similarity matrix, where the i'th row and j'th index correspond to the 
        # similarity between z_i and z_j 
        cos_sim = F.cosine_similarity(x1=z[:, None, :], x2=z[None, :, :], dim=-1)

        # Create a boolean mask where the diagonals are true
        self_mask = torch.eye(n=cos_sim.size()[0], dtype=bool, device=cos_sim.device)

        # Change the diagonals to become really small numbers, zero-ing out their similarity value
        # i.e we dont want the similarity values of z_i and z_i to be in the calculations 
        cos_sim.masked_fill_(mask=self_mask, value= -9e15)  

        # Create a mask that corresponds to the [i,j] location of positive pairs. 
        # Since positive pairs are i + N // 2 away from the i'th anchor, we roll the cos_sim matrix as such 
        # Rolling the Identity matrix row-wise will create this effect. 
        pos_mask = self_mask.roll(shifts=len(batch) // 2, dims=0)

        cos_sim = cos_sim / self.hparams.temperature 
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=0)
        nll = torch.mean(nll)

        reshaped_cos_sim = cos_sim[pos_mask][:, None] # Create 2D matrix with positive pairs on the first column 
        comb_sim = torch.cat([reshaped_cos_sim, cos_sim.masked_fill_(mask=pos_mask, value=-9e15)], dim=1) # concatenate all other values column wise 

        # Sort to find if the true positive pair has the highest similarity value with its column neighbours
        sim_argsort = comb_sim.argsort(dim=1, descending=True).argmin(dim=1)

        # Metrics 
        correct = (sim_argsort == 0).float().mean() # Average true positive pairs that had the highest similarity in their column neighbours 
        top5 = (sim_argsort < 5).float().mean() # Average true positive pairs that were in the top5 highest similarity values 
        mean_position = 1 + sim_argsort.float().mean() # Average position of true positive pairs (using 1-indexing)

        # Logging metrics 
        """
        wandb.log({
            f"{mode}/nll_loss": nll.item(), 
            f"{mode}/top1": correct.item(), 
            f"{mode}/top5": top5.item(), 
            f"{mode}/mean_pos": mean_position.item()
        })
        """
        
        return nll
    
    def training_step(self, batch):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch):
        self.info_nce_loss(batch, mode='val')

In [9]:
"""
TESTING CELL! 

Testing InfoNCE loss 
"""

T = TrajectorySet(dataset=dataset) 
S = Sampler(T, dist="l")

ds = DatasetCL(S, batch_size=32, k=2)
batch = ds.get_batch()
model = mlpCL(lr = 1, weight_decay=1)

nll= model.info_nce_loss(batch)

print(nll)

tensor(4.1381, grad_fn=<MeanBackward0>)


In [11]:
CHECKPOINT_PATH = "../saved_models"

DS = data.DataLoader(dataset=ds)

print(next(iter(DS)))

[tensor([[-4.3051,  2.9311, -3.9615, -0.4089]]), tensor([[-4.4661,  2.8752, -2.7294, -1.5892]])]


# Trainer


In [8]:
CHECKPOINT_PATH = "../saved_models"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

def train_cl(batch_size, max_epochs=1000, **kwargs):
    trainer = pl.Trainer(
        default_root_dir=CHECKPOINT_PATH, 
        
        accelerator="gpu" if str(device).startswith("cuda") else "cpu", 
        devices=1, 
        max_epochs=max_epochs,
        callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", moniter="val/top5"),
                   LearningRateMonitor("epoch")])
    
    

    pass