## Calculate the similarity of joystick samples in a batch
For contrastive learning, we need the joystick samples in a batch to be diverse, otherwise the \
model will not be able to learn

In [1]:
import torch
import pandas as pd
from dataset import CLIPDataModule
from IPython.display import display

pd.set_option('display.max_rows', None)

  from .autonotebook import tqdm as notebook_tqdm


### Define Batch Size

In [2]:
BATCH_SIZE = 128

### CLIPDataModule with Weighted Sampling

In [4]:
dm = CLIPDataModule(data_path='data',
                    batch_size=BATCH_SIZE,
                    num_workers=10,
                    use_weighted_sampling=True)

dm.setup()


[1m[32mloading data from data...[0m


100%|██████████| 12/12 [12:27<00:00, 62.25s/it]
100%|██████████| 13/13 [00:17<00:00,  1.34s/it]


## Function to calculate batch similarity
Calcuate the norm of the vector between every pair of joystick vectors

In [5]:
def get_batch_similarity(joystick_batch: torch.Tensor) -> float:
    m = torch.empty((joystick_batch.shape[0], joystick_batch.shape[0]))
    joystick_batch = joystick_batch.flatten(1)
    for x in range(m.shape[0]):
        x_1 = joystick_batch[x, :]
        for y in range(m.shape[0]):
            x_2 = joystick_batch[y, :]
            m[x, y] = torch.linalg.norm(x_2 - x_1)
    similarity = torch.sum(m).item() / m.shape[0] ** 2
    return similarity


Calculate Average Similarity

In [6]:
similarities = []
for i, batch in enumerate(dm.train_dataloader()):
    s = get_batch_similarity(batch[1])
    similarities.append(s)
df = pd.DataFrame(data=similarities, columns=['batch similarity'])
display(df)
print(f"average batch similarity: {df['batch similarity'].mean():.2f}")



Unnamed: 0,batch similarity
0,19.882969
1,18.316841
2,19.351839
3,20.446171
4,20.893623
5,20.326469
6,19.756655
7,19.230129
8,20.763271
9,20.760746


average batch similarity: 20.43


## DataModule without Weighted Sampling

In [3]:
dm_ns = CLIPDataModule(data_path='data',
                       batch_size=BATCH_SIZE,
                       num_workers=10,
                       use_weighted_sampling=False, 
                       verbose=True)

dm_ns.setup()


[1m[32mloading data from data...[0m


100%|██████████| 12/12 [00:18<00:00,  1.56s/it]
100%|██████████| 13/13 [00:17<00:00,  1.34s/it]


In [6]:
similarities = []
for i, batch in enumerate(dm_ns.train_dataloader()):
    s = get_batch_similarity(batch[1])
    similarities.append(s)
df = pd.DataFrame(data=similarities, columns=['batch similarity'])
display(df)
print(f"average batch similarity: {df['batch similarity'].mean():.2f}")

Unnamed: 0,batch similarity
0,10.112962
1,10.545479
2,7.45576
3,8.856468
4,10.359069
5,9.359221
6,8.544786
7,9.096355
8,10.072435
9,10.010383


average batch similarity: 8.93
