In [1]:
# Personal 
from data.TrajectorySet import TrajectorySet
from data.Sampler import Sampler 
from data.DatasetCL import DatasetCL 
from models.cl_model import mlpCL
from trainer.cl_trainer import train_cl
from utils.visualizations import visualize_embeddings
from utils.tensor_utils import split_data

# Misc
import minari 
import os

# Torch 
import torch 

# PyTorch Lightning 
import pytorch_lightning
from pytorch_lightning.loggers import WandbLogger

In [2]:
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")
PROJECT_ROOT = os.getcwd() 
CHECKPOINT_PATH = PROJECT_ROOT + "/saved_models"

PROJECT_NAME = "Contrastive Learning RL"
RUN_NAME = "best-model-laplace-15"
FILENAME = "best_model_laplace_15"

CONFIG = {
        "distribution": "l",
        "num_state_pairs": 1_000_000,
        "batch_size": 4096,
        "k": 1,
        "lr": 1e-3,
        "weight_decay": 1e-5, 
        "temperature": 30,
        "max_epochs": 1000,
        "filename": FILENAME,
        "device": "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
    }

In [3]:
wandb_logger = WandbLogger(
        project=PROJECT_NAME, 
        name=RUN_NAME, 
        save_dir = PROJECT_ROOT, 
        log_model=True,
        config = CONFIG) 

dist = CONFIG["distribution"]
num_state_pairs = CONFIG["num_state_pairs"]
batch_size = CONFIG["batch_size"]
k = CONFIG["k"]
lr = CONFIG["lr"]
weight_decay = CONFIG["weight_decay"]
temperature = CONFIG["temperature"]
max_epochs = CONFIG["max_epochs"]
device = CONFIG["device"]
filename = CONFIG["filename"]

In [4]:
T = TrajectorySet(dataset=MINARI_DATASET)
S = Sampler(T, dist=dist, b=15, sigma=15)

ds = DatasetCL(S, num_state_pairs=num_state_pairs, k=k)
data = ds.get_batch() #change to list format to ensure correct split

train, val = split_data(data=data, split_val=0.8)

train_dataset = DatasetCL(data=train, k = k)
val_dataset = DatasetCL(data= val, k = k)

KeyboardInterrupt: 

In [5]:
MINIBATCH_SIZE = batch_size
print("Minibatch size:", MINIBATCH_SIZE)
print(type(MINIBATCH_SIZE))

Minibatch size: 4096
<class 'int'>


In [None]:
model = train_cl(cl_model=mlpCL, 
                train_ds=train_dataset, 
                val_ds = val_dataset,
                batch_size= MINIBATCH_SIZE,
                logger=wandb_logger, 
                checkpoint_path=CHECKPOINT_PATH,
                max_epochs=max_epochs,
                filename=filename,  
                device = device, 
                lr=lr, 
                temperature=temperature, 
                weight_decay = weight_decay)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 10
/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/saved_models exists and is not empty.

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | mlp  | Sequential | 44.5 K | train
--------------------------------------------
44.5 K    Trainable params
0         Non-trainable params
44.5 K    Total params
0.178     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode
/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `n

Epoch 43:  15%|█▌        | 30/195 [00:10<00:57,  2.87it/s, v_num=m3ly, train/nll_loss_step=5.210, train/top1_step=0.0667, train/top5_step=0.210, train/nll_loss_epoch=5.210, train/top1_epoch=0.0652, train/top5_epoch=0.211]   


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined