In [1]:
# navigate to home_project/unicoil and install using "pip install -e ."

In [None]:
import torch
import random
from tqdm import tqdm
from datasets import load_dataset
from tevatron.arguments import DataArguments
from torch.utils.data import DataLoader, Dataset
from transformers import AutoConfig, AutoTokenizer
from tevatron.data import EncodeDataset, EncodeCollator
from tevatron.modeling import EncoderOutput, UniCoilModel


# configure the model

q_max_len = 16
p_max_len = 384
encode_is_qry = False
text_max_length = q_max_len if encode_is_qry else p_max_len

hf_model_name = "pxyu/UniCOIL-MSMARCO-KL-Distillation-CSV100k"

config = AutoConfig.from_pretrained(
    hf_model_name,
    num_labels=1,
)

tokenizer = AutoTokenizer.from_pretrained(
    hf_model_name,
    use_fast=True,
)

model = UniCoilModel.load(
    model_name_or_path=hf_model_name,
    config=config,
)

# THIS IS IMPORTANT!
disabled_token_ids = tokenizer.convert_tokens_to_ids(["[SEP]", "[CLS]", "[MASK]", "[PAD]"])
model.disabled_token_ids = disabled_token_ids

In [None]:
marco_passage = load_dataset("Tevatron/msmarco-passage-corpus")['train']

In [None]:
id2text = {}

for x in tqdm(marco_passage):
    
    docid = x['docid']
    text = x['text']
    id2text[docid] = text
    if len(id2text) == 100:
        break

In [None]:
sampled_id2text = {k: id2text[k] for k in random.sample(list(id2text.keys()), 20)}

In [None]:
pairs = [{"text_id": k,  "text": tokenizer.encode(v, add_special_tokens=False)} for k, v in sampled_id2text.items()]
encode_dataset = EncodeDataset(pairs, tokenizer, text_max_length)

In [None]:
encode_dataset[0]

In [None]:
encode_loader = DataLoader(
    encode_dataset,
    batch_size=4,
    collate_fn=EncodeCollator(
        tokenizer,
        max_length=text_max_length,
        padding='max_length'
    ),
    shuffle=False,
    drop_last=False,
)

In [None]:
encoded = []
lookup_indices = []
model.eval()
device = "cpu"

for (batch_ids, batch) in tqdm(encode_loader):
    lookup_indices.extend(batch_ids)
    with torch.no_grad():
        for k, v in batch.items():
            batch[k] = v.to(device)
        if encode_is_qry:
            model_output: EncoderOutput = model(query=batch)
            output = model_output.q_reps.cpu().detach().numpy()
        else:
            model_output: EncoderOutput = model(passage=batch)
            output = model_output.p_reps.cpu().detach().numpy()

    encoded += list(map(process_output, output))


In [None]:
model.pooler