In [None]:
! pip install datasets transformers

In [None]:
import transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
from datasets import load_dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base").to(device)

In [None]:
datasets = load_dataset("squad")

In [None]:
def preprocess_and_tokenize(examples):
  input_sequences = []
  references = []

  for example in examples:
    context = example["context"]
    question = example["question"]
    answer = example["answers"]["text"][0]

    sequence = context + " " + question
    input_sequences.append(sequence)
    references.append(answer)

  input_sequences_and_references = [(sequence, reference) for sequence, reference in zip(input_sequences, references) if is_shorter_than_512(sequence)]
  input_sequences, references = zip(*input_sequences_and_references)

  tokenized_input_sequences = tokenizer(
      input_sequences,
      max_length=512,
      padding=True,
      return_tensors="pt"
  )

  input_ids, attention_masks = tokenized_input_sequences.input_ids, tokenized_input_sequences.attention_mask

  labels = tokenizer(references, padding="longest", return_tensors="pt").input_ids
  labels[labels == tokenizer.pad_token_id] = -100

  return input_ids, attention_masks, labels, references

def is_shorter_than_512(sequence):
  inputs = tokenizer(sequence, truncation=False)
  return len(inputs.input_ids) <= 512

# Training

In [None]:
from datasets import Dataset
input_ids, attention_masks, labels, _ = preprocess_and_tokenize(datasets["train"])

Token indices sequence length is longer than the specified maximum sequence length for this model (516 > 512). Running this sequence through the model will result in indexing errors


In [None]:
from transformers import AdamW, get_scheduler
from tqdm.auto import tqdm

max_source_length = 512
max_target_length = 128

batch_size = 4

num_training_steps = len(input_ids) // batch_size

progress_bar = tqdm(range(num_training_steps))

optimizer = AdamW(model.parameters(), lr=3e-5)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

for i in range(0, num_training_steps * batch_size, batch_size):
  loss = model(input_ids=input_ids[i:i+batch_size].to(device), attention_mask=attention_masks[i:i+batch_size].to(device), labels=labels[i:i+batch_size].to(device)).loss

  loss.backward()
  optimizer.step()
  lr_scheduler.step()
  optimizer.zero_grad()

  print(loss.item())
  progress_bar.update(1)

# Inference

In [None]:
evaluation_input_ids, _, _, references = preprocess_and_tokenize(datasets["validation"])
prediction_ids = model.generate(evaluation_input_ids[:1000].to(device))
predictions = [tokenizer.decode(prediction_id, skip_special_tokens=True) for prediction_id in prediction_ids]
predictions = ["" if len(prediction)== 0 else prediction.lower() if prediction[-1] != '.' else prediction[:-1].lower() for prediction in predictions]
references = [reference.lower() if reference[-1] != '.' else reference[:-1].lower() for reference in references[:1000]]

# Evaluation

In [None]:
# fixes some wierd bug for below pip install evaluate
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
%pip install evaluate
%pip install git+https://github.com/google-research/bleurt.git
%cd bleurt
%pip install .

In [None]:
import evaluate
import numpy as np

bleu = evaluate.load("bleu")
bleurt = evaluate.load("bleurt", module_type="metric")

bleu_references = [[reference] for reference in references.copy()]

bleu1_results = bleu.compute(predictions=predictions, references=bleu_references, max_order=1)
bleu2_results = bleu.compute(predictions=predictions, references=bleu_references, max_order=2)
bleu3_results = bleu.compute(predictions=predictions, references=bleu_references, max_order=3)
bleu4_results = bleu.compute(predictions=predictions, references=bleu_references, max_order=4)

bleurt_results = bleurt.compute(predictions=predictions, references=references)

print(bleu1_results)
print(bleu2_results)
print(bleu3_results)
print(bleu4_results)
print(np.mean(bleurt_results["scores"]))

for i in range(10):
  print(predictions[i], "|", references[i], "|", bleurt.compute(predictions=predictions[i:i+1], references=references[i:i+1])["scores"][0])

In [None]:
from evaluate import load
exact_match_metric = load("exact_match")
results = exact_match_metric.compute(predictions=predictions, references=references)
print(results)

In [None]:
def calculate_precision(prediction_tokens, reference_tokens):
    prediction_set = set(prediction_tokens)
    reference_set = set(reference_tokens)

    # Calculate the number of common elements (intersection)
    common_tokens = prediction_set.intersection(reference_set)
    num_common = len(common_tokens)

    # Calculate precision and recall
    if len(prediction_set) == 0:
        precision = 0
    else:
        precision = num_common / len(prediction_set)

    return precision

def calculate_recall(prediction_tokens, reference_tokens):
    prediction_set = set(prediction_tokens)
    reference_set = set(reference_tokens)

    # Calculate the number of common elements (intersection)
    common_tokens = prediction_set.intersection(reference_set)
    num_common = len(common_tokens)

    if len(reference_set) == 0:
        recall = 0
    else:
        recall = num_common / len(reference_set)

    return recall

def calculate_f1_score(prediction_tokens, reference_tokens):
    precision = calculate_precision(prediction_tokens, reference_tokens)
    recall = calculate_recall(prediction_tokens, reference_tokens)

    if (precision + recall) == 0:
        f1_score = 0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)

    return f1_score

In [None]:
precisions = []
recalls = []
f1s = []
for i in range(100):
  precisions.append(calculate_precision(tokenizer(predictions[i])["input_ids"], tokenizer(references[i])["input_ids"]))
  recalls.append(calculate_recall(tokenizer(predictions[i])["input_ids"], tokenizer(references[i])["input_ids"]))
  f1s.append(calculate_f1_score(tokenizer(predictions[i])["input_ids"], tokenizer(references[i])["input_ids"]))
print(np.mean(precisions))
print(np.mean(recalls))
print(np.mean(f1s))

In [None]:
pip install rouge_score

In [None]:
rouge = evaluate.load('rouge')
results = rouge.compute(predictions=predictions,references=references)
print(results)