In [1]:
import os 
import sys 
project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) 
if project_root not in sys.path: 
    sys.path.insert(0, project_root)

from utils.tensor_utils import split_data
from models.cl_model import mlpCL

import minari
import torch



In [14]:

class EpisodesDataset(torch.utils.data.Dataset): 
    def __init__(self, cl_model = None, minari_dataset = "D4RL/pointmaze/large-v2", n_episodes=1, episodeData=None):
        """
        Dataset to store the z-representation of the states of their corresponding episodes. 

        Args: 
            cl_model: A pretrained contrastive learning model to encode states to z-representations. 
            minari_dataset: The type of minari dataset to use.
            n_episodes: The number of episodes to sample. 
            episodeData: If you want to import the episodeData from minari into this dataset. 

        """
        assert cl_model != None, "Must input a contrastive learning model to obtain z-representations!"

        self.cl_model = cl_model 

        if episodeData: 
            self.episodeData = episodeData
        else: 
            self.minari_dataset = minari.load_dataset(minari_dataset)
            self.episodeData = self.minari_dataset.sample_episodes(n_episodes=n_episodes) # list of episodes [ep1, ep2, ep3, ...]

        # precompute all z representations and store them 
        self.z_data = [] 
        with torch.no_grad(): 
            for ep in self.episodeData: 
                x = torch.as_tensor(ep.observations["observation"], dtype=torch.float32)
                z = cl_model(x)
                self.z_data.append(z) 

    def __len__(self): 
        """
        Returns the number of episodes in the dataset. 
        """
        return len(self.z_data)

    def __getitem__(self, idx): 
        """
        Returns the z-representation specified by "idx". 
        This sample is the list of states in the form of a tensor. 
        """
        return self.z_data[idx]

        



In [15]:
# Load trained CL model 
model_name = "best_model.ckpt"
pretrained_model_file = os.path.join(project_root+ "/saved_models", model_name) 

if os.path.isfile(pretrained_model_file): 
    print(f"Found pretrained model at {pretrained_model_file}, loading...") 
    cl_model = mlpCL.load_from_checkpoint(pretrained_model_file, map_location=torch.device("cpu"))

Found pretrained model at /Users/ray/Documents/Research Assistancy UofA 2025/Reproduce Paper/contrastive-abstraction-RL/saved_models/best_model.ckpt, loading...


In [10]:
cl_model.device

device(type='cpu')

In [16]:
minari_dataset = minari.load_dataset("D4RL/pointmaze/large-v2")
episodeData = minari_dataset.sample_episodes(3360)

train, val = split_data(episodeData, 0.7)

train_ds = EpisodesDataset(cl_model=cl_model, episodeData=train)
val_ds = EpisodesDataset(cl_model=cl_model, episodeData=val)

train_ds[2].size()

torch.Size([633, 32])

In [17]:
train_ds[1]

tensor([[  4.4051,  -0.3216,   2.7463,  ...,  -6.9307,   2.3275,  14.7764],
        [  5.4566,  -1.6109,   2.5574,  ...,  -6.3021,   2.4560,  14.9638],
        [  5.2345,  -2.7181,   1.7200,  ...,  -6.5735,   2.8840,  14.8254],
        ...,
        [ 12.8086, -27.9627, -23.3460,  ...,  24.1302, -11.9508,  10.4717],
        [ 13.1106, -28.3226, -23.2933,  ...,  23.9085, -12.0890,   9.9810],
        [ 13.6523, -28.4744, -22.7852,  ...,  23.9270, -12.5379,   9.2509]])