In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)
import minari

import torch 
import torch.utils.data as data

import faiss

from src.data.EpisodesDataset import EpisodesDataset

from src.models.cl_model import mlpCL 
from src.models.cmhn import cmhn 
from src.models.beta_model import LearnedBetaModel

from src.data.StatesDataset import StatesDataset
from src.data.TrajectorySet import TrajectorySet 
from src.data.Sampler import Sampler 

from src.trainers.beta_trainer import train_beta_model

from src.utils.sampling_states import sample_states 
from src.utils.tensor_utils import split_data

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

In [2]:
print(torch.cuda.is_available())
print(project_root)

True
c:\Users\ray\Documents\2025 RA\contrastive-learning-RL


In [3]:
# Ensures that the jupyter kernel doesn't crash when running chn calculations with faiss
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"
torch.set_num_threads(1)
faiss.omp_set_num_threads(1)

In [4]:
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")
PROJECT_ROOT = project_root
CHECKPOINT_PATH = PROJECT_ROOT + "/trained_models"

PROJECT_NAME = "Learning Beta Model"
RUN_NAME = "real_run"
FILENAME = RUN_NAME
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

CONFIG = {
        "num_states": 1_000_000,  
        "lr": 1e-3,
        "temperature": 0.03796123348109251, 
        "weight_decay": 1e-5, 
        "masking_ratio": 0.3,
        "beta_max": 200,
        "max_epochs": 100,
        "filename": FILENAME,
        "device": DEVICE,
    }

In [5]:
# Load cmhn model 
mhn = cmhn(max_iter=1000, threshold=0.9999, device=DEVICE)

# Load trained CL model 
model_name = "laplace_cos_sim-v1.ckpt"
pretrained_model_file = os.path.join(PROJECT_ROOT+ "/trained_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device(DEVICE))
else:
    print("Model not found...")

Found pretrained model at c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/trained_models\laplace_cos_sim-v1.ckpt, loading...


In [6]:
# Preprocessing step 

data = sample_states(dataset=MINARI_DATASET, num_states=CONFIG["num_states"])
states = data["states"]

train, val = split_data(states, split_val=0.8) 

print(train.shape)
print(val.shape)

train_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=train)
val_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=val)

(800000, 4)
(200000, 4)


In [7]:
wandb_logger = WandbLogger(
        project=PROJECT_NAME, 
        name=RUN_NAME, 
        save_dir = PROJECT_ROOT, 
        log_model=True,
        config = CONFIG) 

In [8]:
MINIBATCH_SIZE = 4096
print(PROJECT_ROOT)
print(CHECKPOINT_PATH)
print(DEVICE)

c:\Users\ray\Documents\2025 RA\contrastive-learning-RL
c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/trained_models
cuda


In [9]:
model = train_beta_model(
    bm_model=LearnedBetaModel,
    cmhn=mhn, 
    train_ds=train_ds,
    val_ds = val_ds,
    batch_size=MINIBATCH_SIZE, 
    logger=wandb_logger, 
    checkpoint_path=CHECKPOINT_PATH, 
    max_epochs=CONFIG["max_epochs"],
    device=CONFIG["device"], 
    filename= FILENAME,

    # kwaargs
    lr=CONFIG["lr"],
    weight_decay=CONFIG["weight_decay"], 
    masking_ratio=CONFIG["masking_ratio"], 
    beta_max=CONFIG["beta_max"],
    temperature=CONFIG["temperature"]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 10
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Currently logged in as: [33mray-s[0m ([33mray-s-university-of-alberta[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


c:\Users\ray\AppData\Local\anaconda3\envs\CL_RL_gpu\lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:751: Checkpoint directory C:\Users\ray\Documents\2025 RA\contrastive-learning-RL\trained_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params | Mode 
------------------------------------------------
0 | dropout  | Dropout    | 0      | train
1 | beta_net | Sequential | 8.4 K  | train
2 | fc_nn    | Sequential | 51.7 K | train
------------------------------------------------
60.1 K    Trainable params
0         Non-trainable params
60.1 K    Total params
0.240     Total estimated model params size (MB)
16        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\ray\AppData\Local\anaconda3\envs\CL_RL_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\ray\AppData\Local\anaconda3\envs\CL_RL_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 14:  36%|███▌      | 70/195 [07:48<13:56,  0.15it/s, v_num=wmvw, train/nll_loss_step=1.140, train/top1_step=0.691, train/top5_step=0.942, val/nll_loss=3.380, val/top1=0.0594, val/top5=0.315, train/nll_loss_epoch=1.140, train/top1_epoch=0.692, train/top5_epoch=0.934] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


AttributeError: 'tuple' object has no attribute 'tb_frame'