In [1]:
import os
import random
import wandb
import torch
import numpy as np
import matplotlib.pyplot as plt
from beir.datasets.data_loader import GenericDataLoader

from matryoshka import Matryoshka

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

  from tqdm.autonotebook import tqdm


In [2]:
data_path = "data/nfcorpus"
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")

length = None
corpus = {k: v for k, v in list(corpus.items())[:length]}
queries = {k: v for k, v in list(queries.items())[:length]}
qrels = {k: v for k, v in list(qrels.items())[:length]}

  0%|          | 0/3633 [00:00<?, ?it/s]

In [3]:
base_model = Matryoshka(matryoshka_dim=384, adaptor=False)
model = Matryoshka(matryoshka_dim=384, adaptor=True)
tokenizer = model.tokenizer

sentences = ["sentence"]
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)



In [4]:
# sanity check on shape
model(pooling=False, **inputs).shape

torch.Size([1, 3, 384])

In [5]:
# sanity check on learble parameters
for i in model.named_parameters():
    if i[1][1].requires_grad:
        print(i[0])

adaptor.down_project.weight
adaptor.down_project.bias
adaptor.ffn.0.weight
adaptor.ffn.0.bias
adaptor.ffn.2.weight
adaptor.ffn.2.bias
adaptor.up_project.weight
adaptor.up_project.bias


##### Trying to reproduce original embeddings with adaptor on

In [6]:
cs = [c["text"] for c in corpus.values()]
qs = list(queries.values())

In [7]:
if torch.cuda.is_available():
    model = model.cuda()
    base_model = base_model.cuda()

epochs = 100
batch_size = 64
running_loss_step = 10
learning_rate = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
similarity_loss = torch.nn.CosineEmbeddingLoss()

data = qs + cs
random.shuffle(data)
train_data = data[:int(len(data) * 0.8)]
test_data = data[int(len(data) * 0.8):]

wandb.init(
    project="matryoshka",  
    name="identity_MSE_pooling_adapter",        
    config={                      
        "learning_rate": learning_rate,
        "epochs": epochs,
        "batch_size": batch_size,
        "loss": "MSE",
        "model": model.name,
        "model_card": model.model_card_data,
        "loss_resolution": running_loss_step,
        "architecture": model.__str__(),
    }
)

ls = []
for i in range(25):
    epoch_loss = []
    running_loss = []

    model.train()
    random.shuffle(train_data)
    for j in range(0, len(train_data), batch_size):
        print(j)
        if j + batch_size > len(train_data):
            break
        # c = cs[j : j + 32]
        # q = qs[j : j + 32]
        q = train_data[j : j + batch_size]

        inputs = tokenizer(q, return_tensors="pt", padding=True, truncation=True)
        if torch.cuda.is_available():
            for k, v in inputs.items():
                inputs[k] = v.cuda()
        outputs = model(pooling=True, skip=False, **inputs)
        target_outputs = base_model(pooling=True, **inputs)
        loss = torch.nn.MSELoss()(outputs, target_outputs)
        # similarity_loss = similarity_loss(outputs, target_outputs, torch.ones(len(q)))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ls.append(loss.item())
        epoch_loss.append(loss.item())
        running_loss.append(loss.item())
        if len(running_loss) % 10 == 0:
            print("Batch:", j, "loss:", np.mean(running_loss))
            wandb.log({"batch": j, "loss": np.mean(running_loss)})
            running_loss = []
 
    model.eval()
    eval_loss = []
    for j in range(0, len(test_data), batch_size):
        if j + batch_size > len(test_data):
            break
        q = test_data[j : j + batch_size]

        inputs = tokenizer(q, return_tensors="pt", padding=True, truncation=True)
        if torch.cuda.is_available():
            for k, v in inputs.items():
                inputs[k] = v.cuda()
        with torch.no_grad():
            outputs = model(pooling=True, skip=False, **inputs)
            target_outputs = base_model(pooling=True, **inputs)
            loss = torch.nn.MSELoss()(outputs, target_outputs)
            eval_loss.append(loss.item())

    wandb.log({"epoch": i, "epoch_loss": np.mean(epoch_loss), "eval_loss": np.mean(eval_loss)})
    print("Epoch:", i, "loss:", np.mean(epoch_loss), "eval_loss:", np.mean(eval_loss))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlaz4rz[0m. Use [1m`wandb login --relogin`[0m to force relogin


0
64
128
192
256
320
384
448
512
576
Batch: 576 loss: 0.005053099431097508
640
704
768
832
896
960
1024
1088
1152
1216
Batch: 1216 loss: 0.005006635328754783
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
Batch: 1856 loss: 0.004983232775703073
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
Batch: 2496 loss: 0.004969768784940242
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
Batch: 3136 loss: 0.004957592720165849
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
Batch: 3776 loss: 0.004912418080493808
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
Batch: 4416 loss: 0.004866496985778212
4480
4544
4608
4672
4736
4800
4864
4928
Epoch: 0 loss: 0.004956041892095433 eval_loss: 0.004850901491743953
0
64
128
192
256
320
384
448
512
576
Batch: 576 loss: 0.0048340227454900745
640
704
768
832
896
960
1024
1088
1152
1216
Batch: 1216 loss: 0.004804914770647884
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
Batch: 1856 loss: 0.004781795339658857
1920
1984
2048
2112
2176
2240
2304
2368
2

In [1]:
s = 0
for i in range(11):
    for j in range(i, 11):
        s += 1
print(s)

66
