In [None]:
import os
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import datasets
import torch
import torch.nn as nn

In [None]:
dataset_name = "flax-sentence-embeddings/stackexchange_titlebody_best_and_down_voted_answer_jsonl"
ai_dataset = datasets.load_dataset(dataset_name, 'ai')['train']
ds_dataset = datasets.load_dataset(dataset_name, 'datascience')['train']
se_dataset = datasets.load_dataset(dataset_name, 'softwareengineering')['train']

In [None]:
combined_dataset = datasets.concatenate_datasets([ai_dataset, ds_dataset, se_dataset])
# Change column names
combined_dataset = combined_dataset.rename_column("title_body", "question")
combined_dataset = combined_dataset.rename_column("upvoted_answer", "answer")

In [None]:
del ai_dataset, ds_dataset, se_dataset

In [None]:
# Fine-tune T5 on the dataset
model_name = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [None]:
def preprocess_function(examples):
    inputs = [f"question: {q} context and answer: {a}" for q, a in zip(examples["question"], examples["answer"])]
    model_inputs = tokenizer(inputs, max_length=512, padding="max_length", truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["question"], max_length=512, padding="max_length", truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = combined_dataset.map(preprocess_function, batched=True, num_proc=4)

In [None]:
def split_train_validation(examples):
    train_size = int(0.9 * len(examples))
    return {"train": examples.select(range(train_size)), "validation": examples.select(range(train_size, len(examples)))}

In [None]:
split = split_train_validation(tokenized_datasets)
print(len(split['train']))
print(len(split['validation']))

In [None]:
training_args = Seq2SeqTrainingArguments(
    "test-t5",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    num_train_epochs=1,
    predict_with_generate=True,
)

In [None]:
trainer = Seq2SeqTrainer(
    model,
    training_args,
    train_dataset=split["train"],
    eval_dataset=split["validation"],
    tokenizer=tokenizer,
)

In [None]:
trainer.train()