In [1]:
!pip install -q transformers==3.0.2

[K     |████████████████████████████████| 778kB 6.1MB/s 
[K     |████████████████████████████████| 901kB 9.9MB/s 
[K     |████████████████████████████████| 1.2MB 19.8MB/s 
[K     |████████████████████████████████| 3.0MB 30.1MB/s 
[?25h

In [2]:
!wget https://dl.fbaipublicfiles.com/FiD/pretrained_models/nq_reader_base.tar.gz

--2021-04-21 14:40:35--  https://dl.fbaipublicfiles.com/FiD/pretrained_models/nq_reader_base.tar.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.74.142, 104.22.75.142, 172.67.9.4, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.74.142|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 825940399 (788M) [application/gzip]
Saving to: ‘nq_reader_base.tar.gz’


2021-04-21 14:41:01 (31.6 MB/s) - ‘nq_reader_base.tar.gz’ saved [825940399/825940399]



In [4]:
!tar -xf nq_reader_base.tar.gz

In [5]:
!git clone https://github.com/facebookresearch/FiD.git

Cloning into 'FiD'...
remote: Enumerating objects: 27, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 27 (delta 4), reused 23 (delta 4), pack-reused 3[K
Unpacking objects: 100% (27/27), done.


In [1]:
cd FiD

/content/FiD


In [2]:
import torch
import transformers
import numpy as np
from pathlib import Path
import torch.distributed as dist
from torch.utils.data import DataLoader, SequentialSampler


import src.slurm
import src.util
from src.options import Options
import src.data
import src.evaluation
import src.model

In [3]:
tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)

In [4]:
model_class = src.model.FiDT5
model = model_class.from_pretrained("../nq_reader_base/")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [5]:
example = {
    "question":'Who is PM of india?',
    "passages":["Narendra Modi PM of india",
                "I am PM of india",
                "We are from india"]
}
batch = [example]

In [9]:
def encode_passages(batch_text_passages, tokenizer, max_length):
    passage_ids, passage_masks = [], []
    for k, text_passages in enumerate(batch_text_passages):
        p = tokenizer.batch_encode_plus(
            text_passages,
            max_length=max_length,
            pad_to_max_length=True,
            return_tensors='pt',
            truncation=True
        )
        passage_ids.append(p['input_ids'][None])
        passage_masks.append(p['attention_mask'][None])

    passage_ids = torch.cat(passage_ids, dim=0)
    passage_masks = torch.cat(passage_masks, dim=0)
    return passage_ids, passage_masks.bool()

def append_question(example):
    if example['passages'] is None:
        return [example['question']]
    return [example['question'] + " " + t for t in example['passages']]

In [10]:
text_passages = [append_question(example) for example in batch]
text_maxlength = 128
passage_ids, passage_masks = encode_passages(text_passages,
                                              tokenizer,
                                              text_maxlength)

In [17]:
outputs = model.generate(
                input_ids=passage_ids,
                attention_mask=passage_masks,
                max_length=10,
                early_stopping=True,
                num_beams=3,
                top_k=1
            )

In [18]:
for k, o in enumerate(outputs):
  ans = tokenizer.decode(o, skip_special_tokens=True)
  print(ans)

Narendra Modi PM of India Narendra
