In [3]:
from transformers import AutoTokenizer, BertModel
import pandas as pd
import torch

In [4]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

In [5]:
e5_model = BertModel.from_pretrained("intfloat/e5-large-v2", torch_dtype="auto", device_map="auto").to(device)
e5_tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-large-v2")

In [6]:
import torch.nn.functional as F
from torch import Tensor


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


passages = pd.read_csv("../../ml/data/cleaned.csv")[:10].abstract.tolist()
passages = ["passage: " + passage for passage in passages]

queries = pd.read_csv("test_data.csv")
queries, answer_ids = queries["question"].tolist(), queries["relatable_abstract_id"].tolist()
queries = ["query: " + query for query in queries]

# Each input text should start with "query: " or "passage: ".
# For tasks other than retrieval, you can simply use the "query: " prefix.
input_texts = queries + passages

print(max(map(len, input_texts)))

1592


In [7]:
# Tokenize the input texts
batch_dict = e5_tokenizer(input_texts, max_length=1600, padding=True, truncation=True, return_tensors='pt').to(device)

outputs = e5_model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:len(queries)] @ embeddings[len(queries):].T).tolist()


In [11]:
ranks = []
for row in range(len(scores)):
    rank = 1
    current = scores[row][answer_ids[row]]
    for x in scores[row]:
        if x > current:
            rank += 1
    ranks.append(rank)

ranks

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]