In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

import minari

import torch 
import torch.utils.data as data
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

from models.cl_model import mlpCL 
from models.beta_model import LearnedBetaModel
from models.cmhn import cmhn 

from data.StatesDataset import StatesDataset

from utils.sampling_states import sample_states
from utils.tensor_utils import split_data


In [2]:
# Import minari dataset
MINARI_DATASET = minari.load_dataset("D4RL/pointmaze/large-v2")

# Load cmhn model 
DEVICE = "cpu"

# Load trained CL model 
cl_model_name = "best_model_laplace_15.ckpt"
pretrained_cl_model_file = os.path.join(project_root+ "/best_models", cl_model_name) 

if os.path.isfile(pretrained_cl_model_file): 
    print(f"Found pretrained model at {pretrained_cl_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_cl_model_file, map_location=torch.device(DEVICE))

# Load beta model
bm_model_name = "test_run_model.ckpt"
pretrained_bm_model_file = os.path.join(project_root+ "/saved_beta_models", bm_model_name) 

if os.path.isfile(pretrained_bm_model_file): 
    print(f"Found pretrained model at {pretrained_bm_model_file}, loading...") 
    beta_model = LearnedBetaModel.load_from_checkpoint(pretrained_bm_model_file, map_location=torch.device(DEVICE))

# Load cmhn model
cmhn = cmhn(update_steps=1, device=DEVICE)


Found pretrained model at c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/best_models\best_model_laplace_15.ckpt, loading...
Found pretrained model at c:\Users\ray\Documents\2025 RA\contrastive-learning-RL/saved_beta_models\test_run_model.ckpt, loading...


In [3]:
# Sample all states from the dataset
states = sample_states(dataset=MINARI_DATASET, num_states=50000)

In [4]:
# Convert states to z representations
states = torch.as_tensor(states, dtype=torch.float32)
with torch.no_grad(): 
    z = cl_model(states)

In [5]:
fixed_beta = torch.as_tensor([20, 25, 35], dtype=torch.float32)
fixed_beta

tensor([20., 25., 35.])

In [6]:
# Get u values. 
beta = fixed_beta[0]
U = cmhn.run(X=z, xi=z, beta=beta, run_as_batch=True)


In [7]:
U.size()

torch.Size([50000, 32])

In [8]:
tol = 0.999

normalized = F.normalize(U, dim=1)
unique = []

for i in range(normalized.shape[0]):
    pi = normalized[i]
    if all(torch.dot(pi, pj) < tol for pj in unique):
        unique.append(pi)

a = torch.stack(unique)


In [9]:
a

tensor([[ 0.3372, -0.0260,  0.1224,  ...,  0.0443, -0.2981,  0.0308],
        [ 0.3103,  0.0182,  0.0205,  ..., -0.0638, -0.2854,  0.1258],
        [ 0.3117,  0.0150,  0.0026,  ..., -0.0760, -0.2815,  0.1479],
        ...,
        [ 0.3475, -0.0117, -0.0234,  ..., -0.1253, -0.0689,  0.0613],
        [ 0.3089, -0.0647, -0.0893,  ..., -0.1778, -0.1076,  0.0764],
        [ 0.3057, -0.0824, -0.0915,  ..., -0.1871, -0.1153,  0.0561]])

In [11]:
a.shape

torch.Size([12864, 32])

In [12]:
a.shape

torch.Size([12864, 32])