# Imports


In [83]:
import minari 
import numpy as np
import wandb
import os

# Torch 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import torch.utils.data as data 

# PyTorch Lightning 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

In [2]:
dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

In [64]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Prior Functionality 

In [96]:
class TrajectorySet: 
    def __init__(self, dataset): 
        """
        dataset: The minari dataset to use. 
        trajectories: a dictionary housing all of the trajectories. The dictionary structure is: 
            {
                1: [trajectory, length of trajectory]
                2: [ ... ]
                etc...
            } 

        num_trajectories: the number of trajectories currently in the set. 
        """
        self.dataset = dataset

        self.trajectories = {} 
        self.num_trajectories = 0 
    
    def add_trajectory(self, trajectory):
        self.trajectories[self.num_trajectories] = [trajectory, len(trajectory)] 
        self.num_trajectories += 1
    
    def get_num_trajectories(self):
        return self.num_trajectories

    def get_trajectory(self, index): 
        assert index < self.num_trajectories, "Specified index is too large."
        return self.trajectories[index]
    
    def get_trajectory_set(self): 
        return self.trajectories
    
    def get_total_states(self): 
        sum = 0
        for _, v in self.trajectories.items(): 
            sum += v[1]
        return sum 

    def generate_trajectories(self, n_trajectories: int = 2): 
        """
        Generates a specified number of trajectories and saves them into the TrajectorySet class. 

        This runs the scripted agent, where the agent uses a PD controller to follow a 
        path of waypoints generated with QIteration until it reaches the goal.

        Args: 
            n_trajectories: The number of trajectories to generate. 
        """
        ep_data = self.dataset.sample_episodes(n_episodes=n_trajectories) # sample trajectories

        # adds all of the sampled trajectories into the TrajectorySet 
        for i in range(len(ep_data)):
            ep = ep_data[i] 

            # Note: only saving states since we only need state representations in the encoder 
            self.add_trajectory(ep.observations["observation"]) 


###########################################
#  Sampler Class
###########################################
class Sampler(): 
    def __init__(self, T: TrajectorySet, dist="g"): 
        """
        T: The Trajectory Set class 
        dist: The distribution used for centering over the anchor state. 
            ['u', 'g', 'l', 'e'] - uniform, gaussian, laplace, exponential
        """

        self.T = T 
        self.dist = dist

    def sample_anchor_state(self, t: list) -> tuple[list, int]: 
        """
        Given a trajectory, we sample the anchor state s_i uniformly. 

        Args: 
            t: The given trajectory we sample from. 

        Returns: 
            A tuple containing [s_i, idx]
            s_i: The state that is sampled, represented as a list of (x,y) coordinates and velocities. 
            idx: The time step of s_i. 
        """
        idx = torch.randint(low=0, high=len(t), size=(1,)).item()
        s_i = t[idx] 
        return [s_i, idx]

    def sample_positive_pair(self, t: list, anchor_state: tuple[list, int]) -> tuple[list, int]: 
        """
        Given the same trajectory that s_i was sampled from, 
        center a gaussian distribution around s_i to get obtain its positve pair: s_j. 
        
        Args: 
            t: The given trajectory, which must be the same as the trajectory that was used to sample the anchor state. 
            anchor_state: The anchor state; a tuple containing [s_i, idx].
            s_i: The state itself.
            idx: The time step of s_i.
            
        
        Return: 
            A tuple containing [s_j, idx]
            s_j: The state that is sampled, represented as a list of (x,y) coordinates and velocities. 
            idx: The time step of s_j.    
        """
        std = 15     # we use 15 to replicate the paper's hyperparams 
        b = 15       # laplace scale hyper param
        gamma = 0.99 # exponential hyper param 

        _, si_idx = anchor_state

        while True: 
            if self.dist == "u": 
                # uniform 
                sj_idx = torch.randint(low=0, high=len(t), size=(1,))
            elif self.dist == "g": 
                # gaussian 
                sj_idx = torch.normal(mean=si_idx, std=std, size=(1,))
            elif self.dist == "l": 
                # laplacian
                sj_idx = torch.distributions.laplace.Laplace(loc=si_idx, scale=b).sample() 
            elif self.dist == "e": 
                # exponential 
                i = int(torch.distributions.exponential.Exponential(rate=gamma).sample()) + 1   # +1 so we don't get an offset of 0
                sj_idx = si_idx + i 
            else: 
                # default to gaussian
                sj_idx = torch.normal(mean=si_idx, std=std, size=(1,))

            sj_idx = int(sj_idx) 

            # Ensures we don't choose an index out of range or the same state. 
            if (sj_idx < len(t)) and (sj_idx > 0) and (sj_idx != si_idx): 
                break 
        
        s_j = t[sj_idx] 

        return [s_j, sj_idx]
    
    def sample_batch(self, batch_size=1024, k=2) -> list[tuple]: 
        """ 
        Creates a batch of anchor states, their positive pairs, and negative pairs. 
        There will be 2(batch_size - 1) amount of negative examples per positive pair.

        Args: 
            T: The trajectory set class (must be empty). 
            batch_size: The size of the batch to be generated.
            k: A hyperparameter that dictates the average number of 
                positive pairs sampled from the same trajectory. The 
                lower the number, the lesser the chance of false negatives. 
        
        Returns: 
            A list of tuples containing the anchor_state and its positive pair. 
            The list is the same length as batch_size. 
        """ 

        batch = [] 

        # Generate trajectory set 
        n_trajectories = batch_size // k
        self.T.generate_trajectories(n_trajectories= n_trajectories)

        for _ in range(batch_size): 
            # Sample anchor state 
            rng = torch.randint(low=0, high=n_trajectories, size=(1,)).item() 
            t = self.T.get_trajectory(index=rng)[0]
            
            anchor_state = self.sample_anchor_state(t) 

            # Sample positive pair 
            positive_pair = self.sample_positive_pair(t, anchor_state=anchor_state)

            # Retrieve states; time-steps aren't necessary. 
            s_i = anchor_state[0]
            s_j = positive_pair[0]

            batch.append([s_i, s_j]) 

        return batch 

