In [1]:
import os 
import sys

In [2]:
# Set path to parent dir to import personal imports
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from data.TrajectorySet import TrajectorySet
from data.Sampler import Sampler 
from utils.tensor_utils import convert_batch_to_tensor

import torch 
import torch.utils.data as data 
from torch.utils.data import Dataset

import minari 

In [3]:
minari_dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

In [4]:
class DatasetCL(Dataset): 
    def __init__(self, sampler, batch_size: int, k: int = 2): 
        """
        sampler: The Sampler class to sample batches. 
        batch_size: The size of the batch (ie. the number of state pairs)
        k: A hyperparameter that dictates the average number of 
                positive pairs sampled from the same trajectory. The 
                lower the number, the lesser the chance of false negatives. 
        """
        self.sampler = sampler 
        self.batch_size = batch_size 
        self.k = k 

        self.batch = convert_batch_to_tensor(self.sampler.sample_batch(batch_size=self.batch_size, k=self.k))
    
    def __len__(self): 
        return len(self.batch)
    
    def __getitem__(self, index):
        s_i, s_j = self.batch[index] 
        return torch.tensor(s_i, dtype=torch.float32), torch.tensor(s_j, dtype=torch.float32)
    
    def get_batch(self):
        return self.batch

In [7]:
"""
TESTING CELL! 
"""

T = TrajectorySet(dataset=minari_dataset) 

S = Sampler(T, dist="u") 
ds = DatasetCL(sampler=S, batch_size=4, k=4)

batch = ds.get_batch()

batch


(tensor([[-3.6186, -1.9642, -0.1215, -2.8724],
         [-4.5962, -0.8383, -0.3059, -3.3184],
         [-4.3503, -1.0622,  3.2947,  0.1763],
         [-4.4889,  2.9897, -0.0629,  1.1910]]),
 tensor([[-3.7673, -1.0248,  2.9082, -0.0424],
         [-4.4843,  2.9467, -0.1456,  1.9175],
         [-3.6748, -1.0392,  1.9655, -0.7151],
         [-3.6944, -1.0320,  2.2089, -0.4781]]))