In [1]:
import os
import random
import wandb
import torch
import numpy as np
import matplotlib.pyplot as plt
from beir.datasets.data_loader import GenericDataLoader

from matryoshka import Matryoshka, PairwiseSimilarityLoss, PairwiseSimilarityLossParallel, RegularizingLoss

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

  from tqdm.autonotebook import tqdm


In [2]:
data_path = "data/nfcorpus"
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")

length = None
corpus = {k: v for k, v in list(corpus.items())[:length]}
queries = {k: v for k, v in list(queries.items())[:length]}
qrels = {k: v for k, v in list(qrels.items())[:length]}

  0%|          | 0/3633 [00:00<?, ?it/s]

In [3]:
base_model = Matryoshka(matryoshka_dim=384, adaptor=False)
model = Matryoshka(matryoshka_dim=384, adaptor=True)
tokenizer = model.tokenizer

sentences = ["sentence"]
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)

if torch.cuda.is_available():
    model = model.cuda()
    base_model = base_model.cuda()

cs = [c["text"] for c in corpus.values()]
qs = list(queries.values())



In [4]:
inputs = tokenizer(qs[:10], return_tensors="pt", padding=True, truncation=True)
if torch.cuda.is_available():
    for k, v in inputs.items():
        inputs[k] = v.cuda()
outputs = model(pooling=True, **inputs)

outputs.shape

torch.Size([10, 384])

In [13]:
sims = torch.triu(torch.matmul(outputs, outputs.T), diagonal=1)
sims

tensor([[0.0000, 0.4798, 0.4460, 0.5219, 0.5122, 0.4382, 0.4393, 0.3483, 0.4456,
         0.4915],
        [0.0000, 0.0000, 0.6895, 0.7388, 0.7228, 0.4461, 0.4331, 0.3755, 0.4809,
         0.5321],
        [0.0000, 0.0000, 0.0000, 0.7614, 0.7354, 0.4166, 0.3989, 0.4310, 0.3575,
         0.5292],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9799, 0.4803, 0.4642, 0.3819, 0.3633,
         0.4787],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4950, 0.4697, 0.3932, 0.3673,
         0.4756],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5225, 0.4936, 0.4942,
         0.5034],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5088, 0.4200,
         0.4417],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4505,
         0.3823],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.6284],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000]], device='c

In [20]:
val, idx = sims.topk(2, dim=1)

In [21]:
# Create a mask of zeros with the same shape as sims
mask = torch.zeros_like(sims, dtype=torch.bool)

# Use the indices to set the corresponding elements in the mask to True
mask.scatter_(1, idx, True)

# Apply the mask to the sims matrix
masked_sims = sims * mask
masked_sims

tensor([[0.0000, 0.0000, 0.0000, 0.5219, 0.5122, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7388, 0.7228, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7614, 0.7354, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9799, 0.4803, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4950, 0.0000, 0.0000, 0.0000,
         0.4756],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5225, 0.0000, 0.0000,
         0.5034],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5088, 0.0000,
         0.4417],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4505,
         0.3823],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.6284],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000]], device='c

In [22]:
torch.sum(val.reshape(-1) / val.reshape(-1))


tensor(nan, device='cuda:0', grad_fn=<SumBackward0>)

In [25]:
val.reshape(-1).where(val.reshape(-1) == 0, 1)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
        0., 0.], device='cuda:0', grad_fn=<WhereBackward0>)