In [20]:
def convert_batch_to_tensor(batch: list[list, list]) -> tuple[torch.tensor, torch.tensor]: 
    """
    Converts the batch to a tuple of tensors. 
    The first tensor corresponds to the anchor states.
    The second tensor corresponds to their corresponding positive pair. 
    i.e. i'th anchor state in the first tensor will have its positive pair be in the i'th state in the second tensor. 
    """

    #unzips the batch into two tuples
    a, b = zip(*batch)  

    # stack arrays row-wise and then convert to tensor of dtype float (to be compatible w/ model weights)
    a_t = torch.tensor(np.stack(a, axis=0), dtype= torch.float32)
    b_t = torch.tensor(np.stack(b, axis=0), dtype= torch.float32)

    return (a_t, b_t)

# Model Architecture

In [93]:
class mlpCL(pl.LightningModule): 
    def __init__(self, lr, weight_decay, temperature=30, max_epochs=1000, h1=256, h2=128, h3=64, h4=32):
        super().__init__() # inherit from LightningModule and nn.module 
        self.save_hyperparameters() # save args  

        self.mlp = nn.Sequential(
            nn.Linear(4, h1), 
            nn.ReLU(inplace=True), 

            nn.Linear(h1, h2), 
            nn.ReLU(inplace=True),

            nn.Linear(h2, h3), 
            nn.ReLU(inplace=True),

            nn.Linear(h3, h4), # representation z 
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return [optimizer, lr_scheduler]
    
    def info_nce_loss(self, batch, mode="train"): 
        # Organizes the states such that their positive pairs are (i + N // 2) away. 
        batch = convert_batch_to_tensor(batch=batch) 
        batch = torch.cat(batch, dim=0)  

        # Encode states 
        z = self.mlp(batch)

        # Get cosine similarity matrix, where the i'th row and j'th index correspond to the 
        # similarity between z_i and z_j 
        cos_sim = F.cosine_similarity(x1=z[:, None, :], x2=z[None, :, :], dim=-1)

        # Create a boolean mask where the diagonals are true
        self_mask = torch.eye(n=cos_sim.size()[0], dtype=bool, device=cos_sim.device)

        # Change the diagonals to become really small numbers, zero-ing out their similarity value
        # i.e we dont want the similarity values of z_i and z_i to be in the calculations 
        cos_sim.masked_fill_(mask=self_mask, value= -9e15)  

        # Create a mask that corresponds to the [i,j] location of positive pairs. 
        # Since positive pairs are i + N // 2 away from the i'th anchor, we roll the cos_sim matrix as such 
        # Rolling the Identity matrix row-wise will create this effect. 
        pos_mask = self_mask.roll(shifts=len(batch) // 2, dims=0)

        cos_sim = cos_sim / self.hparams.temperature 
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=0)
        nll = torch.mean(nll)

        reshaped_cos_sim = cos_sim[pos_mask][:, None] # Create 2D matrix with positive pairs on the first column 
        comb_sim = torch.cat([reshaped_cos_sim, cos_sim.masked_fill_(mask=pos_mask, value=-9e15)], dim=1) # concatenate all other values column wise 

        # Sort to find if the true positive pair has the highest similarity value with its column neighbours
        sim_argsort = comb_sim.argsort(dim=1, descending=True).argmin(dim=1)

        # Metrics 
        correct = (sim_argsort == 0).float().mean() # Average true positive pairs that had the highest similarity in their column neighbours 
        top5 = (sim_argsort < 5).float().mean() # Average true positive pairs that were in the top5 highest similarity values 
        mean_position = 1 + sim_argsort.float().mean() # Average position of true positive pairs (using 1-indexing)

        # Logging metrics 
        wandb.log({
            f"{mode}/nll_loss": nll.item(), 
            f"{mode}/top1": correct.item(), 
            f"{mode}/top5": top5.item(), 
            f"{mode}/mean_pos": mean_position.item()
        })
        
        return nll
    
    def training_step(self, batch):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch):
        self.info_nce_loss(batch, mode='val')

In [81]:
"""
TESTING CELL! 

Testing InfoNCE loss 
"""

T = TrajectorySet() 
S = Sampler(T, dist="l")

batch = S.sample_batch(4)

model = mlpCL(lr = 1, weight_decay=1)

#nll= model.info_nce_loss(batch)

#print(nll)

In [92]:
CHECKPOINT_PATH = "../saved_models"

path = os.path.join(CHECKPOINT_PATH)
path
os.getcwd()

'/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce CL/contrastive-learning-RL/notebooks'

# Trainer


In [None]:
CHECKPOINT_PATH = "../saved_models"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

def train_cl(batch_size, max_epochs=1000, **kwargs):
    trainer = pl.Trainer(
        default_root_dir=CHECKPOINT_PATH, 
        
        accelerator="gpu" if str(device).startswith("cuda") else "cpu", 
        devices=1, 
        max_epochs=max_epochs,
        callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", moniter="val/top5"),
                   LearningRateMonitor("epoch")])
    
    

    pass