In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import minari

import torch 
import torch.utils.data as data
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

from models.cl_model import mlpCL 
from models.beta_model import LearnedBetaModel
from models.cmhn import cmhn 

from data.StatesDataset import StatesDataset

from utils.sampling_states import sample_states
from utils.tensor_utils import split_data


In [2]:
# Import minari dataset
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")

# Load cmhn model 
DEVICE = "cpu"

# Load trained CL model 
cl_model_name = "best_model_laplace_15.ckpt"
pretrained_cl_model_file = os.path.join(project_root+ "/best_models", cl_model_name) 

if os.path.isfile(pretrained_cl_model_file): 
    print(f"Found pretrained model at {pretrained_cl_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_cl_model_file, map_location=torch.device(DEVICE))

# Load beta model
bm_model_name = "test_run_model.ckpt"
pretrained_bm_model_file = os.path.join(project_root+ "/saved_beta_models", bm_model_name) 

if os.path.isfile(pretrained_bm_model_file): 
    print(f"Found pretrained model at {pretrained_bm_model_file}, loading...") 
    beta_model = LearnedBetaModel.load_from_checkpoint(pretrained_bm_model_file, map_location=torch.device(DEVICE))

# Load cmhn model
cmhn = cmhn(update_steps=1, device=DEVICE)


Found pretrained model at c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/best_models\best_model_laplace_15.ckpt, loading...
Found pretrained model at c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/saved_beta_models\test_run_model.ckpt, loading...


In [10]:
# Sample all states from the dataset
states = sample_states(dataset=MINARI_DATASET, num_states=50000)

In [11]:
# Convert states to z representations
states = torch.as_tensor(states, dtype=torch.float32)
with torch.no_grad(): 
    z = cl_model(states)

In [12]:
fixed_beta = torch.as_tensor([20, 25, 35], dtype=torch.float32)
fixed_beta

tensor([20., 25., 35.])

In [13]:
# Get u values. 
beta = fixed_beta[0]
U = cmhn.run(X=z, xi=z, beta=beta, run_as_batch=True)


In [15]:
U.size()

torch.Size([50000, 32])

In [None]:
tol = 0.999

normalized = F.normalize(U, dim=1)
unique = []

for i in range(normalized.shape[0]):
    pi = normalized[i]
    if all(torch.dot(pi, pj) < tol for pj in unique):
        unique.append(pi)

a = torch.stack(unique)


In [17]:
a

tensor([[ 0.1131, -0.1111,  0.0443,  ..., -0.1197, -0.3092,  0.1489],
        [ 0.1109, -0.0642,  0.0534,  ..., -0.1024, -0.3535,  0.1274],
        [ 0.1383, -0.0927,  0.0471,  ..., -0.1190, -0.3667,  0.0816],
        ...,
        [-0.0792,  0.1686,  0.1318,  ...,  0.1148, -0.2005,  0.2910],
        [-0.0725,  0.2215,  0.1738,  ...,  0.0928, -0.1811,  0.3133],
        [-0.0974,  0.1816,  0.1640,  ...,  0.0348, -0.1768,  0.3348]])

In [18]:
a.size()

torch.Size([7147, 32])