In [3]:
import re
import torch
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bigscience/T0_3B")
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0_3B")
inputs = tokenizer.encode("Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy", return_tensors="pt")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bigscience/T0_3B")
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0_3B")

In [26]:
inputs = tokenizer.encode("Context: This banana is very sweet. Question: What color is the banana? Answer:", return_tensors="pt")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))

<pad> It is yellow.</s>


In [27]:
inputs = tokenizer.encode("""Given the context "This banana is very sweet." and the question "What color is the banana?", write an answer""", return_tensors="pt")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))

<pad> yellow</s>


In [25]:
text = """Given the context "Bananas are very sweet, but they can be the cause of tropical diseases." and the conversation "Amy: I like bananas. Chris: Why? Amy: Because they are very sweet! Chris: I don't like bananas. Amy: Why? Chris:", write a follow up conversation"""
inputs = tokenizer.encode(text, return_tensors="pt")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))

<pad> Amy likes bananas because they are very sweet. Chris doesn't like bananas.


The prompt is important. Example:

Given the context "Bananas are very sweet, but they can be the cause of tropical diseases." and the conversation "Amy: I like bananas. Chris: Why? Amy: Because they are very sweet! Chris: I don't like bananas. Amy: Why? Chris:", **write Chris' answer**

**Answer : <pad> Amy doesn't like bananas because they can be the cause of tropical diseases.</s>**


Given the context "Bananas are very sweet, but they can be the cause of tropical diseases." and the conversation "Amy: I like bananas. Chris: Why? Amy: Because they are very sweet! Chris: I don't like bananas. Amy: Why? Chris:", **write a follow up conversation**

**Answer : <pad> Amy likes bananas because they are very sweet. Chris doesn't like bananas.**


Given the context "Bananas are very sweet, but they can be the cause of tropical diseases." and the conversation "Amy: I like bananas. Chris: Why? Amy: Because they are very sweet! Chris: I don't like bananas. Amy: Why?", **what did Chris say?**

**Answer : <pad> Chris doesn't like bananas because they can be the cause of tropical diseases.</s>**

In [8]:
dataset = json.load(open("train-v2.0.json"))

In [104]:
def get_answerable_single_example(index):
    """index is the index of the unanswerable question you want to extract, in the order given by the dataset
       1 gives the first unanswerable question found, 100 the 100th one."""
    count = 0
    for data in tqdm(dataset["data"]):
        for paragraph in data["paragraphs"]:
            for question in paragraph["qas"]: #question is a dict
                if not question["is_impossible"]:
                    count += 1
                    if count == index:
                        return question, paragraph["context"]

def get_unanswerable_single_example(index):
    """index is the index of the unanswerable question you want to extract, in the order given by the dataset
       1 gives the first unanswerable question found, 100 the 100th one."""
    count = 0
    for data in tqdm(dataset["data"]):
        for paragraph in data["paragraphs"]:
            for question in paragraph["qas"]: #question is a dict
                if question["is_impossible"]:
                    count += 1
                    if count == index:
                        return question, paragraph["context"]

In [106]:
def infer(context, question):
    q = question["question"]
    text = "Context: "+context+". Question: "+q
    inputs = tokenizer.encode(text, return_tensors="pt")
    outputs = model.generate(inputs)
    print(text)
    return tokenizer.decode(outputs[0])

q, c = get_unanswerable_single_example(1)
infer(c, q)

  1%|          | 4/442 [00:00<00:00, 5211.93it/s]


Context: The Legend of Zelda: Twilight Princess (Japanese: ゼルダの伝説 トワイライトプリンセス, Hepburn: Zeruda no Densetsu: Towairaito Purinsesu?) is an action-adventure game developed and published by Nintendo for the GameCube and Wii home video game consoles. It is the thirteenth installment in the The Legend of Zelda series. Originally planned for release on the GameCube in November 2005, Twilight Princess was delayed by Nintendo to allow its developers to refine the game, add more content, and port it to the Wii. The Wii version was released alongside the console in North America in November 2006, and in Japan, Europe, and Australia the following month. The GameCube version was released worldwide in December 2006.[b]. Question: What category of game is Legend of Zelda: Australia Twilight?


'<pad> action-adventure</s>'

In [107]:
def get_answerable_single_example(index):
    """index is the index of the unanswerable question you want to extract, in the order given by the dataset
       1 gives the first unanswerable question found, 100 the 100th one."""
    count = 0
    for data in tqdm(dataset["data"]):
        for paragraph in data["paragraphs"]:
            for question in paragraph["qas"]: #question is a dict
                if not question["is_impossible"]:
                    count += 1
                    if count == index:
                        return question, paragraph["context"]

def get_unanswerable_single_example(index):
    """index is the index of the unanswerable question you want to extract, in the order given by the dataset
       1 gives the first unanswerable question found, 100 the 100th one."""
    count = 0
    for data in tqdm(dataset["data"]):
        for paragraph in data["paragraphs"]:
            for question in paragraph["qas"]: #question is a dict
                if question["is_impossible"]:
                    count += 1
                    if count == index:
                        return question, paragraph["context"]

q, c = get_unanswerable_single_example(1)
q1, c1 = get_answerable_single_example(2070)
q2, c2 = get_unanswerable_single_example(6000)
text = "Context: "+c+"\nQuestion: "+q["question"]+" Answer: "+infer(c, q)+"\nQuestion: "+q1["question"]+" Answer: "+q1["answers"][0]["text"]+"\n\n Context: "+c2+" Question: "+q2["question"]
inputs = tokenizer.encode(text, return_tensors="pt")
outputs = model.generate(inputs)
print(text)
tokenizer.decode(outputs[0])

  1%|          | 4/442 [00:00<00:00, 5300.86it/s]
  1%|          | 4/442 [00:00<00:00, 4908.49it/s]
 17%|█▋        | 73/442 [00:00<00:00, 12353.61it/s]


Context: The Legend of Zelda: Twilight Princess (Japanese: ゼルダの伝説 トワイライトプリンセス, Hepburn: Zeruda no Densetsu: Towairaito Purinsesu?) is an action-adventure game developed and published by Nintendo for the GameCube and Wii home video game consoles. It is the thirteenth installment in the The Legend of Zelda series. Originally planned for release on the GameCube in November 2005, Twilight Princess was delayed by Nintendo to allow its developers to refine the game, add more content, and port it to the Wii. The Wii version was released alongside the console in North America in November 2006, and in Japan, Europe, and Australia the following month. The GameCube version was released worldwide in December 2006.[b]. Question: What category of game is Legend of Zelda: Australia Twilight?
Context: The Legend of Zelda: Twilight Princess (Japanese: ゼルダの伝説 トワイライトプリンセス, Hepburn: Zeruda no Densetsu: Towairaito Purinsesu?) is an action-adventure game developed and published by Nintendo for the GameCube 

'<pad> 5%</s>'