# Fine-tuning prototype

Goal: Fine-tune Long-t5


In [None]:
import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from transformers import LongT5ForConditionalGeneration, AutoTokenizer
from askem.data import get_covid_qa

model_name = "google/long-t5-tglobal-base"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = get_covid_qa(split="test")
model = LongT5ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device).half()

In [None]:
def generate_answer(batch) -> dict:
    """A function to generate answer from batch dataset.
    Args:
        batch: A hugging face dataset.

    Returns:
        batch: A hugging face dataset with predicted answers.
    """

    def _to_input(question, context):
        return f"question: {question} context: {context}"
    
    input_text = [_to_input(q, c) for q, c in zip(batch['question'], batch['context'])]

    tokenized = tokenizer(
        input_text,
        max_length=16384,
        truncation=True,
        return_tensors="pt",
        padding=True,
    )
    
    outputs = model.generate(
        input_ids=tokenized.input_ids.to(device),
        attention_mask=tokenized.attention_mask.to(device),
        max_length=512,
    )
    batch['predicted_answer'] = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return batch
    

In [None]:
results = dataset.map(
        generate_answer,
        batched=True,
        batch_size=4,
        remove_columns=["question", "context"],
    )

In [None]:
from bert_score import score
y_true = [a['text'][0] for a in results['answers']]
y_pred = results['predicted_answer']
precision, recall, f1 = score(y_true, y_pred, lang="en", verbose=True)
print(f"Precision: {precision.mean():.3f}")

In [None]:
import pandas as pd

df = pd.DataFrame(
    {
        "id": results['id'],
        "y_true": y_true,
        "y_pred": y_pred,
        "precision": precision.numpy(),
        "recall": recall.numpy(),
        "f1": f1.numpy()
    }
)

df.to_parquet("data/long_t5_zero_shot.parquet")


In [None]:
import pandas as pd
df = pd.read_parquet("/askem/data/longt5_zeroshot.parquet")