In [7]:
import os

In [1]:
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader

  from tqdm.autonotebook import tqdm


In [2]:
dataset = "scifact"
data_path = f"../datasets/{dataset}"

# Loading test set
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

100%|██████████| 5183/5183 [00:00<00:00, 17430.22it/s]


In [3]:
qrels['179']

{'16322674': 1, '27123743': 1, '23557241': 1, '17450673': 1}

In [4]:
import torch
from transformers import AutoModel, AutoTokenizer
from peft import PeftModel, PeftConfig

In [9]:

def get_model(peft_model_name):
    config = PeftConfig.from_pretrained(peft_model_name)
    base_model = AutoModel.from_pretrained(config.base_model_name_or_path)
    model = PeftModel.from_pretrained(base_model, peft_model_name)
    model = model.merge_and_unload()
    model.eval()
    return model

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = get_model('castorini/repllama-v1-7b-lora-passage')

# Define query and passage inputs
query = "What is llama?"
title = "Llama"
passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
query_input = tokenizer(f'query: {query}</s>', return_tensors='pt')
passage_input = tokenizer(f'passage: {title} {passage}</s>', return_tensors='pt')


Downloading shards: 100%|██████████| 2/2 [04:44<00:00, 142.35s/it]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

: 

In [None]:
def get_embed(model, input):
    with torch.no_grad():
        # compute query embedding
        outputs = model(**input)
        embedding = outputs.last_hidden_state[0][-1]
        embedding = torch.nn.functional.normalize(embedding, p=2, dim=0)
    return embedding

In [None]:
import tqdm

query_embeddings = {}
doc_embeddings = {}

print("Encoding queries ...")
for k,q in tqdm(queries):
    query_embed = get_embed(model, q)
    query_embeddings[k] = query_embed

print("Encoding passages ...")
for k,q in tqdm(corpus):
    doc_embed = get_embed(model, q)
    doc_embeddings[k] = doc_embed


In [None]:

# Run the model forward to compute embeddings and query-passage similarity score
with torch.no_grad():
    # compute query embedding
    query_outputs = model(**query_input)
    query_embedding = query_outputs.last_hidden_state[0][-1]
    query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=0)

    # compute passage embedding
    passage_outputs = model(**passage_input)
    passage_embeddings = passage_outputs.last_hidden_state[0][-1]
    passage_embeddings = torch.nn.functional.normalize(passage_embeddings, p=2, dim=0)

    # compute similarity score
    score = torch.dot(query_embedding, passage_embeddings)
    print(score)