# Imports


In [3]:
import os 
import sys

In [4]:
# Set path to parent dir to import personal imports
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from data.TrajectorySet import TrajectorySet
from utils.truncated_distributions import truncated_normal
from utils.truncated_distributions import truncated_laplace
from utils.truncated_distributions import truncated_exponential

import torch 
import minari 
import numpy as np 


In [5]:
dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

# Sampler

### Need to sample and get positive and negative pairs (states). 

In [6]:
class Sampler(): 
    def __init__(self, T: TrajectorySet, dist="g", sigma = 15, b = 15, rate = 0.99): 
        """
        T: The Trajectory Set class 
        dist: The distribution used for centering over the anchor state. 
            ['u', 'g', 'l', 'e'] - uniform, gaussian, laplace, exponential
        """

        self.T = T 
        self.dist = dist
        self.total_episodes = T.get_total_episodes() 
        self.T.generate_trajectories(n_trajectories= self.total_episodes)
        
        # Hyperparameters
        self.sigma = sigma
        self.b = b 
        self.rate = rate  


    def sample_anchor_state(self, t_idx: int) -> tuple[list, int]: 
        """
        Given a trajectory, we sample the anchor state s_i uniformly. 

        Args: 
            t: The given trajectory we sample from. 

        Returns: 
            A tuple containing [s_i, idx]
            s_i: The state that is sampled, represented as a list of (x,y) coordinates and velocities. 
            idx: The time step of s_i. 
        """
        trajectory = self.T.get_trajectory(index=t_idx)[0]
        idx = torch.randint(low=0, high=len(trajectory), size=(1,)).item()
        s_i = trajectory[idx] 
        return [s_i, idx]


    def sample_positive_pair(self, t_idx: int, anchor_state: tuple[list, int]) -> tuple[list, int]: 
        """
        Given the same trajectory that s_i was sampled from, 
        center a gaussian distribution around s_i to get obtain its positve pair: s_j. 
        
        Args: 
            t: The given trajectory, which must be the same as the trajectory that was used to sample the anchor state. 
            anchor_state: The anchor state; a tuple containing [s_i, idx].
            s_i: The state itself.
            idx: The time step of s_i.
            
        Return: 
            Returns the positive pair's state and state index.
        """

        _, si_idx = anchor_state
        trajectory = self.T.get_trajectory(index=t_idx)[0]

        if self.dist == "u": 
            # uniform 
            sj_idx = torch.randint(low=0, high=len(trajectory), size=(1,))

        elif self.dist == "g": 
            # gaussian 
            p = truncated_normal(len(trajectory), mu=si_idx, sigma=self.sigma) 
            sj_idx = np.random.choice(a=len(trajectory), p=p)
            
        elif self.dist == "l": 
            # laplacian
            p = truncated_laplace(len=len(trajectory), mu=si_idx, b=self.b)
            sj_idx = np.random.choice(a=len(trajectory), p=p)

        elif self.dist == "e": 
            # exponential 
            p = truncated_exponential(len=len(trajectory), anchor_state_index=si_idx, rate=self.rate)
            sj_idx = np.random.choice(a=len(trajectory), p=p) 

        else: 
            # default to gaussian
            p = truncated_normal(len(trajectory), mu=si_idx, sigma=self.sigma) 
            sj_idx = np.random.choice(a=len(trajectory), p=p)
        
        s_j = trajectory[sj_idx]
        return [s_j, sj_idx]
    

    def sample_batch(self, batch_size=1024,) -> list[tuple]: 
        """ 
        Creates a batch of anchor states, their positive pairs, and negative pairs. 
        There will be 2(batch_size - 1) amount of negative examples per positive pair.

        Args: 
            batch_size: The size of the batch to be generated.
        
        Returns: 
            A list of tuples containing the anchor_state and its positive pair. 
            The list is the same length as batch_size. 
        """ 

        batch = [] 
    
        for _ in range(batch_size): 
            # Sample anchor state 
            t_idx = torch.randint(low=0, high=self.total_episodes, size=(1,)).item() 
            
            anchor_state = self.sample_anchor_state(t_idx) 

            # Sample positive pair 
            positive_pair = self.sample_positive_pair(t_idx, anchor_state=anchor_state)

            # Retrieve states; time-steps aren't necessary. 
            s_i = anchor_state[0]
            s_j = positive_pair[0]

            batch.append([s_i, s_j]) 

        return batch 

    def sample_states(self, batch_size=1024) -> list[tuple]: 
        """ 
        Creates a batch of anchor states, and its corresponding trajectory to use to sample positive pairs.. 
        There will be 2(batch_size - 1) amount of negative examples per positive pair.

        Args: 
            batch_size: The size of the batch to be generated.
            k: A hyperparameter that dictates the average number of 
                positive pairs sampled from the same trajectory. The 
                lower the number, the lesser the chance of false negatives. 
        
        Returns: 
            A list of tuples containing the anchor_state and its corresponding trajectory. 
            The list is the same length as batch_size. 
        """ 

        batch = [] 
            

        for _ in range(batch_size): 
            # Sample anchor state 
            t_idx = torch.randint(low=0, high=self.total_episodes, size=(1,)).item() 
        
            state = self.sample_anchor_state(t_idx) 

            # Retrieve states; time-steps aren't necessary. 
            s_i = state[0]

            batch.append((s_i, t_idx)) 

        return batch

