In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import minari

import faiss 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

from src.models.cl_model import mlpCL 
#from src.models.cmhn import cmhn 

from src.utils.sampling_states import sample_states

import pytorch_lightning as pl

print(project_root)

c:\Users\ray\Documents\2025 RA\contrastive-learning-RL


In [2]:
"""
MADE SOME CHANGES TO CMHN AND BETA MODEL: check chatgpt "gradient flow analysis" for most of the changes 
- removed using cmhn on z to get z' 
- normalized dropout on z, changes the magnitude but doesn't affect direction and thus also the cos_sim calcs
- in CMHN, changed the __build_index to use inner product (IP) and normalized the input argument 
- in CMHN, stabilized the softmax at the end with logits - logits.max

TRAIN WITH AN IPYNB FILE (copy stuff from beta_main.ipynb) to see results 
- in CMHN, changed to use torch.bmm for a more efficient computation instead of my own cos sim computation 
- removed use_gpu in CMHN args 
- changed pairing (u,z') to (u, u') for more representative learning 
- combined getting u and u' by concatenating batch and noisy batch into the cmhn (faster)
"""

'\nMADE SOME CHANGES TO CMHN AND BETA MODEL: check chatgpt "gradient flow analysis" for most of the changes \n- removed using cmhn on z to get z\' \n- normalized dropout on z, changes the magnitude but doesn\'t affect direction and thus also the cos_sim calcs\n- in CMHN, changed the __build_index to use inner product (IP) and normalized the input argument \n- in CMHN, stabilized the softmax at the end with logits - logits.max\n\nTRAIN WITH AN IPYNB FILE (copy stuff from beta_main.ipynb) to see results \n- in CMHN, changed to use torch.bmm for a more efficient computation instead of my own cos sim computation \n- removed use_gpu in CMHN args \n- changed pairing (u,z\') to (u, u\') for more representative learning \n- combined getting u and u\' by concatenating batch and noisy batch into the cmhn (faster)\n'

In [3]:
# Import minari dataset
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")
DEVICE = "cuda"

# Load trained CL model 
model_name = "laplace_cos_sim-v1.ckpt"
pretrained_model_file = os.path.join(project_root+ "/trained_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device(DEVICE))

Found pretrained model at c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/trained_models\laplace_cos_sim-v1.ckpt, loading...


In [4]:
def is_normalized(v, tol=1e-6):
    return (v.norm(dim=-1) - 1).abs() < tol

In [14]:
class cmhn(): 
    def __init__(self, max_iter = 100, threshold = 0.95, topk = 512, device="cpu"):
        """
        Continuous Modern Hopfield Network 

        Args: 
            update_steps: The number of iterations the cmhn will do. (Usually just one).
            topk: Using faiss, only the top k most similar patterns will be used. (more efficient in batch-wise updates) 
            use_gpu: Tells faiss if we use faiss-cpu or faiss-gpu for behind the scenes calculations. 
            device: The device that torch will use. 
        """
        self.max_iter = max_iter 
        self.threshold = threshold
        self.topk = topk

        self.device = torch.device(device)
        self.index = None 

    def build_index(self, X, d): 
        """
        Builds a faiss index (an object) for efficient searching of top-k patterns from X (on cpu). 
        """
        X_np = X.detach().cpu().numpy().astype("float32") # convert X from tensor to numpy 

        self.index = faiss.IndexFlatIP(d)
        
        self.index.add(X_np)
    
    def __update(self, X, xi, beta): 
        """
        The update rule for a continuous modern hopfield network. 

        Args: 
            X: The stored patterns. X is of size [N, d], where N is the number of patterns, and d the size of the patterns. 
            xi: The state pattern (ie. the current pattern being updated). xi is of size [d, 1]. 
            beta: The scalar inverse-temperature hyperparamater. Controls the number of metastable states that occur in the energy landscape. 
                - High beta corresponds to low temp, more separation between patterns.  
                - Low beta corresponds to high temp, less separation (more metastable states). 
        """
        X_norm = F.normalize(X, p=2, dim=1)
        xi_norm = F.normalize(xi, p=2, dim=0)
        sims = X_norm @ xi_norm  # simularity between stored patterns and current pattern 
        p = F.softmax(beta * sims, dim=0, dtype=torch.float32)  # softmax dist along patterns (higher probability => more likely to be that stored pattern)
        # p of size [N, 1] 

        X_T = X_norm.transpose(0, 1) 
        xi_new = X_T @ p  # xi_new, the updated state pattern; size [d, 1]
        return xi_new

    def __run_batch(self, X, queries, beta=None): 
        """
        Runs the mhn batch-wise for efficient computation. 

        Args: 
            X: Stored patterns, size [N, d].
            queries: Input queries, size [N, d].
            beta: The beta value per sample, size [N].
        """        
        
        assert beta != None, "Must have a value for beta." 

        # normalize for cos sim calcs
        X_norm = F.normalize(X, p=2, dim=-1)
        queries_norm = F.normalize(queries, p=2, dim=-1)

        with torch.no_grad():
            queries_np = queries_norm.detach().cpu().numpy().astype("float32")
            _, indices = self.index.search(queries_np, self.topk)
            indices = torch.from_numpy(indices).to(X.device) # indices of shape [N, topk]

        topk_X = X_norm[indices] # size [N, topk, d] 
        topk_q = queries_norm.unsqueeze(1) # change queries from [N, d] to [N, 1, d] for broadcasting
        
        # dot product of x_ij * q_i along "d dim" to obtain tensor of [N, topk]
        # q_i represents the i'th query
        # x_ij represents the corresponding i'th query and j'th pattern, where j is among the topk 
        # then sum over d to obtain the similarity between row i and col j. 
        # sims = torch.sum(topk_X * topk_q, dim=-1) 

        # USE torch.bmm instead of the above comments for more efficient computation (they do the same thing tho) 
        sims = torch.bmm(topk_q, topk_X.transpose(1,2)).squeeze(1)

        # removing beta broadcasting
        #beta = beta.view(-1, 1)  # beta: [N, 1], broadcasting beta. 
        logits = beta * sims       # sims * beta: [N, topk]
        logits_max = torch.max(logits, dim=-1, keepdim=True)[0]
        probs = F.softmax(logits - logits_max.detach(), dim=-1)   # calculate probs along patterns (NOT queries) ie. along topk --> [N, topk]

        # weighted sum over topk_X: x_ij * probs_i
        xi_new = torch.sum(probs.unsqueeze(-1) * topk_X, dim=1)

        return xi_new
    
    def __has_converged(self, old_xi, new_xi): 
        """ 
        Checks whether or not the hopfield network has converged. Convergence is measured through taking the average cosine similarity 
        between old_xi and new_xi. If this average meets the threshold (ie. avg_cos_sim >= threshold), then we say that old_xi and 
        new_xi are the same and that the hopfield network has converged. 

        old_xi and new_xi are shapes: [N, d]

        Args: 
            old_xi xi before running the udpate rule.
            new_xi: xi after running the update rule.
        
        Returns:
            True: if the average cosine similarity between old_xi and new_xi is meets the threshold.
            False: if the average cosine similarity between old_xi and new_xi is below the threshold.
        """
        converged = False 

        old_norm = F.normalize(old_xi, p=2, dim=-1)  # normalize along rows
        new_norm = F.normalize(new_xi, p=2, dim=-1)

        cos_sim = torch.sum(old_norm * new_norm, dim=1)  # [N], similarity for each query

        min_cos_sim = cos_sim.min().item()
        if min_cos_sim >= self.threshold:
            converged = True

        return converged

    def run(self, X, xi, beta=None, run_as_batch=False): 
        """
        Runs the network. 

        Args: 
            X: The stored patterns. X is of size [N, d], where B is the batches, N is the number of patterns, and d the size of the patterns. 
            xi: The state pattern (ie. the current pattern being updated). xi is of size [d, 1]. xi can also be a batch of queries [N, d].
            beta: The scalar inverse-temperature hyperparamater. Controls the number of metastable states that occur in the energy landscape. 
                - High beta corresponds to low temp, more separation between patterns.  
                - Low beta corresponds to high temp, less separation (more metastable states). 
        """
        assert beta != None, "Must have a value for beta."

        if not isinstance(beta, torch.Tensor):
           beta = torch.as_tensor(beta, dtype=torch.float32)

        X = X.to(self.device)
        xi = xi.to(self.device)
        beta = beta.to(self.device)

        if run_as_batch: 
            if xi.dim() == 1: 
                raise ValueError("Query shape should be [N, d] when updating as a batch.")
            for _ in range(self.max_iter): 
                old_xi = xi.clone()
                xi = self.__run_batch(X, xi, beta)

                if self.__has_converged(old_xi=old_xi, new_xi=xi): 
                    break 
            return xi

        else:
            # if xi is of size [d], then change to [d, 1] 
            if xi.dim() == 1: 
                xi = xi.unsqueeze(1) #[d, 1]
            elif xi.dim() == 2 and xi.size(1) != 1: 
                raise ValueError("Query shape should be [d] or [d, 1].") 

            for _ in range(self.max_iter): 
                xi = self.__update(X, xi, beta)
            return xi 

In [12]:
class LearnedBetaModel(pl.LightningModule): 
    def __init__(self, cmhn, beta_max=200, lr=1e-3, weight_decay=1e-5, temperature=1, masking_ratio=0.3, max_epochs=1000, input_dim=32, h1=128, h2=32, fc_h1 = 256, fc_h2 = 128, fc_h3 = 64, device="cpu"):
        super().__init__() 
        self.save_hyperparameters()
        self.cmhn = cmhn 
        self.device_type = torch.device(device=device)

        self.dropout = nn.Dropout(p=masking_ratio, inplace=False)

        self.beta_net = nn.Sequential(
            nn.Linear(input_dim, h1),
            nn.ReLU(), 

            nn.Linear(h1, h2), 
            nn.ReLU(),

            nn.Linear(h2, 1),
            nn.Softplus() 
            #nn.Sigmoid()       12/20/2025: removing sigmoid for softplus + clamping, because sigmoid is too steep and can cause beta to easily max out  
        ).to(self.device_type)
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(params=self.parameters(), 
                                lr= self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, 
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr / 50)
        return ([optimizer], [lr_scheduler])

    def loss(self, batch, mode="train"): 
        """
        The loss function for the beta network. 

        Args: 
            batch: The batch data that the beta network will use (z representations). 
        
        Returns: 
            loss: The infoNCE loss. 
        """
        batch_norm = F.normalize(batch, p=2, dim=-1)        
        N, d = batch_norm.shape

        # build the faiss index beforehand
        self.cmhn.build_index(X=batch_norm, d=d)

        # get the trial beta 
        beta = torch.mean(self.beta_net(batch_norm)) * self.hparams.beta_max
        
        beta = beta.clamp(max = self.hparams.beta_max)

        # get abstract representation 'u' and normalize 
        #u = F.normalize(self.cmhn.run(batch_norm, batch_norm, beta, run_as_batch=True) , p=2, dim=-1)

        # get the noisy batch, nn.Dropout uses scaling=True to maintain expected value of tensor
        z_prime = F.normalize(self.dropout(batch), p=2, dim=-1) # adding normalization to dropout, since cos_sim doesn't care about vector magnitude 

        #get u' from z' and normalize
        #u_prime = F.normalize(self.cmhn.run(batch_norm, z_prime, beta, run_as_batch=True), p=2, dim=-1)  # removing this as a test

        # concat batch_norm and z_prime together to pass them both into one run of the cmhn (faster than two separate runs of cmhn and stronger backprop signal)
        queries_combined = torch.cat([batch_norm, z_prime], dim=0)
        u_combined = self.cmhn.run(batch_norm, queries_combined, beta, run_as_batch=True)
        u = u_combined[:N]
        u_prime = u_combined[N:]

        # create positive pairs
        p = torch.cat([u, u_prime], dim=0)

        # put new batch pairs into fc_nn to obtain vectors in new embedding space useful for contrastive learning 
        ######################################################################
        #     use p for contrastive loss 
        ######################################################################

        N = p.size(0) // 2
        p_norm = F.normalize(p, p=2, dim=-1)  
        
        sim = torch.matmul(p_norm, p_norm.T) / self.hparams.temperature # cosine sim matrix [2N, 2N]
        if mode=="train": 
            with torch.no_grad():
                self.log(f"{mode}/sim_mean", sim.mean(), on_epoch=True)
                self.log(f"{mode}/sim_std", sim.std(), on_epoch=True)
                sim_xy = torch.mean(torch.sum(batch_norm * u, dim=-1))
                self.log(f"{mode}/sim_xy", sim_xy, on_epoch=True)

        # mask diagonals to large negative numbers so we don't calculate same state similarities
        mask = torch.eye(2 * N, device=sim.device).bool()
        sim = sim.masked_fill_(mask, -9e15)

        # positives: i-th sample matches i + N mod 2N
        labels = (torch.arange(2 * N, device=sim.device) + N) % (2 * N)

        loss = F.cross_entropy(sim, labels) # over mean reduction 

        # extra statistics 
        if mode=="train": 
            with torch.no_grad(): 
                self.log(f"{mode}/p_norm_mean", p_norm.mean(), on_epoch=True)
                self.log(f"{mode}/p_norm_std", p_norm.std(), on_epoch=True)
                self.log(f"{mode}/beta", beta.item(), on_epoch=True)

                self.log(f"{mode}/U_norm_mean", u.mean(), on_epoch=True)
                self.log(f"{mode}/U_norm_std", u.std(), on_epoch=True)
                self.log(f"{mode}/U_norm_max", u.max(), on_epoch=True)

        # metrics
        preds = sim.argmax(dim=1)
        top1 = (preds == labels).float().mean()   # top1: true positive is most similar to anchor 
        top5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(dim=1).float().mean() # top5: true positive is atleast in the top 5 most similar to anchor 

        self.log(f"{mode}/nll_loss", loss, on_epoch=True, prog_bar=True)
        self.log(f"{mode}/top1", top1, on_epoch=True, prog_bar=True)
        self.log(f"{mode}/top5", top5, on_epoch=True, prog_bar=True)

        return loss
    
    def training_step(self, batch):
        return self.loss(batch, mode='train')

    def validation_step(self, batch):
        self.loss(batch, mode='val')

    def get_beta(self, batch): 
        """Returns the beta value."""
        beta = self.beta_net(F.normalize(batch, dim=-1))
        beta = (torch.mean(beta) * self.hparams.beta_max).clamp(max=self.hparams.beta_max)
        return beta 

In [7]:
# Load cmhn model 
mhn = cmhn(max_iter=1, threshold=0, device=DEVICE)

In [8]:
data = sample_states(dataset=MINARI_DATASET, num_states=10)
states = data["states"]

states = torch.as_tensor(states, dtype=torch.float32)
with torch.no_grad(): 
    z = cl_model(states) 


In [9]:
'''
bm = LearnedBetaModel(cmhn=mhn, beta_max=200, device=DEVICE)
_, d = z.shape 
z_norm = F.normalize(z, p=2, dim=-1)
mhn.build_index(X=z_norm, d=d)
loss = bm.loss(z)
loss 
'''


'\nbm = LearnedBetaModel(cmhn=mhn, beta_max=200, device=DEVICE)\n_, d = z.shape \nz_norm = F.normalize(z, p=2, dim=-1)\nmhn.build_index(X=z_norm, d=d)\nloss = bm.loss(z)\nloss \n'

In [13]:
from src.data.StatesDataset import StatesDataset

bm = LearnedBetaModel(cmhn=mhn, beta_max=200, device=DEVICE)

train_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=states)
train_loader = torch.utils.data.DataLoader(dataset=train_ds, batch_size=5, shuffle=True, drop_last=True)
torch.autograd.set_detect_anomaly(True)

bm.train()
batch = next(iter(train_loader)).to(DEVICE)

loss = bm.loss(batch)
bm.zero_grad(set_to_none=True)
loss.backward()

def grad_stat(p):
    if p.grad is None:
        return None
    return (p.grad.norm().item(), p.grad.abs().mean().item(), p.grad.abs().max().item())

print("loss:", loss.item())

# beta_net grads
for name, p in bm.beta_net.named_parameters():
    print("beta_net", name, grad_stat(p))

print(bm.get_beta(batch))

loss: 2.19704270362854
beta_net 0.weight (0.00012130623508710414, 8.010457577256602e-07, 1.9972449081251398e-05)
beta_net 0.bias (0.00012225241516716778, 6.079752893128898e-06, 4.129078661208041e-05)
beta_net 2.weight (0.0003270024317316711, 1.9494175376166822e-06, 4.903965600533411e-05)
beta_net 2.bias (0.00029898397042416036, 3.445614129304886e-05, 0.00011448246368672699)
beta_net 4.weight (0.00019059520855080336, 2.0398505512275733e-05, 0.00010913576988968998)
beta_net 4.bias (0.0006565195508301258, 0.0006565195508301258, 0.0006565195508301258)
tensor(140.4356, device='cuda:0', grad_fn=<ClampBackward1>)


c:\Users\ray\AppData\Local\anaconda3\envs\CL_RL_gpu\lib\site-packages\pytorch_lightning\core\module.py:449: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [17]:
import pytorch_lightning as pl
from pytorch_lightning.profilers import PyTorchProfiler

from src.data.StatesDataset import StatesDataset

# 1. Instantiate the profiler
# We'll use the built-in PyTorch profiler. 
# We'll save the results to the 'lightning_logs' directory for TensorBoard.
profiler = PyTorchProfiler(
    dirpath="lightning_logs",
    filename="cmhn_profile",
    # We only need a few steps to check the gradients
)

# 2. Instantiate the Trainer
# Set a limit on the number of batches to run so the profiler finishes quickly.
trainer = pl.Trainer(
    profiler=profiler,
    limit_train_batches=10,  # Run just enough batches to hit the profiling schedule
    max_epochs=1,
    # Add other necessary args like devices=1, accelerator='cuda' if using GPU
)


train_ds = StatesDataset(cl_model=cl_model, minari_dataset=MINARI_DATASET, data=states)
train_loader = torch.utils.data.DataLoader(dataset=train_ds, batch_size=5, shuffle=True, drop_last=True)


# 3. Run training
print("Starting training with profiler...")
trainer.fit(bm, train_loader)
print("Profiling complete.")

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\ray\AppData\Local\anaconda3\envs\CL_RL_gpu\lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params | Mode 
------------------

Starting training with profiler...
Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 11.31it/s, v_num=11, train/nll_loss_step=2.150, train/top1_step=0.500, train/top5_step=1.000]

c:\Users\ray\AppData\Local\anaconda3\envs\CL_RL_gpu\lib\site-packages\pytorch_lightning\profilers\pytorch.py:467: The PyTorch Profiler default schedule will be overridden as there is not enough steps to properly record traces.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 10.99it/s, v_num=11, train/nll_loss_step=2.170, train/top1_step=0.400, train/top5_step=0.700, train/nll_loss_epoch=2.160, train/top1_epoch=0.450, train/top5_epoch=0.850]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00,  9.99it/s, v_num=11, train/nll_loss_step=2.170, train/top1_step=0.400, train/top5_step=0.700, train/nll_loss_epoch=2.160, train/top1_epoch=0.450, train/top5_epoch=0.850]
Profiling complete.
