In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import minari

import torch 
import torch.utils.data as data

from data.EpisodesDataset import EpisodesDataset

from models.cl_model import mlpCL 
from models.cmhn import cmhn 
from models.beta_model import LearnedBetaModel

from utils.tensor_utils import split_data

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

In [2]:
def train_beta_model(bm_model, cmhn, train_ds, val_ds, batch_size, logger, checkpoint_path, max_epochs=1000, device="cpu", filename= "best_model", **kwargs):
    # Create model checkpoints based on the top5 metric
    filename = kwargs.pop("filename", filename) 
    
    checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path,
                                      filename=filename, 
                                      save_top_k=3, 
                                      save_weights_only=True, 
                                      mode="max",
                                      monitor="val/top5")
    
    trainer = pl.Trainer(
        default_root_dir=checkpoint_path, 
        logger = logger,
        accelerator= "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu", 
        devices=1, 
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback,
                   LearningRateMonitor("epoch")]) # creates a model checkpoint when a new max in val/top5 has been reached 
    train_loader = data.DataLoader(dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = data.DataLoader(dataset= val_ds, batch_size=batch_size, shuffle=False, drop_last=False)
    pl.seed_everything(10)
    model = bm_model(cmhn=cmhn, max_epochs=max_epochs, device=device, **kwargs) 
    trainer.fit(model, train_loader, val_loader)

    print("Best model path:", checkpoint_callback.best_model_path)
    model = bm_model.load_from_checkpoint(checkpoint_callback.best_model_path)
    
    return model 

In [3]:
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")
PROJECT_ROOT = project_root
CHECKPOINT_PATH = PROJECT_ROOT + "/saved_beta_models"

PROJECT_NAME = "Learning Beta Model"
RUN_NAME = "test-run"
FILENAME = "test_run_model"

CONFIG = {
        "num_eps": 10,  # max eps is 3360
        "lr": 1e-3,
        "weight_decay": 1e-5, 
        "masking_ratio": 0.3,
        "beta_max": 200,
        "max_epochs": 10,
        "filename": FILENAME,
        "device": "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
    }

In [4]:
# Load cmhn model 
DEVICE = "mps"
mhn = cmhn(update_steps=1, device=DEVICE)

# Load trained CL model 
model_name = "best_model.ckpt"
pretrained_model_file = os.path.join(PROJECT_ROOT+ "/saved_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device(DEVICE))

Found pretrained model at /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/saved_models/best_model.ckpt, loading...


# Dataset Preprocessing

In [5]:
episodeData = MINARI_DATASET.sample_episodes(n_episodes=CONFIG["num_eps"])
train, val = split_data(episodeData, split_val=0.7) 

train_ds = EpisodesDataset(cl_model=cl_model, episodeData=train)
val_ds = EpisodesDataset(cl_model=cl_model, episodeData=val)

# Logger

In [6]:
wandb_logger = WandbLogger(
        project=PROJECT_NAME, 
        name=RUN_NAME, 
        save_dir = PROJECT_ROOT, 
        log_model=True,
        config = CONFIG) 

# Training

In [7]:
MINIBATCH_SIZE = 1

In [8]:
model = train_beta_model(
    bm_model=LearnedBetaModel,
    cmhn=mhn, 
    train_ds=train_ds,
    val_ds=val_ds, 
    batch_size=MINIBATCH_SIZE, 
    logger=wandb_logger, 
    checkpoint_path=CHECKPOINT_PATH, 
    max_epochs=CONFIG["max_epochs"],
    device=CONFIG["device"], 
    filename= FILENAME,

    # kwaargs
    lr=CONFIG["lr"],
    weight_decay=CONFIG["weight_decay"], 
    masking_ratio=CONFIG["masking_ratio"], 
    beta_max=CONFIG["beta_max"],
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 10
[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name     | Type       | Params | Mode 
------------------------------------------------
0 | dropout  | Dropout    | 0      | train
1 | beta_net | Sequential | 8.4 K  | train
2 | fc_nn    | Sequential | 4.2 K  | train
------------------------------------------------
12.6 K    Trainable params
0         Non-trainable params
12.6 K    Total params
0.050     Total estimated model params size (MB)
12        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  sims = X @ queries.T   # shape [N, N]


torch.Size([1, 603, 32])
torch.Size([1, 603, 32])


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [32, 32] but got: [32, 603].