In [2]:
!pip install transformers

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
model = model.to(device)

## Check Flan-T5 in Target Ranking

In [26]:
lm_instruction = 'Find <extra_id_0>.'
samples = [
    {"id": 1,
     "query": "Scientists are studying the quality of the water in the <extra_id_0>.",
     "facts": ["Taking samples of water is used for studying the quality of water.", "Scientists go to a lake once a month to take samples of water."],
     "targets": ["<extra_id_0> lake", "<extra_id_0> water"]},
    {"id": 2,
     "query": "Matter in solid phase has definite shape and <extra_id_0>.",
     "facts": ["Matter in the solid phase has definite shape.", "Matter in the solid phase has definite volume."],
     "targets": ["<extra_id_0> volume", "<extra_id_0> shape", "<extra_id_0> volume"]},
    {"id": 3,
     "query": "The <extra_id_0> is opaque.",
     "facts": ["If an object is opaque , then light will not shine through that object.", "Opacity is a property of an object and includes ordered values of opaque / translucent / transparent.", "The light cannot shine through an object."],
     "targets": ['<extra_id_0> object', '<extra_id_0> opacity']},
]

In [51]:
def encode_ex(query, retrieved_passages, tokenizer, instruction=''):
    def append_question(q, docs):
        return ['{}\n {} {}'.format(instruction, " ".join(docs), q[0])]

    text_passages = append_question(query, retrieved_passages)
    passage_ids, passage_masks = [], []
    p = tokenizer.batch_encode_plus(
        text_passages,
        max_length=512,
        padding='max_length',
        return_tensors='pt',
        truncation=True
    )
    passage_ids.append(p['input_ids'][None])
    passage_masks.append(p['attention_mask'][None])
    passage_ids = torch.cat(passage_ids, dim=0)
    passage_masks = torch.cat(passage_masks, dim=0).bool()
    passage_ids, passage_masks = passage_ids.squeeze(1), passage_masks.squeeze(1)
    return passage_ids, passage_masks, text_passages

def encode_target(targets, tokenizer):
    target = tokenizer.batch_encode_plus(
        targets,
        max_length=200,
        padding=True,
        return_tensors='pt',
        truncation=True
    )
    target_ids = target["input_ids"]
    target_mask = target["attention_mask"].bool()
    target_ids = target_ids.masked_fill(~target_mask, -100)
    return target_ids, target_mask

In [42]:
def target_ranking(sample, flan, tokenizer):
    targets = sample["targets"]
    alt_num = len(targets)
    target_losses = torch.zeros(alt_num)
    query = sample["query"]

    context_ids, context_mask, new_q = encode_ex([query], sample["facts"], tokenizer, instruction=lm_instruction)
    label_ids, _ = encode_target(targets, tokenizer)
    for alt_i in range(alt_num):
        labels_output = flan(input_ids=context_ids.to(device), attention_mask=context_mask.to(device), labels=label_ids[alt_i].unsqueeze(0).to(device))
        target_losses[alt_i] = labels_output[0]
    predicted_alt = torch.argmin(target_losses)
    return new_q, predicted_alt, target_losses

In [52]:
for sample in samples:
    new_q, pred, loss = target_ranking(sample, model, tokenizer)
    print("{} {})\tquery: {}\n\ttargets: {}\n\t\tpred:\t{} (loss: {}),\n\t\tanswer:\t{} (loss: {})\n".format("+" if pred == 0 else "-", sample["id"], new_q, sample["targets"], sample["targets"][pred], loss[pred], sample["targets"][0], loss[0]))

+ 1)	query: ['Find <extra_id_0>.\n Taking samples of water is used for studying the quality of water. Scientists go to a lake once a month to take samples of water. Scientists are studying the quality of the water in the <extra_id_0>.']
	targets: ['<extra_id_0> lake', '<extra_id_0> water']
		pred:	<extra_id_0> lake (loss: 16.30898094177246),
		answer:	<extra_id_0> lake (loss: 16.30898094177246)

- 2)	query: ['Find <extra_id_0>.\n Matter in the solid phase has definite shape. Matter in the solid phase has definite volume. Matter in solid phase has definite shape and <extra_id_0>.']
	targets: ['<extra_id_0> volume', '<extra_id_0> shape', '<extra_id_0> volume']
		pred:	<extra_id_0> shape (loss: 17.38692283630371),
		answer:	<extra_id_0> volume (loss: 18.05997657775879)

- 3)	query: ['Find <extra_id_0>.\n If an object is opaque , then light will not shine through that object. Opacity is a property of an object and includes ordered values of opaque / translucent / transparent. The light can