In [1]:
import os
import random
import wandb
import torch
import numpy as np
import matplotlib.pyplot as plt
from beir.datasets.data_loader import GenericDataLoader

from matryoshka import Matryoshka, PairwiseSimilarityLoss, PairwiseSimilarityLossParallel

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

  from tqdm.autonotebook import tqdm


In [2]:
data_path = "data/nfcorpus"
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")

length = None
corpus = {k: v for k, v in list(corpus.items())[:length]}
queries = {k: v for k, v in list(queries.items())[:length]}
qrels = {k: v for k, v in list(qrels.items())[:length]}

  0%|          | 0/3633 [00:00<?, ?it/s]

In [3]:
base_model = Matryoshka(matryoshka_dim=384, adaptor=False)
model = Matryoshka(matryoshka_dim=384, adaptor=True)
tokenizer = model.tokenizer

sentences = ["sentence"]
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)

if torch.cuda.is_available():
    model = model.cuda()
    base_model = base_model.cuda()



In [4]:
cs = [c["text"] for c in corpus.values()]
qs = list(queries.values())

In [5]:
inputs = tokenizer(qs[:10], return_tensors="pt", padding=True, truncation=True)
if torch.cuda.is_available():
    for k, v in inputs.items():
        inputs[k] = v.cuda()
outputs = model(pooling=True, skip=False, **inputs)
target_outputs = base_model(pooling=True, **inputs)

loss_fn = PairwiseSimilarityLoss()
loss_fn_parallel = PairwiseSimilarityLossParallel()
loss, loss_partial = loss_fn(outputs, target_outputs, [64, 128, 256, 384], reduce=False)
loss2, loss_partial2 = loss_fn_parallel(outputs, target_outputs, [64, 128, 256, 384], reduce=False)

print(loss)
print(loss_partial)
print(loss2)
print(loss_partial2)

tensor(133.9172, device='cuda:0', grad_fn=<AddBackward0>)
{64: tensor(35.2179, device='cuda:0', grad_fn=<AddBackward0>), 128: tensor(33.6590, device='cuda:0', grad_fn=<AddBackward0>), 256: tensor(31.9819, device='cuda:0', grad_fn=<AddBackward0>), 384: tensor(33.0584, device='cuda:0', grad_fn=<AddBackward0>)}
tensor(128.0457, device='cuda:0', grad_fn=<AddBackward0>)
{64: tensor(33.7338, device='cuda:0', grad_fn=<AddBackward0>), 128: tensor(32.2593, device='cuda:0', grad_fn=<AddBackward0>), 256: tensor(30.7107, device='cuda:0', grad_fn=<AddBackward0>), 384: tensor(31.3418, device='cuda:0', grad_fn=<AddBackward0>)}


In [6]:
import torch.nn.functional as F

loss = 0
embeddings = target_outputs
embeddings_adapted = outputs
for j in range(len(outputs)):
    target_similarity = F.cosine_similarity(embeddings[0].squeeze(0), embeddings[j].squeeze(0), dim=0)
    for m in [64, 128, 256, 384]:
        similarity = F.cosine_similarity(embeddings_adapted[0].squeeze(0)[:m], embeddings[j].squeeze(0)[:m], dim=0)
        loss += torch.abs(target_similarity - similarity)
print("loss")
print(loss)

loss
tensor(8.4998, device='cuda:0', grad_fn=<AddBackward0>)


In [7]:
loss = 0
embeddings = target_outputs
embeddings_adapted = outputs
cloned_target = embeddings[0].repeat(10, 1)
target_similarity = F.cosine_similarity(cloned_target, embeddings, dim=1)
cloned_adapted = embeddings_adapted[0].repeat(10, 1)
for m in [64, 128, 256, 384]:
    similarity = F.cosine_similarity(cloned_adapted[:, :m], embeddings[:, :m], dim=1)
    loss += torch.sum(torch.abs(target_similarity - similarity))
print(loss)

tensor(8.4998, device='cuda:0', grad_fn=<AddBackward0>)


In [8]:
%%timeit
loss, loss_partial = loss_fn(outputs, target_outputs, [64, 128, 256, 384], reduce=False)

23.4 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%%timeit
loss2, loss_partial2 = loss_fn_parallel(outputs, target_outputs, [64, 128, 256, 384], reduce=False)

