In [59]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import minari

import torch 
import torch.utils.data as data
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

from models.cl_model import mlpCL 
from models.cmhn import cmhn 

from data.DatasetCL import DatasetCL 
from data.Sampler import Sampler 
from data.TrajectorySet import TrajectorySet

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger


In [60]:
# Import minari dataset
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")

# Load cmhn model 
DEVICE = "cuda"
mhn = cmhn(update_steps=1, device=DEVICE)

# Load trained CL model 
model_name = "best_model.ckpt"
pretrained_model_file = os.path.join(project_root+ "/saved_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file)

Found pretrained model at c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/saved_models\best_model.ckpt, loading...


In [61]:
class LearnedBetaModel(pl.LightningModule): 
    def __init__(self, cmhn, beta_max, lr=1e-3, weight_decay=1e-5, masking_ratio=0.3, max_epochs=1000, input_dim=32, h1=128, h2=32, fc_h1 = 64, device="cpu"):
        super().__init__() 
        self.save_hyperparameters()
        self.cmhn = cmhn 
        self.device_type = torch.device(device=device)

        self.dropout = nn.Dropout(p=masking_ratio, inplace=False)

        self.beta_net = nn.Sequential(
            nn.Linear(input_dim, h1),
            nn.ReLU(), 

            nn.Linear(h1, h2), 
            nn.ReLU(),

            nn.Linear(h2, 1),
            nn.Sigmoid() 
        ).to(self.device_type)

        self.fc_nn = nn.Sequential( 
            nn.Linear(input_dim, fc_h1),
            nn.ReLU(), 
            nn.Linear(fc_h1, input_dim)
        ).to(self.device_type)
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return ([optimizer], [lr_scheduler])

    def loss(self, batch, mode="train"): 
        """
        The loss function for the beta network. 

        Args: 
            batch: The batch data that the beta network will use (z representations). 
        
        Returns: 
            loss: The infoNCE loss. 
        """
        batch = batch.to(self.device_type)

        # get the trial beta 
        beta = self.beta_net(batch)

        # get abstract representation 'u' using the hopfield network (from z and beta)
        U = self.cmhn.run(batch, batch, beta, run_as_batch=True) 

        # get the noisy batch, nn.Dropout uses scaling=True to maintain expected value of tensor
        z_prime = self.dropout(batch)

        # create positive pairs
        pairs = torch.cat([U, z_prime], dim=0)
      
        # put new batch pairs into fc_nn to obtain vectors in new embedding space useful for contrastive learning 
        p = self.fc_nn(pairs)

        # DEBUGGING STUFF: 
        #print("z rep: ", batch)
        #print("beta: ", beta)
        #print("U: ", U)
        #print("U size: ", U.size())
        #print("pairs: ", pairs)
        #print("pairs size: ", pairs.size())
        #print("p: ", p)
        #print("p size: ", p.size())


        ######################################################################
        #     use p for contrastive loss 
        ######################################################################

        N = p.size(0) // 2

        # normalize vector embedding
        p = F.normalize(p, dim=1)

        sim = torch.matmul(p, p.T) # cosine sim matrix [2N, 2N]
        #print("sim: ", sim)

        # mask diagonals to large negative numbers so we don't calculate same state similarities
        mask = torch.eye(2 * N, device=sim.device).bool()
        sim = sim.masked_fill_(mask, -9e15)

        # positives: i-th sample matches i + N mod 2N
        labels = (torch.arange(2 * N, device=sim.device) + N) % (2 * N)

        loss = F.cross_entropy(sim, labels) # over mean reduction 

        # extra statistics 
        if mode=="train": 
            with torch.no_grad(): 
                norms = torch.norm(p, dim=1)
                self.log(f"{mode}/sim_mean", sim.mean(), on_epoch=True)
                self.log(f"{mode}/sim_std", sim.std(), on_epoch=True)
                self.log(f"{mode}/p_norm_mean", norms.mean(), on_epoch=True)
                self.log(f"{mode}/p_norm_std", norms.std(), on_epoch=True)

        # metrics
        preds = sim.argmax(dim=1)
        top1 = (preds == labels).float().mean()   # top1: true positive is most similar to anchor 
        top5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(dim=1).float().mean() # top5: true positive is atleast in the top 5 most similar to anchor 

        self.log(f"{mode}/nll_loss", loss, on_epoch=True, prog_bar=True)
        self.log(f"{mode}/top1", top1, on_epoch=True, prog_bar=True)
        self.log(f"{mode}/top5", top5, on_epoch=True, prog_bar=True)

        return loss
    
    def training_step(self, batch):
        return self.info_nce_loss(batch, mode='train')

    def validation_step(self, batch):
        self.info_nce_loss(batch, mode='val')

    def debugging(self): 
        for name, param in self.beta_net.named_parameters():
            print(f"{name} device:", param.device)

In [57]:

x = torch.as_tensor([[1,-1, 1, 0.4], [1,2,-0.1, 0.2]])
with torch.no_grad():
    z = cl_model(x)
#print("z representation:", z)

#bm = LearnedBetaModel(cmhn=mhn, beta_max=200, device="cuda")
#loss = bm.loss(z)
#loss



In [56]:


ep = MINARI_DATASET.sample_episodes(n_episodes=1)
episodeData = ep[0]
obs = episodeData.observations["observation"]
obs = torch.as_tensor(obs, dtype=torch.float32)


print(obs.size())
with torch.no_grad(): 
    z = cl_model(obs)

print(z.size()) 

loss = bm.loss(z)
print(loss)

torch.Size([491, 4])
torch.Size([491, 32])
tensor(6.6316, device='cuda:0', grad_fn=<NllLossBackward0>)


# Trainer


In [None]:
def train_beta_model(cl_model, train_ds, val_ds, batch_size, logger, checkpoint_path, max_epochs=1000, device="cpu", filename= "best_model", **kwargs):
    # Create model checkpoints based on the top5 metric
    filename = kwargs.pop("filename", filename) 
    
    checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path,
                                      filename=filename, 
                                      save_top_k=3, 
                                      save_weights_only=True, 
                                      mode="max",
                                      monitor="val/top5")
    
    trainer = pl.Trainer(
        default_root_dir=checkpoint_path, 
        logger = logger,
        accelerator= "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu", 
        devices=1, 
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback,
                   LearningRateMonitor("epoch")]) # creates a model checkpoint when a new max in val/top5 has been reached 
    train_loader = data.DataLoader(dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = data.DataLoader(dataset= val_ds, batch_size=batch_size, shuffle=False, drop_last=False)
    pl.seed_everything(10)
    model = cl_model(max_epochs=max_epochs, device=device, **kwargs) 
    trainer.fit(model, train_loader, val_loader)

    print("Best model path:", checkpoint_callback.best_model_path)
    model = cl_model.load_from_checkpoint(checkpoint_callback.best_model_path)
    
    return model 

# Model 

### Configs

In [None]:
config = {
    "distribution": "g",
    "batch_size": 256,
    "k": 2,
    "lr": 5e-4,
    "weight_decay": 1e-4, 
    "temperature": 0.08,
    "max_epochs": 10
}

wandb_logger = WandbLogger(
    project="Contrastive Learning RL", 
    name="test-run-new-infoNCE-loss", 
    save_dir = project_root, 
    log_model=True,
    config = config
) 

dist = config["distribution"]
batch_size = config["batch_size"]
k = config["k"]
lr = config["lr"]
weight_decay = config["weight_decay"]
temperature = config["temperature"]
max_epochs = config["max_epochs"]

### Datasets

In [None]:
# Train and val episodes

In [None]:




model = train_cl(train_ds=train_dataset, 
                val_ds=val_dataset, 
                batch_size=batch_size,
                logger=wandb_logger, 
                max_epochs=max_epochs, 
                lr=lr, 
                temperature=temperature, 
                weight_decay = weight_decay)