## Calculate the similarity of joystick samples in a batch
For contrastive learning, we need the joystick samples in a batch to be diverse, otherwise the \
model will not be able to learn

In [1]:
import torch
import pandas as pd
from dataset import CLIPDataModule
from IPython.display import display

  from .autonotebook import tqdm as notebook_tqdm


### Define Batch Size

In [2]:
BATCH_SIZE = 128

### CLIPDataModule with Weighted Sampling

In [3]:
dm = CLIPDataModule(data_path='data',
                    batch_size=BATCH_SIZE,
                    num_workers=10,
                    use_weighted_sampling=True)

dm.setup()


[1m[32mloading data from data...[0m


100%|██████████| 12/12 [00:22<00:00,  1.88s/it]
100%|██████████| 3/3 [00:00<00:00,  5.57it/s]


## Function to calculate batch similarity
Calcuate the norm of the vector between every pair of joystick vectors

In [4]:
def get_batch_similarity(joystick_batch: torch.Tensor) -> float:
    m = torch.empty((joystick_batch.shape[0], joystick_batch.shape[0]))
    joystick_batch = joystick_batch.flatten(1)
    for x in range(m.shape[0]):
        x_1 = joystick_batch[x, :]
        for y in range(m.shape[0]):
            x_2 = joystick_batch[y, :]
            m[x, y] = torch.linalg.norm(x_2 - x_1)
    similarity = torch.sum(m).item() / m.shape[0] ** 2
    return similarity


Calculate Average Similarity

In [5]:
similarities = []
for i, batch in enumerate(dm.train_dataloader()):
    s = get_batch_similarity(batch[1])
    similarities.append(s)
df = pd.DataFrame(data=similarities, columns=['batch similarity'])
display(df)
print(f"average batch similarity: {df['batch similarity'].mean():.2f}")



Unnamed: 0,batch similarity
0,16.115494
1,16.700123
2,16.626482
3,16.615232
4,16.802391
...,...
242,16.520035
243,16.698954
244,16.789499
245,16.542904


average batch similarity: 16.57


## DataModule without Weighted Sampling

In [6]:
dm_ns = CLIPDataModule(data_path='data',
                       batch_size=BATCH_SIZE,
                       num_workers=10,
                       use_weighted_sampling=False, 
                       verbose=True)

dm_ns.setup()


[1m[32mloading data from data...[0m
[36mskip first 50 frames[0m
[36mbatch size: 128[0m
[36mfuture joystick length: 300
[0m
[32mcreating training set...[0m


100%|██████████| 12/12 [00:21<00:00,  1.78s/it]


[32mcreating validation set...[0m


100%|██████████| 3/3 [00:00<00:00,  5.59it/s]

[36mtraining size: 31738 samples[0m
[36mvalidation size: 2720 samples[0m





In [7]:
similarities = []
for i, batch in enumerate(dm_ns.train_dataloader()):
    s = get_batch_similarity(batch[1])
    similarities.append(s)
df = pd.DataFrame(data=similarities, columns=['batch similarity'])
display(df)
print(f"average batch similarity: {df['batch similarity'].mean():.2f}")

Unnamed: 0,batch similarity
0,8.922907
1,10.385017
2,9.210779
3,7.556968
4,8.690161
...,...
242,9.395877
243,9.053288
244,9.625372
245,8.903849


average batch similarity: 8.94
