## T5 Question Answering Inference (PyTorch)

[T5](https://arxiv.org/abs/1910.10683) is a recent approach to do doing sequence to sequence modeling that specifically required input text, and output text, also called a *text-to-text* approach. I've been deeply interested in this model the moment I read about it.

I believe that the combination of *text-to-text* as a universal interface for NLP tasks paired multi-task learning (single model learning multiple tasks) will have huge impact on how deep learning is applied in practice. This competition is my first attemt at utilizing T5 for a real world dataset so I hope it helps you guys use it for your own purposes!

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
import torch
import pandas as pd

In [None]:
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
    """
    Returns list[torch.tensor] of tokenized outputs from the input file separated per line
    """
    examples = []
    with open(data_path, "r") as f:
        for text in f.readlines():
            tokenized = tokenizer.batch_encode_plus(
                [text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
            )
            # We keep dimension 0 as a singleton since `model.generate` requires dimensionality of (BS x SL)
            examples.append(tokenized['input_ids']) # 1 x SL
    return examples

def get_span_from_ids(input_ids, t5):
    whole_input_str = tokenizer.decode(input_ids.squeeze())
    input_str = whole_input_str.split('context: ')[-1]
    question_str = whole_input_str.split('context: ')[0]
    
    # Return whole input string if neutral
    if "Which section was neutral?" in whole_input_str:
        #print('Neutral found! Returning input string ...')
        return input_str.strip()
    #print(input_str)
    
    # Predict
    generated_ids = t5.generate(
        input_ids=input_ids,
        num_beams=1,
        max_length=80,
        repetition_penalty=2.5
    ).squeeze()
    predicted_span = tokenizer.decode(generated_ids)
    # Make sure that the predicted span only has words contained in the context input
    input_str_list = input_str.split()
    predicted_span_list = predicted_span.split()
    predicted_span_filtered = " ".join([s for s in predicted_span_list if s in input_str_list])
    return predicted_span_filtered

def process_span(pred_span, input_ids):
    whole_input_str = tokenizer.decode(input_ids.squeeze())
    input_str = whole_input_str.split('context:')[-1].strip()
    question_str = whole_input_str.split('context:')[0].strip()
    
    if "question: neutral" in whole_input_str:
        #print('Neutral found! Returning input string ...')
        final_span = input_str
    else:
        input_str_list = input_str.split()
        predicted_span_list = pred_span.split()
        predicted_span_filtered = " ".join([s for s in predicted_span_list if s in input_str_list])
        # Simple heuristic given that blank answers are typically for short contexts
        if predicted_span_filtered != '':
            final_span = predicted_span_filtered
        else:
            final_span = input_str

    return final_span.replace(' ⁇ ', '`').replace('"', '')

def get_span_from_ids_batch(input_ids, t5):
    """
    Returns batch of predicted spans (str)
    """
    generated_ids = t5.generate(
        input_ids=input_ids,
        num_beams=4,
        max_length=80,
        length_penalty=2,
        early_stopping=True,
        #repetition_penalty=2.5,
    )
    predicted_spans = [tokenizer.decode(ids) for ids in generated_ids]
    return predicted_spans

def post_process(selected):
    return " ".join(set(selected.lower().split()))

## Prepare data

In [None]:
test = pd.read_csv('/kaggle/input/tweet-sentiment-extraction/test.csv')#.iloc[:200]
processed_input_test = ("question: " + test.sentiment + " context: " + test.text)
processed_input_str_test = '\n'.join(processed_input_test.values.tolist())

with open('../working/test.source', 'w') as f:
    f.write(processed_input_str_test)


In [None]:
!head -20 ../working/test.source

## Read the model

In [None]:
tokenizer = T5Tokenizer.from_pretrained('../input/t5-qa-training-short-question-pytorch/')
t5 = T5ForConditionalGeneration.from_pretrained('../input/t5-5-epochs-sentiment-extraction/')

## Read the data

In [None]:
# Note we don't do any padding so no sequence length constraint is applied
test_input_ids = encode_file(tokenizer, '../working/test.source', 80, pad_to_max_length=True, return_tensors='pt')

In [None]:
len(test_input_ids)

In [None]:
# Checking max len
#lens = [len(test_input_ids[i].squeeze()) for i in range(len(test_input_ids))]
#max(lens)

In [None]:
test_input_ids[6]

In [None]:
test_input_ids[6].size()

## Setup data as DataLoader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
input_ids_tensor = torch.cat(test_input_ids).to(device)
input_dataset = TensorDataset(input_ids_tensor)
input_dataloader = DataLoader(
    input_dataset,
    batch_size=16,
)
input_dataloader = iter(input_dataloader)

## Sample prediction

Careful, these go through the generator. Only run for testing

In [None]:
#generated_ids = t5.generate(
#    input_ids=next(input_dataloader)[0],
#    num_beams=1,
#    max_length=80,
#    #repetition_penalty=2.5
#)

In [None]:
#tokenizer.decode(generated_ids[2])

In [None]:
t5.to(device)
t5.eval()

In [None]:
for param in t5.parameters():
    param.requires_grad = False

## Make predictions

From 2 hours on CPU, forward pass takes ~5min on GPU w/ batch size of 16

In [None]:
test_preds = []
for ex in tqdm(input_dataloader):
    test_preds += get_span_from_ids_batch(ex[0], t5)

In [None]:
# Final processing
test_preds = [process_span(s, ids) for s, ids in zip(test_preds, test_input_ids)]

In [None]:
test_preds

In [None]:
#tokenizer.decode(tokenizer.encode('Cramps . . .'))

In [None]:
len(test_preds)

Save predictions

In [None]:
sub = pd.read_csv('../input/tweet-sentiment-extraction/sample_submission.csv')
sub.shape

In [None]:
sub['selected_text'] = test_preds
sub.selected_text = sub.selected_text.map(post_process)

In [None]:
sub.head()

In [None]:
sub.to_csv('submission.csv', index=False)