# Imports


In [3]:
import minari
import torch

In [4]:
dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

# Prior Functionality

### Getting functions/methods from other classes

In [1]:
class TrajectorySet: 
    def __init__(self, dataset): 
        """
        dataset: The minari dataset to use. 
        trajectories: a dictionary housing all of the trajectories. The dictionary structure is: 
            {
                1: [trajectory, length of trajectory]
                2: [ ... ]
                etc...
            } 

        num_trajectories: the number of trajectories currently in the set. 
        """
        self.dataset = dataset

        self.trajectories = {} 
        self.num_trajectories = 0 
    
    def add_trajectory(self, trajectory):
        self.trajectories[self.num_trajectories] = [trajectory, len(trajectory)] 
        self.num_trajectories += 1
    
    def get_num_trajectories(self):
        return self.num_trajectories

    def get_trajectory(self, index): 
        assert index < self.num_trajectories, "Specified index is too large."
        return self.trajectories[index]
    
    def get_trajectory_set(self): 
        return self.trajectories
    
    def get_total_states(self): 
        sum = 0
        for _, v in self.trajectories.items(): 
            sum += v[1]
        return sum 

    def generate_trajectories(self, n_trajectories: int = 2): 
        """
        Generates a specified number of trajectories and saves them into the TrajectorySet class. 

        This runs the scripted agent, where the agent uses a PD controller to follow a 
        path of waypoints generated with QIteration until it reaches the goal.

        Args: 
            n_trajectories: The number of trajectories to generate. 
        """
        ep_data = self.dataset.sample_episodes(n_episodes=n_trajectories) # sample trajectories

        # adds all of the sampled trajectories into the TrajectorySet 
        for i in range(len(ep_data)):
            ep = ep_data[i] 

            # Note: only saving states since we only need state representations in the encoder 
            self.add_trajectory(ep.observations["observation"]) 


# Sampler

### Need to sample and get positive and negative pairs (states). 

In [6]:
class Sampler(): 
    def __init__(self, T: TrajectorySet, dist="g"): 
        """
        T: The Trajectory Set class 
        dist: The distribution used for centering over the anchor state. 
            ['u', 'g', 'l', 'e'] - uniform, gaussian, laplace, exponential
        """

        self.T = T 
        self.dist = dist

    def sample_anchor_state(self, t: list) -> tuple[list, int]: 
        """
        Given a trajectory, we sample the anchor state s_i uniformly. 

        Args: 
            t: The given trajectory we sample from. 

        Returns: 
            A tuple containing [s_i, idx]
            s_i: The state that is sampled, represented as a list of (x,y) coordinates and velocities. 
            idx: The time step of s_i. 
        """
        idx = torch.randint(low=0, high=len(t), size=(1,)).item()
        s_i = t[idx] 
        return [s_i, idx]

    def sample_positive_pair(self, t: list, anchor_state: tuple[list, int]) -> tuple[list, int]: 
        """
        Given the same trajectory that s_i was sampled from, 
        center a gaussian distribution around s_i to get obtain its positve pair: s_j. 
        
        Args: 
            t: The given trajectory, which must be the same as the trajectory that was used to sample the anchor state. 
            anchor_state: The anchor state; a tuple containing [s_i, idx].
            s_i: The state itself.
            idx: The time step of s_i.
            
        
        Return: 
            A tuple containing [s_j, idx]
            s_j: The state that is sampled, represented as a list of (x,y) coordinates and velocities. 
            idx: The time step of s_j.    
        """
        std = 15     # we use 15 to replicate the paper's hyperparams 
        b = 15       # laplace scale hyper param
        gamma = 0.99 # exponential hyper param 

        _, si_idx = anchor_state

        while True: 
            if self.dist == "u": 
                # uniform 
                sj_idx = torch.randint(low=0, high=len(t), size=(1,))
            elif self.dist == "g": 
                # gaussian 
                sj_idx = torch.normal(mean=si_idx, std=std, size=(1,))
            elif self.dist == "l": 
                # laplacian
                sj_idx = torch.distributions.laplace.Laplace(loc=si_idx, scale=b).sample() 
            elif self.dist == "e": 
                # exponential 
                i = int(torch.distributions.exponential.Exponential(rate=gamma).sample()) + 1   # +1 so we don't get an offset of 0
                sj_idx = si_idx + i 
            else: 
                # default to gaussian
                sj_idx = torch.normal(mean=si_idx, std=std, size=(1,))

            sj_idx = int(sj_idx) 

            # Ensures we don't choose an index out of range or the same state. 
            if (sj_idx < len(t)) and (sj_idx > 0) and (sj_idx != si_idx): 
                break 
        
        s_j = t[sj_idx] 

        return [s_j, sj_idx]
    
    def sample_batch(self, batch_size=1024, k=2) -> list[tuple]: 
        """ 
        Creates a batch of anchor states, their positive pairs, and negative pairs. 
        There will be 2(batch_size - 1) amount of negative examples per positive pair.

        Args: 
            batch_size: The size of the batch to be generated.
            k: A hyperparameter that dictates the average number of 
                positive pairs sampled from the same trajectory. The 
                lower the number, the lesser the chance of false negatives. 
        
        Returns: 
            A list of tuples containing the anchor_state and its positive pair. 
            The list is the same length as batch_size. 
        """ 

        batch = [] 

        # Generate trajectory set 
        n_trajectories = batch_size // k
        self.T.generate_trajectories(n_trajectories= n_trajectories)

        for _ in range(batch_size): 
            # Sample anchor state 
            rng = torch.randint(low=0, high=n_trajectories, size=(1,)).item() 
            t = self.T.get_trajectory(index=rng)[0]
            
            anchor_state = self.sample_anchor_state(t) 

            # Sample positive pair 
            positive_pair = self.sample_positive_pair(t, anchor_state=anchor_state)

            # Retrieve states; time-steps aren't necessary. 
            s_i = anchor_state[0]
            s_j = positive_pair[0]

            batch.append([s_i, s_j]) 

        return batch 

