In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

from models.cl_model import mlpCL 
from models.cmhn import cmhn 

from data.DatasetCL import DatasetCL 
from data.Sampler import Sampler 
from data.TrajectorySet import TrajectorySet

import pytorch_lightning as pl


In [2]:
# Load cmhn model 
mhn = cmhn(update_steps=1)

# Load trained CL model 
model_name = "best_model.ckpt"
pretrained_model_file = os.path.join(project_root+ "/saved_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file)

Found pretrained model at /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/saved_models/best_model.ckpt, loading...


In [3]:
class LearnedBetaModel(pl.LightningModule): 
    def __init__(self, cmhn, beta_max, lr=1e-3, weight_decay=1e-5, masking_ratio=0.3, max_epochs=1000, input_dim=32, h1=128, h2=32, fc_h1 = 64):
        super().__init__() 
        self.save_hyperparameters()
        self.cmhn = cmhn 

        self.beta_net = nn.Sequential(
            nn.Linear(input_dim, h1),
            nn.ReLU(), 

            nn.Linear(h1, h2), 
            nn.ReLU(),

            nn.Linear(h2, 1),
            nn.Sigmoid() 
        )

        self.fc_nn = nn.Sequential( 
            nn.Linear(input_dim, fc_h1),
            nn.ReLU(), 
            nn.Linear(fc_h1, input_dim)
        )
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return ([optimizer], [lr_scheduler])

    def loss(self, batch, mode="train"): 
        """
        The loss function for the beta network. 

        Args: 
            batch: The batch data that the beta network will use (z representations). 
        
        Returns: 
            loss: The infoNCE loss. 
        """
        # get the trial beta 
        beta = self.beta_net(batch)  

        # get abstract representation 'u' using the hopfield network (from z and beta)
        U = torch.empty((0, 32))
        for i in range(batch.size()[0]): 
            u_i = self.cmhn.run(batch, batch[i, :], beta[i])
            u_i = torch.transpose(u_i, 1, 0)
    
            U = torch.cat([U, u_i], dim=0)

        U2 = self.cmhn.run_batch(batch, batch, beta) 


        return U, U2

In [4]:

x = torch.as_tensor([[1,-1, 1, 0.4], [1,2,-0.1, 0.2]])
with torch.no_grad():
    z = cl_model(x)
print("z representation:", z)
#print("z size:", z.size())

bm = LearnedBetaModel(cmhn=mhn, beta_max=200)
U1, U2 = bm.loss(z)

print(U1)
print(U2)



z representation: tensor([[-14.1995,  16.4025,  -2.2677,  16.8622,   3.8314, -24.4883,  -8.4043,
          19.3913, -18.0510, -16.5705,   0.3289,   9.5356,   0.9896,   5.2742,
          -2.1161, -13.5134, -10.2675,  -8.2921,  13.5510, -21.6915,  13.4183,
           4.1641,  20.5247, -16.7965,  20.0339,   7.7674,   8.0739,  14.6877,
          22.4745,  -2.0929,  -0.2420,  25.1546],
        [ 17.8858, -26.0878,  -6.0880,  -4.0136,  -3.1719,   9.8321,  14.3580,
           4.9684,  -8.7627, -18.6895,  27.3732,   6.5467,  -4.1518,  -3.7496,
         -15.4645,  -2.3632, -13.1712,   7.4133,   8.6765, -28.7746,  -4.9537,
         -22.8542,   9.5407, -17.5092,  -2.9711,  -7.6025,  16.0034,   2.1269,
          23.1665,  10.5346,  -2.3052,  20.7228]])
beta tensor([[0.2980],
        [0.3188]], grad_fn=<SigmoidBackward0>)
1 tensor([[6342.5029, 1773.1714],
        [1773.1714, 6371.3105]])
2 tensor([[1890.1979,  528.4420],
        [ 565.3519, 2031.4069]], grad_fn=<MulBackward0>)
probs tensor([[1., 0.