## Calculate the similarity of joystick samples in a batch
For contrastive learning, we need the joystick samples in a batch to be diverse, otherwise the \
model will not be able to learn

In [None]:
import torch
import pandas as pd
from dataset import CLIPDataModule
from IPython.display import display

In [None]:
BATCH_SIZE = 128
NUM_TRIALS = 10

### CLIPDataModule with Weighted Sampling

In [None]:
dm = CLIPDataModule(data_path='data',
                    batch_size=BATCH_SIZE,
                    num_workers=10,
                    use_weighted_sampling=True)

dm.setup()


## Function to calculate batch similarity
Calcuate the norm of the vector between every pair of joystick vectors

In [None]:
def get_batch_similarity(joystick_batch: torch.Tensor) -> float:
    m = torch.empty((joystick_batch.shape[0], joystick_batch.shape[0]))
    joystick_batch = joystick_batch.flatten(1)
    for x in range(m.shape[0]):
        x_1 = joystick_batch[x, :]
        for y in range(m.shape[0]):
            x_2 = joystick_batch[y, :]
            m[x, y] = torch.linalg.norm(x_2 - x_1)
    similarity = torch.sum(m).item() / m.shape[0] ** 2
    return similarity


Calculate Average Similarity

In [None]:
similarities = []
for i, batch in enumerate(dm.train_dataloader()):
    s = get_batch_similarity(batch[1])
    similarities.append(s)
df = pd.DataFrame(data=similarities, columns=['batch similarity'])
display(df)
print(f"average batch similarity: {df['batch similarity'].mean():.2f}")



## DataModule without Weighted Sampling

In [None]:
dm_ns = CLIPDataModule(data_path='data_bak',
                       batch_size=BATCH_SIZE,
                       num_workers=4,
                       use_weighted_sampling=False)

dm_ns.setup()


In [None]:
similarities = []
for i, batch in enumerate(dm_ns.train_dataloader()):
    s = get_batch_similarity(batch[1])
    similarities.append(s)
df = pd.DataFrame(data=similarities, columns=['batch similarity'])
display(df)
print(f"average batch similarity: {df['batch similarity'].mean():.2f}")