5.07 ms ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
run_name = "pairwise_loss_skip_normalized"
epochs = 25
batch_size = 64
running_loss_step = 10
learning_rate = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = PairwiseSimilarityLossParallel()
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

data = qs + cs
random.shuffle(data)
train_data = data[:int(len(data) * 0.8)]
test_data = data[int(len(data) * 0.8):]

wandb.init(
    project="matryoshka_pairwise",  
    name=run_name,        
    config={                      
        "learning_rate": learning_rate,
        "epochs": epochs,
        "batch_size": batch_size,
        "loss": "PairwiseSimilarityLoss",
        "model": model.name,
        "model_card": model.model_card_data,
        "loss_resolution": running_loss_step,
        "architecture": model.__str__(),
    }
)

ls = []
for i in range(epochs):
    epoch_loss = []
    running_loss = []
    partial_running_loss = {64: [], 128: [], 256: [], 384: []}

    model.train()
    random.shuffle(train_data)
    for j in range(0, len(train_data), batch_size):
        print(j)
        if j + batch_size > len(train_data):
            break
        # c = cs[j : j + 32]
        # q = qs[j : j + 32]
        q = train_data[j : j + batch_size]

        inputs = tokenizer(q, return_tensors="pt", padding=True, truncation=True)
        if torch.cuda.is_available():
            for k, v in inputs.items():
                inputs[k] = v.cuda()
        outputs = model(pooling=True, skip=False, **inputs)
        target_outputs = base_model(pooling=True, **inputs)

        loss, loss_partial = loss_fn(outputs, target_outputs, [64, 128, 256, 384])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ls.append(loss.item())
        epoch_loss.append(loss.item())
        running_loss.append(loss.item())
        for k, v in loss_partial.items():
            partial_running_loss[k].append(v.item())
        if len(running_loss) % 10 == 0:
            print("Batch:", j, "loss:", np.mean(running_loss))
            wandb.log({"batch": j, "loss": np.mean(running_loss)} | {f"loss_{k}": np.mean(v) for k, v in partial_running_loss.items()})
            partial_running_loss = {64: [], 128: [], 256: [], 384: []}
            running_loss = []
 
    model.eval()
    eval_loss = []
    partial_eval_loss = {64: [], 128: [], 256: [], 384: []}
    for j in range(0, len(test_data), batch_size):
        if j + batch_size > len(test_data):
            break
        q = test_data[j : j + batch_size]

        inputs = tokenizer(q, return_tensors="pt", padding=True, truncation=True)
        if torch.cuda.is_available():
            for k, v in inputs.items():
                inputs[k] = v.cuda()
        with torch.no_grad():
            outputs = model(pooling=True, skip=False, **inputs)
            target_outputs = base_model(pooling=True, **inputs)
            loss, loss_partial = loss_fn(outputs, target_outputs, [64, 128, 256, 384])
            eval_loss.append(loss.item())
            for k, v in loss_partial.items():
                partial_eval_loss[k].append(v.item())
    wandb.log({"epoch": i, "epoch_loss": np.mean(epoch_loss), "eval_loss": np.mean(eval_loss)} | {f"eval_loss_{k}": np.mean(v) for k, v in partial_eval_loss.items()})
    print("Epoch:", i, "loss:", np.mean(epoch_loss), "eval_loss:", np.mean(eval_loss))
    os.makedirs(f"models/{run_name}", exist_ok=True)
    torch.save(model, f"models/{run_name}/{i}.pth")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlaz4rz[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112051188117929, max=1.0…

0
64
128
192
256
320
384
448
512
576
Batch: 576 loss: 0.38975848257541656
640
704
768
832
896
960
1024
1088
1152
1216
Batch: 1216 loss: 0.3867886245250702
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
Batch: 1856 loss: 0.3859786123037338
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
Batch: 2496 loss: 0.3808394134044647
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
Batch: 3136 loss: 0.37997758388519287
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
Batch: 3776 loss: 0.3750932842493057
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
Batch: 4416 loss: 0.3725013703107834
4480
4544
4608
4672
4736
4800
4864
4928
Epoch: 0 loss: 0.3812659282188911 eval_loss: 0.37232198683839096
0
64
128
192
256
320
384
448
512
576
Batch: 576 loss: 0.36411979496479036
640
704
768
832
896
960
1024
1088
1152
1216
Batch: 1216 loss: 0.3660570025444031
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
Batch: 1856 loss: 0.3615768730640411
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
Batch: 2496 