In [1]:
from transformers import AutoModel
import torch

# Initialize the model
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to("cuda")
model.eval()

  def forward(
  def backward(ctx, dout, *args):


XLMRobertaLoRA(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): ParametrizedEmbedding(
        250002, 1024, padding_idx=1
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): LoRAParametrization()
          )
        )
      )
      (token_type_embeddings): ParametrizedEmbedding(
        1, 1024
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): LoRAParametrization()
          )
        )
      )
    )
    (emb_drop): Dropout(p=0.1, inplace=False)
    (emb_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): XLMRobertaEncoder(
      (layers): ModuleList(
        (0-23): 24 x Block(
          (mixer): MHA(
            (rotary_emb): RotaryEmbedding()
            (Wqkv): ParametrizedLinearResidual(
              in_features=1024, out_features=3072, bias=True
              (parametrizations): ModuleDict(
                (weight): 

In [2]:
from datasets import load_dataset

dataset = load_dataset("toughdata/quora-question-answer-dataset", split="train")

Downloading readme:   0%|          | 0.00/485 [00:00<?, ?B/s]

In [3]:
from tqdm import trange
import numpy as np

questions = dataset["question"]
answer = dataset["answer"]


def get_embedding(texts, model, task="text-matching", batch_size=128):
    embeddings = []

    for i in trange(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        result = model.encode(batch, task=task, convert_to_tensor=True)
        embeddings.append(result)
    return torch.cat(embeddings)

In [4]:
def mean_reciprocal_rank(similarity_matrix):
    sim_matrix = similarity_matrix.cpu().numpy()
    n = sim_matrix.shape[0]

    reciprocal_ranks = []
    for i in range(n):
        row = sim_matrix[i]
        ranks = (-row).argsort()  # Rank in descending order
        rank_of_diag = np.where(ranks == i)[0][0] + 1  # 1-based rank
        reciprocal_ranks.append(1 / rank_of_diag)

    return np.mean(reciprocal_ranks)

In [5]:
def get_uniform_loss(embeddings):
    distance_matrix = torch.pdist(embeddings, p=2).pow(2)
    exp_kernel = torch.exp(-2 * distance_matrix)
    uniform_loss = torch.log(exp_kernel.mean())
    return uniform_loss.item()


def get_alignment_loss(embeddings_1, embeddings_2):
    return (embeddings_1 - embeddings_2).norm(p=2, dim=1).pow(2).mean().item()

In [6]:
from matplotlib import pyplot as plt

def get_ranking(similarity_matrix):
    rank = torch.argsort(similarity_matrix, dim=1, descending=True)
    # find the index of i in ith row
    row_indices = torch.arange(rank.size(0)).to(rank.device)
    # Compare each element in base_rank with its row index
    # row_indices.unsqueeze(1)
    comparison = rank == row_indices.unsqueeze(1)
    # Find the index where the value is True in each row
    positions = comparison.nonzero()[:, 1]

    return positions


def evaluate(use_lora=False, use_query_lora=False, test_size=1024):
    if use_lora:
        if use_query_lora:
            questions_embedding = get_embedding(
                questions[:test_size], model, task="retrieval.query"
            )
            answer_embedding = get_embedding(
                answer[:test_size], model, task="retrieval.passage"
            )
        else:
            questions_embedding = get_embedding(questions[:test_size], model, task='text-matching')
            answer_embedding = get_embedding(answer[:test_size], model, batch_size=32, task='text-matching')
    else:
        questions_embedding = get_embedding(questions[:test_size], model, task=None)
        answer_embedding = get_embedding(answer[:test_size], model, batch_size=32, task=None)

    similarity_matrix = torch.matmul(questions_embedding, answer_embedding.T)

    print("\tMRR (higher is better):", mean_reciprocal_rank(similarity_matrix))

    uniform_loss_questions = get_uniform_loss(questions_embedding)
    uniform_loss_answer = get_uniform_loss(answer_embedding)
    alignment_loss = get_alignment_loss(questions_embedding, answer_embedding)

    print("\tUniform Loss (Questions):", uniform_loss_questions)
    print("\tUniform Loss (Answer):", uniform_loss_answer)
    print("\tAlignment Loss:", alignment_loss)

    ranking = get_ranking(similarity_matrix)
    top1 = (ranking == 0).float().mean().item()
    top_5 = (ranking < 5).float().mean().item()
    top_10 = (ranking < 10).float().mean().item()
    top_20 = (ranking < 20).float().mean().item()

    print("\tTop 1 Accuracy:", top1)
    print("\tTop 5 Accuracy:", top_5)
    print("\tTop 10 Accuracy:", top_10)
    print("\tTop 20 Accuracy:", top_20)

test_size = 5120
print("Base Model")
evaluate(use_lora=False, test_size=test_size)
print("LoRA Model")
evaluate(use_lora=True, test_size=test_size)
print("LoRA Model with Query")
evaluate(use_lora=True, use_query_lora=True, test_size=test_size)

Base Model


100%|██████████| 40/40 [00:05<00:00,  6.71it/s]
100%|██████████| 160/160 [00:59<00:00,  2.70it/s]


	MRR (higher is better): 0.43431357869125675
	Uniform Loss (Questions): -2.7598531246185303
	Uniform Loss (Answer): -2.7526440620422363
	Alignment Loss: 0.6807968020439148
	Top 1 Accuracy: 0.2777343690395355
	Top 5 Accuracy: 0.637890636920929
	Top 10 Accuracy: 0.7685546875
	Top 20 Accuracy: 0.8353515863418579
LoRA Model


100%|██████████| 40/40 [00:21<00:00,  1.84it/s]
100%|██████████| 160/160 [01:09<00:00,  2.31it/s]


	MRR (higher is better): 0.4183315753429723
	Uniform Loss (Questions): -3.756124973297119
	Uniform Loss (Answer): -3.5556962490081787
	Alignment Loss: 0.9674510955810547
	Top 1 Accuracy: 0.26835939288139343
	Top 5 Accuracy: 0.6103515625
	Top 10 Accuracy: 0.735156238079071
	Top 20 Accuracy: 0.8003906607627869
LoRA Model with Query


100%|██████████| 40/40 [00:21<00:00,  1.84it/s]
100%|██████████| 40/40 [00:50<00:00,  1.25s/it]


	MRR (higher is better): 0.42633959762058343
	Uniform Loss (Questions): -3.6905629634857178
	Uniform Loss (Answer): -3.698300361633301
	Alignment Loss: 0.9925322532653809
	Top 1 Accuracy: 0.275390625
	Top 5 Accuracy: 0.6255859732627869
	Top 10 Accuracy: 0.746874988079071
	Top 20 Accuracy: 0.80859375
