# Imports


In [1]:
import sys 
import os

In [2]:

# Set path to parent dir to import personal imports
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Personal 
from data.TrajectorySet import TrajectorySet
from data.Sampler import Sampler 
from data.DatasetCL import DatasetCL 
from utils.tensor_utils import convert_batch_to_tensor

# Misc
import minari 
import numpy as np
import wandb
import os

# Torch 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import torch.utils.data as data 

# PyTorch Lightning 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger



In [3]:
minari_dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Model Architecture

In [4]:
class mlpCL(pl.LightningModule): 
    def __init__(self, lr, weight_decay, temperature=30, max_epochs=1000, h1=256, h2=128, h3=64, h4=32):
        super().__init__() # inherit from LightningModule and nn.module 
        self.save_hyperparameters() # save args  

        self.mlp = nn.Sequential(
            nn.Linear(4, h1), 
            nn.ReLU(inplace=True), 

            nn.Linear(h1, h2), 
            nn.ReLU(inplace=True),

            nn.Linear(h2, h3), 
            nn.ReLU(inplace=True),

            nn.Linear(h3, h4), # representation z 
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return ([optimizer], [lr_scheduler])
    
    def info_nce_loss(self, batch, mode="train"): 
        # Organizes the states such that their positive pairs are (i + N // 2) away. 
        batch = torch.cat(batch, dim=0)  

        # Encode states 
        z = self.mlp(batch)

        # Get cosine similarity matrix, where the i'th row and j'th index correspond to the 
        # similarity between z_i and z_j 
        cos_sim = F.cosine_similarity(x1=z[:, None, :], x2=z[None, :, :], dim=-1)

        # Create a boolean mask where the diagonals are true
        self_mask = torch.eye(n=cos_sim.size()[0], dtype=bool, device=cos_sim.device)

        # Change the diagonals to become really small numbers, zero-ing out their similarity value
        # i.e we dont want the similarity values of z_i and z_i to be in the calculations 
        cos_sim = cos_sim.masked_fill(mask=self_mask, value= -9e15)  

        # Create a mask that corresponds to the [i,j] location of positive pairs. 
        # Since positive pairs are i + N // 2 away from the i'th anchor, we roll the cos_sim matrix as such 
        # Rolling the Identity matrix row-wise will create this effect. 
        pos_mask = self_mask.roll(shifts=len(batch) // 2, dims=0)

        cos_sim = cos_sim / self.hparams.temperature 
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=0)
        nll = torch.mean(nll) # Loss 


        reshaped_cos_sim = cos_sim[pos_mask][:, None] # Create 2D matrix with positive pairs on the first column 
        comb_sim = torch.cat([reshaped_cos_sim, cos_sim.masked_fill(mask=pos_mask, value=-9e15)], dim=1) # concatenate all other values column wise 

        # Sort to find if the true positive pair has the highest similarity value with its column neighbours
        sim_argsort = comb_sim.argsort(dim=1, descending=True).argmin(dim=1)

        # Metrics 
        correct = (sim_argsort == 0).float().mean() # Average true positive pairs that had the highest similarity in their column neighbours 
        top5 = (sim_argsort < 5).float().mean() # Average true positive pairs that were in the top5 highest similarity values 
        mean_position = 1 + sim_argsort.float().mean() # Average position of true positive pairs (using 1-indexing)

        # Logging metrics 
        self.log(f"{mode}/top1", correct, prog_bar=True, on_epoch=True)
        self.log(f"{mode}/top5", top5, prog_bar=True, on_epoch=True)
        self.log(f"{mode}/mean_pos", mean_position, on_epoch=True)
        self.log(f"{mode}/nll_loss", nll, prog_bar=True, on_epoch=True)
        
        return nll
    
    def training_step(self, batch):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch):
        self.info_nce_loss(batch, mode='val')

In [None]:
"""
TESTING CELL! 

Testing InfoNCE loss 
"""

T = TrajectorySet(dataset=minari_dataset) 
S = Sampler(T, dist="l")

ds = DatasetCL(S, batch_size=4, k=2)
batch = ds.get_batch()
model = mlpCL(lr = 1, weight_decay=1)
print(batch)

nll= model.info_nce_loss(batch)

print(nll)

