In [1]:
import os
import random
import wandb
import torch
import numpy as np
import matplotlib.pyplot as plt
from beir.beir.datasets.data_loader import GenericDataLoader

from matryoshka import Matryoshka, PairwiseSimilarityLoss, PairwiseSimilarityLossParallel, RegularizingLoss, TopKSimilarityLoss

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

  from tqdm.autonotebook import tqdm


In [3]:
!pip install wandb

Collecting wandb
  Using cached wandb-0.19.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting sentry-sdk>=2.0.0 (from wandb)
  Using cached sentry_sdk-2.19.2-py2.py3-none-any.whl.metadata (9.9 kB)
Collecting setproctitle (from wandb)
  Using cached setproctitle-1.3.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Using cached wandb-0.19.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20.0 MB)
Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Using cached sentry_sdk-2.19.2-py2.py3-none-any.whl (322 kB)
Using cached setproctitle-1.3.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)
Installing collected packages: setproctitle, sentry-sdk, docker-pycreds, wandb
Successfully installed dock

In [2]:
data_path = "data/nfcorpus"
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")

length = None
corpus = {k: v for k, v in list(corpus.items())[:length]}
queries = {k: v for k, v in list(queries.items())[:length]}
qrels = {k: v for k, v in list(qrels.items())[:length]}

  0%|          | 0/3633 [00:00<?, ?it/s]

In [3]:
base_model = Matryoshka(matryoshka_dim=768, adaptor=False)
model = Matryoshka(matryoshka_dim=768, adaptor=True)
tokenizer = model.tokenizer

sentences = ["sentence"]
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)

