## Calculate the similarity of joystick samples in a batch
For contrastive learning, we need the joystick samples in a batch to be diverse, otherwise the \
model will not be able to learn

In [1]:
import torch
from dataset import CLIPDataModule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BATCH_SIZE = 64
NUM_TRIALS = 10

### CLIPDataModule with Weighted Sampling

In [3]:
dm = CLIPDataModule(data_path='data_bak',
                    batch_size=BATCH_SIZE,
                    num_workers=4,
                    use_weighted_sampling=True)

dm.setup()


[1m[32mloading data from data_bak...[0m


100%|██████████| 10/10 [11:37<00:00, 69.71s/it]
100%|██████████| 7/7 [00:04<00:00,  1.44it/s]


In [4]:
def get_batch_similarity(joystick_batch):
    m = torch.empty((joystick_batch.shape[0], joystick_batch.shape[0]))
    for x in range(m.shape[0]):
        x_1 = joystick_batch[x, :]
        for y in range(m.shape[0]):
            x_2 = joystick_batch[y, :]
            m[x, y] = torch.linalg.norm(x_2 - x_1)
    similarity = torch.sum(m).item() / m.shape[0] ** 2
    return similarity


Calculate Average Similarity

In [5]:
avg_similarity = 0.0
for i, batch in enumerate(dm.train_dataloader()):
    s = get_batch_similarity(batch[1])
    print(f'batch {i} similarity: {s:.2f}')
    avg_similarity += s
print(f"average batch similarity: {avg_similarity / len(dm.train_dataloader()):.2f}")



batch 0 similarity: 25.26
batch 1 similarity: 20.86
batch 2 similarity: 22.67
batch 3 similarity: 22.47
batch 4 similarity: 21.09
batch 5 similarity: 22.71
batch 6 similarity: 22.97
batch 7 similarity: 21.55
batch 8 similarity: 23.90
batch 9 similarity: 22.31
batch 10 similarity: 24.09
batch 11 similarity: 23.75
batch 12 similarity: 23.96
batch 13 similarity: 22.55
batch 14 similarity: 22.73
batch 15 similarity: 23.38
batch 16 similarity: 22.07
batch 17 similarity: 22.77
batch 18 similarity: 21.90
batch 19 similarity: 21.93
batch 20 similarity: 20.42
batch 21 similarity: 20.52
batch 22 similarity: 20.99
batch 23 similarity: 21.64
batch 24 similarity: 20.00
batch 25 similarity: 18.67
batch 26 similarity: 21.40
batch 27 similarity: 19.23
batch 28 similarity: 18.72
batch 29 similarity: 19.23
batch 30 similarity: 18.27
batch 31 similarity: 19.60
batch 32 similarity: 16.83
batch 33 similarity: 17.61
batch 34 similarity: 18.07
batch 35 similarity: 17.20
batch 36 similarity: 18.02
batch 37 si

## DataModule without Weighted Sampling

In [6]:
dm_ns = CLIPDataModule(data_path='data_bak',
                       batch_size=BATCH_SIZE,
                       num_workers=4,
                       use_weighted_sampling=False)

dm_ns.setup()


[1m[32mloading data from data_bak...[0m


100%|██████████| 10/10 [00:22<00:00,  2.22s/it]
100%|██████████| 7/7 [00:01<00:00,  4.44it/s]


In [7]:
avg_similarity = 0.0
for i, batch in enumerate(dm_ns.train_dataloader()):
    s = get_batch_similarity(batch[1])
    print(f'batch {i} similarity: {s:.2f}')
    avg_similarity += s
print(
    f"average batch similarity: {avg_similarity / len(dm.train_dataloader()):.2f}")

batch 0 similarity: 7.71
batch 1 similarity: 7.68
batch 2 similarity: 9.02
batch 3 similarity: 7.01
batch 4 similarity: 7.96
batch 5 similarity: 5.96
batch 6 similarity: 8.76
batch 7 similarity: 8.86
batch 8 similarity: 9.36
batch 9 similarity: 7.97
batch 10 similarity: 8.70
batch 11 similarity: 7.50
batch 12 similarity: 8.17
batch 13 similarity: 9.04
batch 14 similarity: 10.88
batch 15 similarity: 8.82
batch 16 similarity: 9.83
batch 17 similarity: 9.77
batch 18 similarity: 7.98
batch 19 similarity: 10.14
batch 20 similarity: 6.81
batch 21 similarity: 8.55
batch 22 similarity: 8.92
batch 23 similarity: 8.22
batch 24 similarity: 8.24
batch 25 similarity: 9.46
batch 26 similarity: 7.43
batch 27 similarity: 8.19
batch 28 similarity: 9.07
batch 29 similarity: 9.34
batch 30 similarity: 7.98
batch 31 similarity: 8.08
batch 32 similarity: 8.66
batch 33 similarity: 10.51
batch 34 similarity: 7.57
batch 35 similarity: 10.91
batch 36 similarity: 7.29
batch 37 similarity: 8.05
batch 38 similarit