In [1]:
import os 
import sys
# Set path to parent dir to import personal imports
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from data.TrajectorySet import TrajectorySet
from data.Sampler import Sampler 
from utils.tensor_utils import convert_batch_to_tensor

import torch 
import torch.utils.data as data 
from torch.utils.data import Dataset

import minari 

In [2]:
minari_dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

In [3]:
import torch 
from utils.tensor_utils import convert_batch_to_tensor
class DatasetCL(torch.utils.data.Dataset): 
    def __init__(self, sampler = None, num_state_pairs: int = None, k: int = 2): 
        """
        sampler: The Sampler class to sample batches. 
        num_state_pairs: The number of state pairs
        k: A hyperparameter that dictates the average number of 
                positive pairs sampled from the same trajectory. The 
                lower the number, the lesser the chance of false negatives. 
        """
     
        assert sampler != None, "Must have a sampler if you don't have a dataset inputted."
        assert num_state_pairs != None, "Must have a sampled pairs amount if you don't have a dataset inputted."

        self.sampler = sampler 
        self.num_state_pairs = num_state_pairs 
        self.k = k 
        self.pairs = []
        self.sample_pairs()
        
    def __len__(self): 
        return len(self.pairs)
    
    def __getitem__(self, index):
        s_i, s_j = self.pairs[index] 
        return (torch.as_tensor(s_i, dtype=torch.float32), torch.as_tensor(s_j, dtype=torch.float32))
    
    def sample_pairs(self): 
        anchor, positive = convert_batch_to_tensor(self.sampler.sample_batch(batch_size=self.num_state_pairs, k=self.k))
        self.pairs = list(zip(anchor, positive))


In [8]:
"""
TESTING CELL! 
"""

T = TrajectorySet(dataset=minari_dataset) 

S = Sampler(T, dist="l") 
ds = DatasetCL(sampler=S, num_state_pairs=1000, k=1)


#batch = ds.get_batch()
#print(batch)
print(len(ds))
print(ds[0])

ds.sample_pairs()


#DS = data.DataLoader(dataset=ds, batch_size=2, shuffle=True)

#for batch in DS: 
#    print(batch)




resampled pairs!
1000
(tensor([ 0.2935,  0.8678, -1.8925, -1.3350]), tensor([ 0.1639,  0.8273, -3.0553, -0.5356]))
resampled pairs!


In [3]:
class DatasetCL2(torch.utils.data.Dataset): 
    def __init__(self, sampler = None, num_state_pairs: int = None, k: int = 2): 
        """
        sampler: The Sampler class to sample batches. 
        num_state_pairs: The number of state pairs
        k: A hyperparameter that dictates the average number of 
                positive pairs sampled from the same trajectory. The 
                lower the number, the lesser the chance of false negatives. 
        """
     
        assert sampler != None, "Must have a sampler if you don't have a dataset inputted."
        assert num_state_pairs != None, "Must have a sampled pairs amount if you don't have a dataset inputted."

        self.sampler = sampler 
        self.num_state_pairs = num_state_pairs 
        self.k = k 
        self.states = self.sampler.sample_states(batch_size = self.num_state_pairs)
    
    def __len__(self): 
        return len(self.states)

    def __getitem__(self, idx): 
        s_i, t_idx = self.states[idx]
        s_j, _ = self.sampler.sample_positive_pair(t_idx, s_i)
        return (torch.as_tensor(s_i, dtype=torch.float32), torch.as_tensor(s_j, dtype=torch.float32))

    def get_states(self): 
        return self.states

In [5]:
T = TrajectorySet(dataset=minari_dataset) 

S = Sampler(T, dist="l") 
ds = DatasetCL2(sampler=S, num_state_pairs=1_000_000, k=1)

states = ds.get_states()

In [6]:
len(states)

1000000