In [6]:
import os
import random
import wandb
import torch
import numpy as np
import matplotlib.pyplot as plt
from beir.datasets.data_loader import GenericDataLoader

from matryoshka import Matryoshka, PairwiseSimilarityLoss, PairwiseSimilarityLossParallel, RegularizingLoss, TopKSimilarityLoss

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [7]:
data_path = "data/nfcorpus"
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")

length = None
corpus = {k: v for k, v in list(corpus.items())[:length]}
queries = {k: v for k, v in list(queries.items())[:length]}
qrels = {k: v for k, v in list(qrels.items())[:length]}

  0%|          | 0/3633 [00:00<?, ?it/s]

In [8]:
base_model = Matryoshka(matryoshka_dim=384, adaptor=False)
model = Matryoshka(matryoshka_dim=384, adaptor=True)
tokenizer = model.tokenizer

sentences = ["sentence"]
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)

if torch.cuda.is_available():
    model = model.cuda()
    base_model = base_model.cuda()



In [9]:
cs = [c["text"] for c in corpus.values()]
qs = list(queries.values())

In [10]:
run_name = "pairwise_reg_topk_skip_layernorm"
epochs = 25
batch_size = 64
running_loss_step = 10
learning_rate = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = PairwiseSimilarityLossParallel()
loss_fn_reg = RegularizingLoss()
loss_fn_topk = TopKSimilarityLoss(k=10)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

data = qs + cs
random.shuffle(data)
train_data = data[:int(len(data) * 0.95)]
test_data = data[int(len(data) * 0.95):]

wandb.init(
    project="matryoshka_pairwise",  
    name=run_name,        
    config={                      
        "learning_rate": learning_rate,
        "epochs": epochs,
        "batch_size": batch_size,
        "loss": "PairwiseSimilarityLoss+RegLoss",
        "model": model.name,
        "model_card": model.model_card_data,
        "loss_resolution": running_loss_step,
        "architecture": model.__str__(),
    }
)

ls = []
for i in range(epochs):
    epoch_loss = []
    running_loss = []
    partial_running_loss = {64: [], 128: [], 256: [], 384: []}

    model.train()
    random.shuffle(train_data)
    for j in range(0, len(train_data), batch_size):
        print(j)
        if j + batch_size > len(train_data):
            break
        # c = cs[j : j + 32]
        # q = qs[j : j + 32]
        q = train_data[j : j + batch_size]

        inputs = tokenizer(q, return_tensors="pt", padding=True, truncation=True)
        if torch.cuda.is_available():
            for k, v in inputs.items():
                inputs[k] = v.cuda()
        outputs = model(pooling=True, **inputs)
        target_outputs = base_model(pooling=True, **inputs)

        loss, loss_partial = loss_fn(target_outputs, outputs, [64, 128, 256, 384])
        loss_reg = loss_fn_reg(target_outputs, outputs)
        loss_topk = loss_fn_topk(target_outputs, outputs, [64, 128, 256, 384])

        loss = loss + loss_reg + loss_topk

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ls.append(loss.item())
        epoch_loss.append(loss.item())
        running_loss.append(loss.item())
        for k, v in loss_partial.items():
            partial_running_loss[k].append(v.item())
        if len(running_loss) % 10 == 0:
            print("Batch:", j, "loss:", np.mean(running_loss))
            wandb.log({"batch": j, "loss": np.mean(running_loss)} | {f"loss_{k}": np.mean(v) for k, v in partial_running_loss.items()})
            partial_running_loss = {64: [], 128: [], 256: [], 384: []}
            running_loss = []
 
    model.eval()
    eval_loss = []
    partial_eval_loss = {64: [], 128: [], 256: [], 384: []}
    for j in range(0, len(test_data), batch_size):
        if j + batch_size > len(test_data):
            break
        q = test_data[j : j + batch_size]

        inputs = tokenizer(q, return_tensors="pt", padding=True, truncation=True)
        if torch.cuda.is_available():
            for k, v in inputs.items():
                inputs[k] = v.cuda()
        with torch.no_grad():
            outputs = model(pooling=True, **inputs)
            target_outputs = base_model(pooling=True, **inputs)
            loss, loss_partial = loss_fn(outputs, target_outputs, [64, 128, 256, 384])
            eval_loss.append(loss.item())
            for k, v in loss_partial.items():
                partial_eval_loss[k].append(v.item())
    wandb.log({"epoch": i, "epoch_loss": np.mean(epoch_loss), "eval_loss": np.mean(eval_loss)} | {f"eval_loss_{k}": np.mean(v) for k, v in partial_eval_loss.items()})
    print("Epoch:", i, "loss:", np.mean(epoch_loss), "eval_loss:", np.mean(eval_loss))
    os.makedirs(f"models/{run_name}", exist_ok=True)
    torch.save(model, f"models/{run_name}/{i}.pth")

VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111914697620604, max=1.0…

0
64
128
192
256
320
384
448
512
576
Batch: 576 loss: 2.0714909553527834
640
704
768
832
896
960
1024
1088
1152
1216
Batch: 1216 loss: 2.0424808740615843
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
Batch: 1856 loss: 2.017916536331177
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
Batch: 2496 loss: 1.9935804724693298
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
Batch: 3136 loss: 1.973906433582306
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
Batch: 3776 loss: 1.9541754961013793
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
Batch: 4416 loss: 1.932689893245697
4480
4544
4608
4672
4736
4800
4864
4928
4992
5056
Batch: 5056 loss: 1.9012663960456848
5120
5184
5248
5312
5376
5440
5504
5568
5632
5696
Batch: 5696 loss: 1.8807628393173217
5760
5824
5888
Epoch: 0 loss: 1.9717461091020834 eval_loss: 0.24278169125318527
0
64
128
192
256
320
384
448
512
576
Batch: 576 loss: 1.8423680782318115
640
704
768
832
896
960
1024
1088
1152
1216
Batch: 1216 loss: 1.8209438800811768
1280
1