In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import minari

import torch 
import torch.utils.data as data
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

from models.cl_model import mlpCL 
from models.cmhn import cmhn 

from data.StatesDataset import StatesDataset
from data.TrajectorySet import TrajectorySet 
from data.Sampler import Sampler 

from utils.tensor_utils import split_data

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger


In [2]:
# Import minari dataset
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")

# Load cmhn model 
DEVICE = "mps"
mhn = cmhn(update_steps=1, device=DEVICE)

# Load trained CL model 
model_name = "best_model.ckpt"
pretrained_model_file = os.path.join(project_root+ "/saved_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device("cpu"))

Found pretrained model at /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/saved_models/best_model.ckpt, loading...


In [3]:
class LearnedBetaModel(pl.LightningModule): 
    def __init__(self, cmhn, beta_max, lr=1e-3, weight_decay=1e-5, masking_ratio=0.3, max_epochs=1000, input_dim=32, h1=128, h2=32, fc_h1 = 64, device="cpu"):
        super().__init__() 
        self.save_hyperparameters()
        self.cmhn = cmhn 
        self.device_type = torch.device(device=device)

        self.dropout = nn.Dropout(p=masking_ratio, inplace=False)

        self.beta_net = nn.Sequential(
            nn.Linear(input_dim, h1),
            nn.ReLU(), 

            nn.Linear(h1, h2), 
            nn.ReLU(),

            nn.Linear(h2, 1),
            nn.Sigmoid() 
        ).to(self.device_type)

        self.fc_nn = nn.Sequential( 
            nn.Linear(input_dim, fc_h1),
            nn.ReLU(), 
            nn.Linear(fc_h1, input_dim)
        ).to(self.device_type)
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return ([optimizer], [lr_scheduler])

    def loss(self, batch, mode="train"): 
        """
        The loss function for the beta network. 

        Args: 
            batch: The batch data that the beta network will use (z representations). 
        
        Returns: 
            loss: The infoNCE loss. 
        """
        batch = batch.to(self.device_type)
        
        print("batch: ", batch)
        print("batch size: ", batch.size())

        # get the trial beta 
        beta = self.beta_net(batch)

        # get abstract representation 'u' 
        U = self.cmhn.run(batch, batch, beta, run_as_batch=True) 

        # get the noisy batch, nn.Dropout uses scaling=True to maintain expected value of tensor
        z_prime = self.dropout(batch)

        # create positive pairs
        pairs = torch.cat([U, z_prime], dim= -2)

        print("pairs size: ", pairs.size())
      
        # put new batch pairs into fc_nn to obtain vectors in new embedding space useful for contrastive learning 
        p = self.fc_nn(pairs)

        # DEBUGGING STUFF: 
        #print("z rep: ", batch)
        #print("beta: ", beta)
        #print("U: ", U)
        #print("U size: ", U.size())
        #print("pairs: ", pairs)
        #print("pairs size: ", pairs.size())
        #print("p: ", p)
        #print("p size: ", p.size())


        ######################################################################
        #     use p for contrastive loss 
        ######################################################################

        N = p.size(-2) // 2

        print("N: ", N)

        # normalize vector embedding
        p = F.normalize(p, dim=-1)

        print("p size: ", p.size())

        sim = torch.matmul(p, torch.transpose(p, -2, -1)) # cosine sim matrix [2N, 2N]
        #print("sim: ", sim)

        # mask diagonals to large negative numbers so we don't calculate same state similarities
        mask = torch.eye(2 * N, device=sim.device).bool()
        sim = sim.masked_fill_(mask, -9e15)

        # positives: i-th sample matches i + N mod 2N
        labels = (torch.arange(2 * N, device=sim.device) + N) % (2 * N)

        print("sim size: ", sim.size())

        loss = F.cross_entropy(sim, labels) # over mean reduction 

        # extra statistics 
        if mode=="train": 
            with torch.no_grad(): 
                norms = torch.norm(p, dim=1)
                self.log(f"{mode}/sim_mean", sim.mean(), on_epoch=True)
                self.log(f"{mode}/sim_std", sim.std(), on_epoch=True)
                self.log(f"{mode}/p_norm_mean", norms.mean(), on_epoch=True)
                self.log(f"{mode}/p_norm_std", norms.std(), on_epoch=True)

        # metrics
        preds = sim.argmax(dim=1)
        top1 = (preds == labels).float().mean()   # top1: true positive is most similar to anchor 
        top5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(dim=1).float().mean() # top5: true positive is atleast in the top 5 most similar to anchor 

        self.log(f"{mode}/nll_loss", loss, on_epoch=True, prog_bar=True)
        self.log(f"{mode}/top1", top1, on_epoch=True, prog_bar=True)
        self.log(f"{mode}/top5", top5, on_epoch=True, prog_bar=True)

        return loss
    
    def training_step(self, batch):
        return self.loss(batch, mode='train')

    def validation_step(self, batch):
        self.loss(batch, mode='val')

    def debugging(self): 
        for name, param in self.beta_net.named_parameters():
            print(f"{name} device:", param.device)

In [None]:
T = TrajectorySet(dataset=MINARI_DATASET)
S = Sampler(T=T) 

batch_size = 100
split = 0.7
train_split = int(batch_size)
#val_split = int(batch_size * (1 - split))

train_ds = StatesDataset(cl_model = cl_model, sampler=S, num_states=train_split)
#val_ds = StatesDataset(cl_model = cl_model, sampler=S, num_states=val_split)

minibatch_size = 10 
train_loader = data.DataLoader(dataset=train_ds, batch_size=minibatch_size, shuffle=True, drop_last=True)
#val_loader = data.DataLoader(dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False)

device = "mps"

bm = LearnedBetaModel(cmhn=mhn, beta_max=200, max_epochs=1, device=device)

trainer = pl.Trainer(
    accelerator=device,
)

trainer.fit(model= bm, train_dataloaders=train_loader)


  states = torch.tensor(sampler.sample_states(batch_size=num_states), dtype=torch.float32)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/ray/Documents/Res

Epoch 0:   0%|          | 0/10 [00:00<?, ?it/s] batch:  tensor([[ 7.0562e+00, -5.8528e+00,  6.5353e+00,  2.3514e+00, -2.8791e+00,
         -1.7313e+00,  3.2142e+00,  2.5070e+01, -1.0900e+01, -1.4021e+01,
          1.2437e+01,  4.0531e+00, -3.9539e-02,  1.2092e+00, -3.3643e+00,
          1.9345e+00, -1.5970e+01, -4.4869e+00,  2.2603e+00, -4.4064e+01,
          6.5065e+00, -2.0252e+01,  6.5319e+00, -1.9507e+01, -3.3382e-01,
         -2.4224e+00,  1.5023e+01,  5.7491e+00,  3.2347e+01,  6.0673e+00,
         -4.5899e+00,  2.0640e+01],
        [-7.6892e+00, -3.4949e-01,  3.3495e-01, -1.5433e+01,  1.0744e+01,
          2.8085e+00,  4.3322e+00,  1.5831e+01, -1.1246e+01, -8.3291e+00,
          9.0711e+00,  4.3851e+00, -1.8073e+01,  2.5675e-01, -5.3219e+00,
          1.1384e+01, -2.7354e+00, -9.1010e+00,  1.1904e+01, -3.0936e+01,
          1.9790e+01, -1.2092e+01,  1.4305e+01, -3.2855e+01, -2.2950e+00,
          9.3239e+00,  1.4638e+01,  3.0674e+00,  2.9820e+01, -1.2097e+01,
          3.7445e+00