(tensor([[ 1.3912,  2.8372, -3.4308, -0.4404],
        [-1.5635,  0.8735,  1.4885, -1.2872],
        [ 4.4222,  0.3994,  0.0472,  5.0314],
        [-0.5007,  0.8776,  3.9664,  0.4011]]), tensor([[ 1.2745e+00,  2.8308e+00, -4.1192e+00, -4.8980e-02],
        [-1.4805e+00,  2.7749e+00, -3.1463e-03, -4.8342e+00],
        [ 4.4155e+00,  6.2710e-01, -2.3755e-01,  4.0578e+00],
        [ 3.3278e-02,  9.2593e-01,  4.6660e+00,  3.8236e-01]]))


Error: You must call wandb.init() before wandb.log()

In [6]:

DS = data.DataLoader(dataset=ds)
print(next(iter(DS)))

NameError: name 'ds' is not defined

# Trainer


In [5]:
CHECKPOINT_PATH = "../saved_models"
checkpoint_callback = ModelCheckpoint(dirpath=CHECKPOINT_PATH,
                                      filename="best_model", 
                                      save_top_k=3, 
                                      save_weights_only=True, 
                                      mode="max",
                                      monitor="val/top5")

def train_cl(train_ds, val_ds, logger, max_epochs=1000, **kwargs):
    trainer = pl.Trainer(
        default_root_dir=CHECKPOINT_PATH, 
        logger = logger,
        accelerator= "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu", 
        devices=1, 
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback,
                   LearningRateMonitor("epoch")]) # creates a model checkpoint when a new max in val/top5 has been reached 
    train_loader = data.DataLoader(dataset=train_ds, shuffle=True, drop_last=True)
    val_loader = data.DataLoader(dataset= val_ds, shuffle=False, drop_last=True)

    print("train loader:", len(train_loader))
    print("val loader:", len(val_loader))

    pl.seed_everything(10)
    model = mlpCL(max_epochs=max_epochs, **kwargs) 
    trainer.fit(model, train_loader, val_loader)

    print("Best model path:", checkpoint_callback.best_model_path)
    model = mlpCL.load_from_checkpoint(checkpoint_callback.best_model_path)
    
    return model 

In [6]:
config = {
    "distribution": "g",
    "batch_size": 256,
    "k": 2,
    "lr": 5e-4,
    "weight_decay": 1e-4, 
    "temperature": 0.08,
    "max_epochs": 20
}

wandb_logger = WandbLogger(
    project="Contrastive Learning RL", 
    name="test-run", 
    save_dir = project_root, 
    log_model=True
) 

dist = config["distribution"]
batch_size = config["batch_size"]
k = config["k"]
lr = config["lr"]
weight_decay = config["weight_decay"]
temperature = config["temperature"]
max_epochs = config["max_epochs"]

T = TrajectorySet(dataset=minari_dataset)
S = Sampler(T, dist=dist)
train_dataset = DatasetCL(S, batch_size=batch_size, k=k)

val_dataset = DatasetCL(S, batch_size=batch_size, k=k)

print(len(train_dataset))
print(len(val_dataset))

torch.autograd.set_detect_anomaly(True)
trainer = train_cl(train_ds=train_dataset, 
                val_ds=val_dataset, 
                logger=wandb_logger, 
                max_epochs=max_epochs, 
                lr=lr, 
                temperature=temperature, 
                weight_decay = weight_decay)
    
    
    

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 10


256
256
train loader: 256
val loader: 256


[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce CL/contrastive-learning-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce CL/contrastive-learning-RL/saved_models exists and is not empty.

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | mlp  | Sequential | 44.5 K | train
--------------------------------------------
44.5 K    Trainable params
0         Non-trainable params
44.5 K    Total params
0.178     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce CL/contrastive-learning-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce CL/contrastive-learning-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 256/256 [00:05<00:00, 44.96it/s, v_num=r5qf, train/top1_step=1.000, train/top5_step=1.000, train/nll_loss_step=0.000, val/top1=1.000, val/top5=1.000, val/nll_loss=0.000, train/top1_epoch=1.000, train/top5_epoch=1.000, train/nll_loss_epoch=0.000]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 256/256 [00:05<00:00, 44.94it/s, v_num=r5qf, train/top1_step=1.000, train/top5_step=1.000, train/nll_loss_step=0.000, val/top1=1.000, val/top5=1.000, val/nll_loss=0.000, train/top1_epoch=1.000, train/top5_epoch=1.000, train/nll_loss_epoch=0.000]
Best model path: /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce CL/contrastive-learning-RL/saved_models/best_model.ckpt
