In [1]:
# navigate to home_project/unicoil and install using "pip install -e ."

In [2]:
import torch
import random
from tqdm import tqdm
from datasets import load_dataset
from tevatron.arguments import DataArguments
from torch.utils.data import DataLoader, Dataset
from transformers import AutoConfig, AutoTokenizer
from tevatron.data import EncodeDataset, EncodeCollator
from tevatron.modeling import EncoderOutput, UniCoilModel


# configure the model

q_max_len = 16
p_max_len = 384
encode_is_qry = False
text_max_length = q_max_len if encode_is_qry else p_max_len

hf_model_name = "pxyu/UniCOIL-MSMARCO-KL-Distillation-CSV100k"

config = AutoConfig.from_pretrained(
    hf_model_name,
    num_labels=1,
)

tokenizer = AutoTokenizer.from_pretrained(
    hf_model_name,
    use_fast=True,
)

model = UniCoilModel.load(
    model_name_or_path=hf_model_name,
    config=config,
)

# THIS IS IMPORTANT!
disabled_token_ids = tokenizer.convert_tokens_to_ids(["[SEP]", "[CLS]", "[MASK]", "[PAD]"])
model.disabled_token_ids = disabled_token_ids

In [3]:
marco_passage = load_dataset("Tevatron/msmarco-passage-corpus")['train']

In [4]:
id2text = {}

for x in tqdm(marco_passage):
    
    docid = x['docid']
    text = x['text']
    id2text[docid] = text
    if len(id2text) == 100:
        break

  0%|          | 99/8841823 [00:00<11:46, 12507.86it/s]


In [5]:
sampled_id2text = {k: id2text[k] for k in random.sample(list(id2text.keys()), 20)}

In [6]:
pairs = [{"text_id": k,  "text": tokenizer.encode(v, add_special_tokens=False)} for k, v in sampled_id2text.items()]
encode_dataset = EncodeDataset(pairs, tokenizer, text_max_length)

In [7]:
encode_dataset[0]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


('41',
 {'input_ids': [2, 1180, 10408, 10822, 14234, 1627, 30, 275, 79589, 117, 16, 653, 5285, 202, 193, 1501, 263, 26069, 17, 1958, 2662, 111, 211, 683, 838, 1298, 16, 377, 318, 330, 1455, 1918, 8206, 249, 210, 16, 176, 1180, 9863, 16, 189, 2724, 16, 5913, 43, 2335, 2156, 190, 14996, 18, 3], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]})

In [8]:
encode_loader = DataLoader(
    encode_dataset,
    batch_size=4,
    collate_fn=EncodeCollator(
        tokenizer,
        max_length=text_max_length,
        padding='max_length'
    ),
    shuffle=False,
    drop_last=False,
)

In [9]:
encoded = []
lookup_indices = []
model.eval()
device = "cpu"

for (batch_ids, batch) in tqdm(encode_loader):
    lookup_indices.extend(batch_ids)
    with torch.no_grad():
        for k, v in batch.items():
            batch[k] = v.to(device)
        if encode_is_qry:
            model_output: EncoderOutput = model(query=batch)
            output = model_output.q_reps.cpu().detach().numpy()
        else:
            model_output: EncoderOutput = model(passage=batch)
            output = model_output.p_reps.cpu().detach().numpy()

    encoded += list(map(process_output, output))


  0%|          | 0/5 [00:04<?, ?it/s]


TypeError: 'NoneType' object is not callable

In [None]:
model.pooler