In [11]:
"""
TESTING CELL!
"""
T = TrajectorySet(dataset=dataset)

uniform = "u"
gaussian = "g"
laplacian = "l"
exponential = "e"

s = Sampler(T, laplacian)

batch = s.sample_states(batch_size=1_000_000)
batch

[(array([-4.04910266,  2.90661252,  4.26074936, -0.64447898]), 2324),
 (array([-4.65028084, -0.99642687,  0.40253345, -2.51646003]), 46),
 (array([-3.66225091,  0.98120751, -3.1390291 ,  0.09336701]), 2936),
 (array([-1.51277266,  1.44237359,  1.81603151,  4.75450653]), 1621),
 (array([ 2.76442509, -1.15951479, -4.2820801 , -0.20141582]), 1603),
 (array([-1.62548599,  2.43786906, -0.44980141,  4.94031416]), 2540),
 (array([-3.38059979, -1.92727999, -1.29580134, -1.68853201]), 1898),
 (array([-4.57098969,  0.12958918, -0.01873894, -3.72285325]), 1062),
 (array([ 0.95412365, -1.19430713,  4.6579405 ,  0.83514932]), 2929),
 (array([ 2.5054536 , -1.0468366 , -0.16159242,  2.66673743]), 1080),
 (array([-1.23051673,  0.98821022, -4.24182684,  0.18668688]), 2739),
 (array([ 1.55342202, -1.09336055, -4.09038044,  0.05657038]), 3105),
 (array([ 0.52869882, -2.8363104 , -0.37483146,  3.80365627]), 1402),
 (array([-1.65775694,  3.08084548, -2.33967941,  0.54506146]), 1143),
 (array([-4.52902414, 

In [15]:
"""
TESTING CELL! 

Testing to see if sample_batch works properly. 
"""

batch = s.sample_batch(batch_size=128, k=2)
print("Batch:", batch)
len(batch)

Batch: [[array([ 0.56870167,  2.41734414, -2.1953144 , -2.82081276]), array([ 0.73601277,  2.5516822 , -3.48755141, -1.42048912])], [array([-4.16956237,  2.9079329 , -4.81422535,  0.06935036]), array([-3.36274968,  2.84272199, -3.87854267,  0.11520598])], [array([-3.50391923e+00, -2.70694707e+00, -1.58601976e-03,  3.86680992e+00]), array([-3.56736572, -1.84723318, -1.09876067,  4.06927189])], [array([ 0.45081867,  2.44274512, -0.01286939,  4.71943654]), array([0.45267535, 2.78909559, 0.53293338, 2.39661628])], [array([-1.012238  ,  0.89482413, -5.03012776,  1.13837793]), array([ 0.44761699,  0.5602349 , -0.19576949,  4.7233684 ])], [array([ 4.45895717,  3.00595937, -1.31085017,  1.37946242]), array([ 4.49120896,  2.95730895, -0.60235248,  2.1072741 ])], [array([ 0.7489631 ,  2.93579248,  4.21883557, -0.105629  ]), array([ 0.70677474,  2.93684877,  3.99017488, -0.05372305])], [array([ 2.47168207, -1.57292286, -0.17360297,  4.88771031]), array([ 2.46703662, -1.48244743, -0.18362607,  4.4

128