In [7]:
"""
TESTING CELL!
"""
T = TrajectorySet(dataset=dataset)

uniform = "u"
gaussian = "g"
laplacian = "l"
exponential = "e"

s = Sampler(T, laplacian)

T.generate_trajectories(n_trajectories=1) 
t = T.get_trajectory(0)[0] 
anchor_state = s.sample_anchor_state(t)
print(f"s_i: {anchor_state[0]}, at time step: {anchor_state[1]}")

positive_pair = s.sample_positive_pair(t=t, anchor_state=anchor_state)
print(f"s_j: {positive_pair[0]}, at time step: {positive_pair[1]}")


s_i: [ 3.10517333 -1.0592575   4.67465834  0.32619211], at time step: 530
s_j: [ 2.13934799 -1.12443064  4.57106431 -0.21759863], at time step: 506


In [8]:
"""
TESTING CELL! 

Testing to see if sample_batch works properly. 
"""

batch = s.sample_batch(batch_size=128, k=2)
print("Batch:", batch)
len(batch)

Batch: [[array([ 0.42347977, -0.21610121, -0.19178915,  2.7706739 ]), array([ 0.41777253, -0.55428052,  0.11493824,  5.071069  ])], [array([2.82581935, 0.81448915, 4.09075007, 0.75110628]), array([3.04735105, 0.85969424, 4.18007384, 0.94973478])], [array([-3.71913882,  2.8601187 ,  2.46014687, -0.44114982]), array([-3.86567585,  2.88328691,  3.53477764, -0.52096189])], [array([-1.65290574,  1.78238473, -0.25230643,  2.99153212]), array([-1.71262417,  2.96956474, -1.15763502,  2.03130241])], [array([ 2.4387001 ,  1.77763996, -0.85396076, -4.08780283]), array([ 2.46807916,  1.95907802, -0.08632228, -2.9374559 ])], [array([-4.62628172, -0.21190188, -0.4183882 , -4.35286083]), array([-4.62221664,  1.59472762,  0.41057932, -5.21585639])], [array([ 4.42589167, -1.10844326,  1.51804618,  1.30979735]), array([ 4.23626319, -1.1542856 ,  3.47317596, -0.03689346])], [array([ 4.42589167, -1.10844326,  1.51804618,  1.30979735]), array([4.41997441, 0.26012585, 0.61158345, 4.86815322])], [array([ 1.0

128