In [1]:
import os 
import sys 

import minari

import torch 
import torch.utils.data as data

from data.EpisodesDataset import EpisodesDataset

from models.cl_model import mlpCL 
from models.cmhn import cmhn 
from models.beta_model import LearnedBetaModel

from data.StatesDataset import StatesDataset
from data.TrajectorySet import TrajectorySet 
from data.Sampler import Sampler 

from trainer.beta_trainer import train_beta_model

from utils.sampling_states import sample_states 
from utils.tensor_utils import split_data

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

In [2]:
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")
PROJECT_ROOT = os.getcwd()
CHECKPOINT_PATH = PROJECT_ROOT + "\saved_beta_models"

PROJECT_NAME = "Learning Beta Model"
RUN_NAME = "test-run-3"
FILENAME = "test_run_model"
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

CONFIG = {
        "num_states": 1_000,  
        "lr": 1e-3,
        "temperature": 0.1, 
        "weight_decay": 1e-5, 
        "masking_ratio": 0.3,
        "beta_max": 200,
        "max_epochs": 10,
        "filename": FILENAME,
        "device": DEVICE,
    }

In [3]:
# Load cmhn model 
mhn = cmhn(update_steps=1, device=DEVICE)

# Load trained CL model 
model_name = "best_model_laplacian.ckpt"
pretrained_model_file = os.path.join(PROJECT_ROOT+ "/best_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device(DEVICE))

Found pretrained model at /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/best_models/best_model_laplacian.ckpt, loading...


In [4]:
# Preprocessing step 

data = sample_states(dataset=MINARI_DATASET, num_states=CONFIG["num_states"])
train, val = split_data(data, split_val=0.8) 

print(train.shape)
print(val.shape)

train_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=train)
val_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=val)

(800, 4)
(200, 4)


In [5]:
wandb_logger = WandbLogger(
        project=PROJECT_NAME, 
        name=RUN_NAME, 
        save_dir = PROJECT_ROOT, 
        log_model=True,
        config = CONFIG) 

In [6]:
MINIBATCH_SIZE = CONFIG["num_states"]
#print(PROJECT_ROOT)
#print(CHECKPOINT_PATH)

In [7]:
model = train_beta_model(
    bm_model=LearnedBetaModel,
    cmhn=mhn, 
    train_ds=train_ds,
    val_ds = val_ds,
    batch_size=1, 
    logger=wandb_logger, 
    checkpoint_path=CHECKPOINT_PATH, 
    max_epochs=CONFIG["max_epochs"],
    device=CONFIG["device"], 
    filename= FILENAME,

    # kwaargs
    lr=CONFIG["lr"],
    weight_decay=CONFIG["weight_decay"], 
    masking_ratio=CONFIG["masking_ratio"], 
    beta_max=CONFIG["beta_max"],
    temperature=CONFIG["temperature"]
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 10
[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL\saved_beta_models exists and is not empty.

  | Name     | Type       | Params | Mode 
------------------------------------------------
0 | dropout  | Dropout    | 0      | train
1 | beta_net | Sequential | 8.4 K  | train
2 | fc_nn    | Sequential | 4.2 K  | train
------------------------------------------------
12.6 K    Trainable params
0         Non-trainable params
12.6 K    Total params
0.050     Total estimated model params size (MB)
12        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/CL_RL/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/800 [00:00<?, ?it/s] batch type: <class 'torch.Tensor'>
batch shape: torch.Size([1, 32])
batch dtype: torch.float32
batch mean: 2.7164931297302246
Epoch 0:   0%|          | 1/800 [00:01<21:07,  0.63it/s, v_num=0a53, train/nll_loss_step=0.000, train/top1_step=1.000]batch type: <class 'torch.Tensor'>
batch shape: torch.Size([1, 32])
batch dtype: torch.float32
batch mean: 3.466097116470337
Epoch 0:   0%|          | 2/800 [00:01<10:37,  1.25it/s, v_num=0a53, train/nll_loss_step=0.000, train/top1_step=1.000]batch type: <class 'torch.Tensor'>
batch shape: torch.Size([1, 32])
batch dtype: torch.float32
batch mean: 1.9046266078948975
Epoch 0:   0%|          | 3/800 [00:01<07:06,  1.87it/s, v_num=0a53, train/nll_loss_step=0.000, train/top1_step=1.000]batch type: <class 'torch.Tensor'>
batch shape: torch.Size([1, 32])
batch dtype: torch.float32
batch mean: 2.7505640983581543
Epoch 0:   0%|          | 4/800 [00:01<05:21,  2.47it/s, v_num=0a53, train/nll_loss_step=0.000


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined