In [23]:
import os 
import sys 
import platform
use_mac_workaround = platform.system() == "Darwin"  # True on macOS

if use_mac_workaround:
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
    os.environ["OMP_NUM_THREADS"] = "1"
    torch.set_num_threads(1)
    faiss.omp_set_num_threads(1)

project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import torch 
import faiss 
import numpy as np
import torch.nn.functional as F

import psutil

from models.cl_model import mlpCL 
import minari 
from utils.sampling_states import sample_states


In [24]:
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")
DEVICE = "cpu"
PROJECT_ROOT = project_root
TOTAL_STATES = 1_000_000

# Load trained CL model 
model_name = "best_model_laplace_15.ckpt"
pretrained_model_file = os.path.join(project_root+ "/saved_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device(DEVICE))
else:
    print("Did not find a model.")

Found pretrained model at /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/saved_models/best_model_laplace_15.ckpt, loading...


In [25]:
class cmhn(): 
    def __init__(self, update_steps = 1, topk = 256, use_gpu = False, device="cpu"):
        """
        Continuous Modern Hopfield Network 

        Args: 
            update_steps: The number of iterations the cmhn will do. (Usually just one).
            topk: Using faiss, only the top k most similar patterns will be used. (more efficient in batch-wise updates) 
            use_gpu: Tells faiss if we use faiss-cpu or faiss-gpu for behind the scenes calculations. 
            device: The device that torch will use. 
        """
        self.update_steps = update_steps 
        self.topk = topk
        self.use_gpu = use_gpu
        self.device = torch.device(device)
        self.index = None 

    def __build_index(self, X, d): 
        """
        Builds a faiss index (an object) for efficient searching of top-k patterns from X. 
        """
        X_np = X.detach().cpu().numpy().astype("float32") # convert X from tensor to numpy 

        if self.use_gpu: 
            flat_index = faiss.IndexFlatL2(d) 
            self.index = faiss.index_cpu_to_all_gpus(flat_index)
        else: 
            self.index = faiss.IndexFlatL2(d)
        
        self.index.add(X_np)
    
    def __update(self, X, xi, beta): 
        """
        The update rule for a continuous modern hopfield network. 

        Args: 
            X: The stored patterns. X is of size [N, d], where N is the number of patterns, and d the size of the patterns. 
            xi: The state pattern (ie. the current pattern being updated). xi is of size [d, 1]. 
            beta: The scalar inverse-temperature hyperparamater. Controls the number of metastable states that occur in the energy landscape. 
                - High beta corresponds to low temp, more separation between patterns.  
                - Low beta corresponds to high temp, less separation (more metastable states). 
        """
        X_norm = F.normalize(X, p=2, dim=1)
        xi_norm = F.normalize(xi, p=2, dim=0)
        sims = X_norm @ xi_norm  # simularity between stored patterns and current pattern 
        p = F.softmax(beta * sims, dim=0, dtype=torch.float32)  # softmax dist along patterns (higher probability => more likely to be that stored pattern)
        # p of size [N, 1] 

        X_T = X_norm.transpose(0, 1) 
        xi_new = X_T @ p  # xi_new, the updated state pattern; size [d, 1]
        return xi_new

    def __run_batch(self, X, queries, beta=None): 
        """
        Runs the mhn batch-wise for efficient computation. 

        Args: 
            X: Stored patterns, size [N, d].
            queries: Input queries, size [N, d].
            beta: The beta value per sample, size [N].
        """        
        
        assert beta != None, "Must have a value for beta." 
        assert X.shape == queries.shape, "X and queries must be the same shape! (N, d)."
        N, d = X.shape 
        self.__build_index(X, d) 

        queries_np = queries.detach().cpu().numpy().astype("float32")
        distances, indices = self.index.search(queries_np, self.topk)
        
        queries = torch.from_numpy(queries_np).to(X.device)
        indices = torch.from_numpy(indices).to(X.device) # indices of shape [N, topk]

        topk_X = X[indices] # size [N, topk, d] 
        topk_q = queries.unsqueeze(1) # change queries from [N, d] to [N, 1, d] for broadcasting
        
        # dot product of x_ij * q_i along "d dim" to obtain tensor of [N, topk]
        # q_i represents the i'th query
        # x_ij represents the corresponding i'th query and j'th pattern, where j is among the topk 
        # then sum over d to obtain the similarity between row i and col j. 
        sims = torch.sum(topk_X * topk_q, dim=-1) 

        beta = beta.view(-1, 1)  # beta: [N, 1], broadcasting beta. 
        sims = beta * sims       # sims * beta: [N, topk]
        probs = F.softmax(sims, dim=-1) # calculate probs along patterns (NOT queries) ie. along topk --> [N, topk]
        
        # weighted sum over topk_X: x_ij * probs_i
        xi_new = torch.sum(probs.unsqueeze(-1) * topk_X, dim=1) 

        return xi_new

    def run(self, X, xi, beta=None, run_as_batch=False): 
        """
        Runs the network. 

        Args: 
            X: The stored patterns. X is of size [N, d], where B is the batches, N is the number of patterns, and d the size of the patterns. 
            xi: The state pattern (ie. the current pattern being updated). xi is of size [d, 1]. xi can also be a batch of queries [N, d].
            beta: The scalar inverse-temperature hyperparamater. Controls the number of metastable states that occur in the energy landscape. 
                - High beta corresponds to low temp, more separation between patterns.  
                - Low beta corresponds to high temp, less separation (more metastable states). 
        """
        assert beta != None, "Must have a value for beta."

        #if not isinstance(beta, torch.Tensor):
        #   beta = torch.as_tensor(beta, dtype=torch.float32)

        X = X.to(self.device)
        xi = xi.to(self.device)
        beta = beta.to(self.device)

        if run_as_batch: 
            if xi.dim() == 1: 
                raise ValueError("Query shape should be [N, d] when updating as a batch.")
            
            for _ in range(self.update_steps): 
                xi = self.__run_batch(X, xi, beta)
            return xi
        
        else:
            # if xi is of size [d], then change to [d, 1] 
            if xi.dim() == 1: 
                xi = xi.unsqueeze(1) #[d, 1]
            elif xi.dim() == 2 and xi.size(1) != 1: 
                raise ValueError("Query shape should be [d] or [d, 1].") 

            for _ in range(self.update_steps): 
                xi = self.__update(X, xi, beta)
            return xi 

In [26]:
d = sample_states(MINARI_DATASET, TOTAL_STATES,)

In [27]:
states = d["states"]

# Subsample from the states array so there isn't so much clutter on visualization
#idx = np.random.choice(np.arange(TOTAL_STATES), size=100_000, replace=False)
#new_states = states[idx] 

subsampled_states = []
idx = np.random.choice(np.arange(TOTAL_STATES), size=50_000, replace=False)
subsampled_states = states[idx]  # [N, 4]

In [28]:
with torch.no_grad(): 
    z = cl_model(torch.as_tensor(subsampled_states, dtype=torch.float32))

# z of size [N, 32]

print(subsampled_states.shape)
print(z.size())

(50000, 4)
torch.Size([50000, 32])


In [29]:
chn = cmhn(update_steps=1, topk=256, use_gpu = False, device = DEVICE) 

In [30]:
beta = torch.as_tensor(35.0, dtype=torch.float32)
print(f"Memory usage: {psutil.Process(os.getpid()).memory_info().rss / 1e6} MB")


Memory usage: 1460.96128 MB


In [31]:
with torch.autograd.set_detect_anomaly(True): 
    U = chn.run(z, z, beta, run_as_batch=True)

In [33]:
U.size()

torch.Size([50000, 32])

In [42]:
def remove_dupes(x, k=1000, threshold=0.99): 
    x = x / np.linalg.norm(x, axis=1, keepdims=True)  # l2 normalize for cosine sim 

    index = faiss.IndexFlatIP(x.shape[1])
    index.add(x)

    D, I = index.search(x, k+1)

    N = x.shape[0]
    mask = np.ones(N, dtype=bool)
    visited = np.zeros(N, dtype=bool)

    # Finds the most similar vectors and masks them out.
    for i in range(N):
        if visited[i]:
            continue
        neighbors = I[i, 1:]  # skip self-match
        similarities = D[i, 1:]
        for j, sim in zip(neighbors, similarities):
            if sim > threshold:
                mask[j] = False
                visited[j] = True

    return x[mask]

In [47]:
u = remove_dupes(U, k = 5000, threshold= 0.80)

  x = x / np.linalg.norm(x, axis=1, keepdims=True)  # l2 normalize for cosine sim


In [48]:
u.size()

torch.Size([183, 32])

In [37]:
print(u)

tensor([[ -0.1593,   1.8885,   3.8210,   0.2488,   3.0884,   2.6527,  -0.8658,
          -4.8314,  -0.4564,  -0.2040,  -4.2156,  -7.8757,  -4.5268,   3.5428,
           1.6113,  -4.3523,   8.5002,  -6.9940,   5.2944,  -3.7729,   1.7530,
          -2.0040,  -5.4662,   1.4627,   2.1636,  -2.8145,  -3.0785,   1.1415,
          -4.8897, -11.8525,  -4.7602,   3.5978],
        [ -9.0760,   6.4994,  -7.3232,   1.5715,  -7.5245,  -4.3644,  -7.8839,
          -4.1005,   2.9459,   6.0185,  -3.5223,   3.5647,   6.4049,   3.3934,
           3.1928,  -3.8362,  -1.5759,   4.1450,   6.3996,   6.7548,   1.0043,
           2.1827,  -0.5473,  -3.9415,   4.2520,  -4.2653,   2.8234,  -1.9857,
           0.9425,   0.3422,   1.6241,   4.9217],
        [ -4.6427,  -3.3330,  -4.9547,  -5.9786,   0.0777,  -7.2652,   1.2147,
          -1.9109,  -0.5800,  -6.1060,  -0.4300,  -2.4197,  -0.9425,  -5.5252,
          -2.6725,  -1.9154,  -2.1081,   2.5641,  -7.0463,  -1.1112, -11.1808,
          -2.0582,   3.9627,   