if torch.cuda.is_available():
    model = model.cuda()
    base_model = base_model.cuda()

  warn(
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [7]:
cs = [c["text"] for c in corpus.values()]
qs = list(queries.values())

In [8]:
def chunk_text(text, tokenizer, chunk_size=512, overlap=50):
    """
    Splits text into overlapping chunks.
    Args:
        text: input string
        tokenizer: tokenizer instance
        chunk_size: max token size for each chunk
        overlap: overlap between consecutive chunks
    Returns:
        List of tokenized chunks (strings)
    """
    tokens = tokenizer.tokenize(text)  # Tokenize input text
    chunks = []
    for i in range(0, len(tokens), chunk_size - overlap):
        chunk = tokens[i:i + chunk_size]
        chunks.append(tokenizer.convert_tokens_to_string(chunk))
        if len(chunk) < chunk_size:
            break
    return chunks


In [11]:
run_name = "pairwise_reg_topk_skip_layernorm"
epochs = 15
batch_size = 64
running_loss_step = 10
learning_rate = 1e-5

device = "cuda"
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = PairwiseSimilarityLossParallel()
loss_fn_reg = RegularizingLoss()
loss_fn_topk = TopKSimilarityLoss(k=10)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

data = qs + cs
random.shuffle(data)
train_data = data[:int(len(data) * 0.95)]
test_data = data[int(len(data) * 0.95):]


def process_long_texts(text_list, model, tokenizer, train=True):
    final_embeddings = []
    for text in text_list:
        chunks = chunk_text(text, tokenizer, chunk_size=512)
        chunk_embeddings = []

        if train = True:
            for chunk in chunks:
                inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
                embedding = model(pooling=True, **inputs)  # Always calculate gradients during training
                chunk_embeddings.append(embedding.squeeze(0))
            
            # Replace hard mean pooling with a learnable weighted aggregation
            chunk_embeddings = torch.stack(chunk_embeddings)
            final_embedding = torch.mean(chunk_embeddings, dim=0) if not train else chunk_embeddings.mean(dim=0, keepdim=True)
            final_embeddings.append(final_embedding)

        else:
            

    return torch.cat(final_embeddings)  # Proper gradient flow maintained



ls = []
for i in range(epochs):
    epoch_loss = []
    running_loss = []
    partial_running_loss = {64: [], 128: [], 256: [], 384: [], 768:[]}

    model.train()
    random.shuffle(train_data)
    for j in range(0, len(train_data), batch_size):
        print(j)
        if j + batch_size > len(train_data):
            break
        # c = cs[j : j + 32]
        # q = qs[j : j + 32]
        q = train_data[j : j + batch_size]

        outputs = process_long_texts(q, model, tokenizer, train=True)
        target_outputs = process_long_texts(q, base_model, tokenizer, train=True)
        #inputs = tokenizer(q, return_tensors="pt", padding=True, max_length = 512, truncation=True)
        # if torch.cuda.is_available():
        #     for k, v in inputs.items():
        #         inputs[k] = v.cuda()
        # outputs = model(pooling=True, **inputs)
        # target_outputs = base_model(pooling=True, **inputs)

        loss, loss_partial = loss_fn(target_outputs, outputs, [64, 128, 256, 384, 768])
        loss_reg = loss_fn_reg(target_outputs, outputs)
        loss_topk = loss_fn_topk(target_outputs, outputs, [64, 128, 256, 384, 768])

        loss = loss + loss_reg + loss_topk

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step() 
        ls.append(loss.item())
        epoch_loss.append(loss.item())
        running_loss.append(loss.item())
        for k, v in loss_partial.items():
            partial_running_loss[k].append(v.item())
        if len(running_loss) % 10 == 0:
            print("Batch:", j, "loss:", np.mean(running_loss))
            #wandb.log({"batch": j, "loss": np.mean(running_loss)} | {f"loss_{k}": np.mean(v) for k, v in partial_running_loss.items()})
            partial_running_loss = {64: [], 128: [], 256: [], 384: [], 768:[]}
            running_loss = []
 
    model.eval()
    base_model.eval()
    eval_loss = []
    partial_eval_loss = {64: [], 128: [], 256: [], 384: [], 768:[]}
    

    for j in range(0, len(test_data), batch_size):
        if j + batch_size > len(test_data):
            break
        q = test_data[j : j + batch_size]
    
        # Process long texts into embeddings
        with torch.no_grad():
            outputs = process_long_texts(q, model, tokenizer, train=False)
            target_outputs = process_long_texts(q, base_model, tokenizer, train=False)
            
            # Compute the losses
            loss, loss_partial = loss_fn(target_outputs, outputs, [64, 128, 256, 384, 768])
            eval_loss.append(loss.item())
            
            # Collect partial losses
            for k, v in loss_partial.items():
                partial_eval_loss[k].append(v.item())
    
    #wandb.log({"epoch": i, "epoch_loss": np.mean(epoch_loss), "eval_loss": np.mean(eval_loss)} | {f"eval_loss_{k}": np.mean(v) for k, v in partial_eval_loss.items()})
    print("Epoch:", i, "loss:", np.mean(epoch_loss), "eval_loss:", np.mean(eval_loss))
    os.makedirs(f"modelsbert1/{run_name }", exist_ok=True)
    torch.save(model.state_dict(), f"modelsbert1/{run_name}/{i}_state_dict.pth")

0
64
128
192
256
320
384
448
512
576
Batch: 576 loss: 1.4785133957862855
640
704
768
832
896
960
1024
1088
1152
1216
Batch: 1216 loss: 1.4668502449989318
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
Batch: 1856 loss: 1.4672417044639587
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
Batch: 2496 loss: 1.4672711610794067
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
Batch: 3136 loss: 1.4685245037078858
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
Batch: 3776 loss: 1.47001770734787
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
Batch: 4416 loss: 1.466200590133667
4480
4544
4608
4672
4736
4800
4864
4928
4992
5056
Batch: 5056 loss: 1.4673577785491942
5120
5184
5248
5312
5376
5440
5504
5568
5632
5696
Batch: 5696 loss: 1.4688364624977113
5760
5824
5888


IndexError: too many indices for tensor of dimension 1

In [9]:
!git config --global --add safe.directory /home/shreya34/matryoshka

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
!wandb login

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!
