In [1]:
import os 
import sys

In [2]:
# Set path to parent dir to import personal imports
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from data.TrajectorySet import TrajectorySet
from data.Sampler import Sampler 
from utils.tensor_utils import convert_batch_to_tensor

import torch 
import torch.utils.data as data 
from torch.utils.data import Dataset

import minari 

In [3]:
minari_dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

In [7]:
class DatasetCL(Dataset): 
    def __init__(self, sampler, batch_size: int, k: int = 2): 
        """
        sampler: The Sampler class to sample batches. 
        batch_size: The size of the batch (ie. the number of state pairs)
        k: A hyperparameter that dictates the average number of 
                positive pairs sampled from the same trajectory. The 
                lower the number, the lesser the chance of false negatives. 
        """
        self.sampler = sampler 
        self.batch_size = batch_size 
        self.k = k 

        self.batch = convert_batch_to_tensor(self.sampler.sample_batch(batch_size=self.batch_size, k=self.k))
    
    def __len__(self): 
        return len(self.batch)
    
    def __getitem__(self, index):
        s_i, s_j = self.batch[index] 
        return torch.tensor(s_i, dtype=torch.float32), torch.tensor(s_j, dtype=torch.float32)
    
    def get_batch(self):
        return self.batch

In [13]:
"""
TESTING CELL! 
"""

T = TrajectorySet(dataset=minari_dataset) 

S = Sampler(T, dist="g") 
ds = DatasetCL(sampler=S, batch_size=4, k=2)

batch = ds.get_batch()

batch = torch.cat(batch, dim=0)  

batch


tensor([[ 2.5125, -1.0768, -3.2496, -0.1181],
        [-4.6480, -0.7711,  0.0483, -4.6350],
        [ 0.4696,  0.6035,  0.2569,  4.0585],
        [ 0.3488, -0.0111,  0.9485,  3.9890],
        [ 2.6997, -1.0737, -4.4064, -0.0589],
        [-4.6326, -0.4750, -0.8742, -5.1609],
        [ 0.3191, -0.1569,  0.2807,  3.0689],
        [ 0.3270, -0.5274, -0.2245,  4.8216]])