# Imports


In [1]:
import minari
import torch

In [2]:
dataset = minari.load_dataset("D4RL/pointmaze/large-v2")

# Prior Functionality

### Getting functions/methods from other classes

In [3]:
class TrajectorySet: 
    def __init__(self): 
        """
        trajectories: a dictionary housing all of the trajectories. The dictionary structure is: 
            {
                1: [trajectory, length of trajectory]
                2: [ ... ]
                etc...
            } 

        num_trajectories: the number of trajectories currently in the set. 
        """

        self.trajectories = {} 
        self.num_trajectories = 0 
    
    def add_trajectory(self, trajectory):
        self.trajectories[self.num_trajectories] = [trajectory, len(trajectory)] 
        self.num_trajectories += 1
    
    def get_num_trajectories(self):
        return self.num_trajectories

    def get_trajectory(self, index): 
        assert index < self.num_trajectories, "Specified index is too large."
        return self.trajectories[index]
    
    def get_trajectory_set(self): 
        return self.trajectories
    
    def get_total_states(self): 
        sum = 0
        for _, v in self.trajectories.items(): 
            sum += v[1]
        return sum 

    def generate_trajectories(self, n_trajectories: int = 2): 
        """
        Generates a specified number of trajectories and saves them into the TrajectorySet class. 

        This runs the scripted agent, where the agent uses a PD controller to follow a 
        path of waypoints generated with QIteration until it reaches the goal.
        """
        ep_data = dataset.sample_episodes(n_episodes=n_trajectories) # sample trajectories

        # adds all of the sampled trajectories into the TrajectorySet 
        for i in range(len(ep_data)):
            ep = ep_data[i] 

            # Note: only saving states since we only need state representations in the encoder 
            self.add_trajectory(ep.observations["observation"]) 

# Sampling 

### Need to sample and get positive and negative pairs (states). 
### Check out Torch distributions page: https://docs.pytorch.org/docs/stable/distributions.html

In [52]:
def sample_anchor_state(t: list) -> tuple[list, int]: 
    """
    Given a trajectory, we sample the anchor state s_i uniformly. 

    Args: 
        t: The given trajectory we sample from. 

    Returns: 
        A tuple containing [s_i, idx]
        s_i: The state that is sampled, represented as a list of (x,y) coordinates and velocities. 
        idx: The time step of s_i. 
    """
    idx = torch.randint(low=0, high=len(t), size=(1,)).item()
    s_i = t[idx] 
    return [s_i, idx]

def sample_positive_pair_gaussian(t: list, anchor_state: tuple[list, int]) -> tuple[list, int]: 
    """
    Given the same trajectory that s_i was sampled from, 
    center a gaussian distribution around s_i to get obtain its positve pair: s_j. 
    
    Reference: https://docs.pytorch.org/docs/stable/generated/torch.normal.html

    Args: 
        t: The given trajectory, which must be the same as the trajectory that was used to sample the anchor state. 
        anchor_state: The anchor state; a tuple containing [s_i, idx].
        s_i: The state itself.
        idx: The time step of s_i.
    
    Return: 
        A tuple containing [s_j, idx]
        s_j: The state that is sampled, represented as a list of (x,y) coordinates and velocities. 
        idx: The time step of s_j.    
    """
    std = 15  # we use 15 to replicate the paper's hyperparams

    _, si_idx = anchor_state

    while True: 
        sj_idx = torch.normal(mean=si_idx, std=std, size=(1,))
        sj_idx = int(sj_idx) 

        # Ensures we don't choose an index out of range or the same state. 
        if (sj_idx < len(t)) and (sj_idx != si_idx): 
            break 
    
    s_j = t[sj_idx] 

    return [s_j, sj_idx]

In [53]:
"""
TESTING CELL!
"""
T = TrajectorySet() 
T.generate_trajectories(n_trajectories=1) 
t = T.get_trajectory(0)[0] 
anchor_state = sample_anchor_state(t)
print(f"s_i: {anchor_state[0]}, at time step: {anchor_state[1]}")

positive_pair = sample_positive_pair_gaussian(t=t, anchor_state=anchor_state)
print(f"s_j: {positive_pair[0]}, at time step: {positive_pair[1]}")


s_i: [ 2.68499918 -1.08981343  1.79511564  1.17858673], at time step: 13
s_j: [ 2.35291474 -1.21387728  4.50045152  0.13545453], at time step: 2


# Sample a minibatch of states. 

### For example: A batch of 2048 samples will result in 2(2048 - 1) = 4094 negative examples per positive pair. 

In [57]:
def sample_batch(T: TrajectorySet, batch_size=1024, k=2) -> list[tuple]: 
    """ 
    Creates a batch of anchor states, their positive pairs, and negative pairs. 
    There will be 2(batch_size - 1) amount of negative examples per positive pair.

    Args: 
        T: The trajectory set class (must be empty). 
        batch_size: The size of the batch to be generated.
        k: A hyperparameter that dictates the average number of 
            positive pairs sampled from the same trajectory. The 
            lower the number, the lesser the chance of false negatives. 
    
    Returns: 
        A list of tuples containing the anchor_state and its positive pair. 
        The list is the same length as batch_size. 
    """ 

    batch = [] 

    # Generate trajectory set 
    n_trajectories = batch_size // k
    T.generate_trajectories(n_trajectories= n_trajectories)

    for _ in range(batch_size): 
        # Sample anchor state 
        rng = torch.randint(low=0, high=n_trajectories, size=(1,)).item() 
        t = T.get_trajectory(index=rng)[0]
        
        anchor_state = sample_anchor_state(t) 

        # Sample positive pair 
        positive_pair = sample_positive_pair_gaussian(t, anchor_state=anchor_state)

        # Retrieve states; time-steps aren't necessary. 
        s_i = anchor_state[0]
        s_j = positive_pair[0]

        batch.append([s_i, s_j]) 

    return batch 

In [58]:
"""
TESTING CELL! 

Testing to see if sample_batch works properly. 
"""

T = TrajectorySet() 
batch = sample_batch(T, 64, 1)
print("Batch:", batch)
len(batch)

Batch: [[array([-4.63952485,  2.42895509, -0.70942291,  5.07436753]), array([-4.60797118,  2.22234224, -0.77327669,  5.09043942])], [array([3.67080735, 0.85142329, 4.84847044, 0.54643165]), array([3.82608333, 0.86599257, 5.22625565, 0.45977857])], [array([2.76826608, 0.89862059, 3.90544118, 0.57572993]), array([ 2.44107337,  0.96781873,  0.8800313 , -2.08661115])], [array([-3.56475894,  0.98381129, -3.08658069, -0.11977307]), array([-4.04334881,  0.99029188, -4.88998371, -0.03143712])], [array([ 4.4815226 , -1.01473624, -0.78154177, -2.50279187]), array([ 4.3609642 , -1.13010142, -2.42388747, -1.10584767])], [array([ 2.43688701,  1.77557874, -1.32337217, -4.60261191]), array([ 2.48204679,  2.17004052,  0.5620716 , -3.50312065])], [array([ 0.49972185, -2.66915211,  0.14275037,  2.93397442]), array([ 0.49553987, -2.82567333,  0.08765508,  1.30021421])], [array([-2.50478729,  2.99670037, -2.82860567,  0.111341  ]), array([-2.02992743,  2.96886626, -4.10945937,  0.3165978 ])], [array([ 3.2

64