In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import minari

import torch 
import torch.utils.data as data
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

from models.cl_model import mlpCL 
from models.cmhn import cmhn 

from data.StatesDataset import StatesDataset
from data.TrajectorySet import TrajectorySet 
from data.Sampler import Sampler 

from utils.sampling_states import sample_states
from utils.tensor_utils import split_data

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger


In [2]:
# Import minari dataset
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")

# Load cmhn model 
DEVICE = "mps"
mhn = cmhn(update_steps=1, device=DEVICE)

# Load trained CL model 
model_name = "best_model.ckpt"
pretrained_model_file = os.path.join(project_root+ "/saved_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device(DEVICE))

Found pretrained model at /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/saved_models/best_model.ckpt, loading...


In [3]:
class LearnedBetaModel(pl.LightningModule): 
    def __init__(self, cmhn, beta_max, lr=1e-3, weight_decay=1e-5, temperature=0.1, masking_ratio=0.3, max_epochs=1000, input_dim=32, h1=128, h2=32, fc_h1 = 64, device="cpu"):
        super().__init__() 
        self.save_hyperparameters()
        self.cmhn = cmhn 
        self.device_type = torch.device(device=device)

        self.dropout = nn.Dropout(p=masking_ratio, inplace=False)

        self.beta_net = nn.Sequential(
            nn.Linear(input_dim, h1),
            nn.ReLU(), 

            nn.Linear(h1, h2), 
            nn.ReLU(),

            nn.Linear(h2, 1),
            nn.Sigmoid() 
        ).to(self.device_type)

        self.fc_nn = nn.Sequential( 
            nn.Linear(input_dim, fc_h1),
            nn.ReLU(), 
            nn.Linear(fc_h1, input_dim)
        ).to(self.device_type)
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return ([optimizer], [lr_scheduler])

    def loss(self, batch, mode="train"): 
        """
        The loss function for the beta network. 

        Args: 
            batch: The batch data that the beta network will use (z representations). 
        
        Returns: 
            loss: The infoNCE loss. 
        """
        batch = batch

        # get the trial beta 
        beta = self.beta_net(batch)

        beta = beta * self.hparams.beta_max

        # get abstract representation 'u' 
        U = self.cmhn.run(batch, batch, beta, run_as_batch=True) 

        # get the noisy batch, nn.Dropout uses scaling=True to maintain expected value of tensor
        z_prime = self.dropout(batch)

        # create positive pairs
        pairs = torch.cat([U, z_prime], dim=0)
      
        # put new batch pairs into fc_nn to obtain vectors in new embedding space useful for contrastive learning 
        p = self.fc_nn(pairs)


        ######################################################################
        #     use p for contrastive loss 
        ######################################################################

        N = p.size(0) // 2

        sim = torch.matmul(p, p.T) / self.hparams.temperature # cosine sim matrix [2N, 2N]
        #print("sim: ", sim)

        # mask diagonals to large negative numbers so we don't calculate same state similarities
        mask = torch.eye(2 * N, device=sim.device).bool()
        sim = sim.masked_fill_(mask, -9e15)

        # positives: i-th sample matches i + N mod 2N
        labels = (torch.arange(2 * N, device=sim.device) + N) % (2 * N)

        loss = F.cross_entropy(sim, labels) # over mean reduction 

        # extra statistics 
        if mode=="train": 
            with torch.no_grad(): 
                norms = torch.norm(p, dim=1)
                self.log(f"{mode}/sim_mean", sim.mean(), on_epoch=True)
                self.log(f"{mode}/sim_std", sim.std(), on_epoch=True)
                self.log(f"{mode}/p_norm_mean", norms.mean(), on_epoch=True)
                self.log(f"{mode}/p_norm_std", norms.std(), on_epoch=True)
                self.log(f"{mode}/beta_mean", beta.mean(), on_epoch=True)

        # metrics
        preds = sim.argmax(dim=1)
        top1 = (preds == labels).float().mean()   # top1: true positive is most similar to anchor 
        #top5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(dim=1).float().mean() # top5: true positive is atleast in the top 5 most similar to anchor 

        self.log(f"{mode}/nll_loss", loss, on_epoch=True, prog_bar=True)
        self.log(f"{mode}/top1", top1, on_epoch=True, prog_bar=True)
        #self.log(f"{mode}/top5", top5, on_epoch=True, prog_bar=True)

        return loss
    
    def training_step(self, batch):
        return self.loss(batch, mode='train')

    def validation_step(self, batch):
        self.loss(batch, mode='val')

In [4]:
data = sample_states(dataset=MINARI_DATASET, num_states=10)
train, val = split_data(data, split_val=0.8) 

train_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=train)
val_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=val)

data = torch.as_tensor(data, dtype=torch.float32)
with torch.no_grad(): 
    z = cl_model(data) 


In [5]:
bm = LearnedBetaModel(cmhn=mhn, beta_max=200, device=DEVICE)
loss = bm.loss(z)
loss 

/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/core/module.py:441: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(754.7344, device='mps:0', grad_fn=<NllLossBackward0>)