In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from pprint import pprint

In [None]:
import torch
import math
import transformers
from transformers import (
    RobertaConfig,
    RobertaModel,
    AutoTokenizer,
    pipeline,
    AutoModel,
    RobertaTokenizerFast,
    RobertaForQuestionAnswering
)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
import random

In [None]:
from typing import List, Optional, Dict, Union, Tuple

In [None]:
# DEFINE THE MODEL

configuration = RobertaConfig()
configuration.vocab_size = 65536
configuration.bos_token_id = 0
configuration.device = "cpu"
# configuration.pad_token_id = 1
configuration.eos_token_id = 2
configuration.pad_token_id = 0
pprint(configuration)

# Fine-Tuning for QA

## Dataset prep

In [None]:
# !!pip install Dataset

### Load Custom dataset

In [None]:
from datasets import Dataset

In [None]:
import functools

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
import re

In [None]:
def load_askathon_clean(path: str) -> pd.DataFrame:
    data = pd.read_csv(path)
    data = data.drop(columns=["Email Address"]).reset_index(drop=True)
    data.rename(columns={
        data.columns[0] : "context",
        data.columns[1]: "id",
        data.columns[2]: "source",
        data.columns[3]: "topics",
        data.columns[4]: "q1",
        data.columns[5]: "a1",
        data.columns[6]: "q2",
        data.columns[7]: "a2",
        data.columns[8]: "q3",
        data.columns[9]: "a3",
        data.columns[10]: "q4",
        data.columns[11]: "a4",
        data.columns[12]: "q5",
        data.columns[13]: "a5"
    }, inplace=True)
    data.drop(columns=["source", "topics"], inplace=True)
    return data

In [None]:
def create_qa_dataset(data: pd.DataFrame) -> pd.DataFrame:
    res = []
    q_keys = [f"q{i}" for i in range(1, 6)]
    a_keys = [f"a{i}" for i in range(1, 6)]
    
    def _index_fn(context: str, answer: str) -> int:
        try:
            return context.lower().index(answer.rstrip(" ,.!?").lower())
        except ValueError:
            return -1
    
    for _df in data.itertuples():
        tmp = []
        context = _df.context.strip()
        for qk, ak in zip(q_keys, a_keys):
            q, a = getattr(_df, qk), getattr(_df, ak)
            
            if not isinstance(a, str):
                continue
            idx = _index_fn(context, a)
            if idx > -1:
                tmp.append(dict(
                    id="".join(re.split(r"[ :/]", _df.id)),
                    context=context,
                    question=q,
                    answer_text=a,
                    answer_start=idx,
                ))
        res.extend(tmp)
    return pd.DataFrame(res)

In [None]:
data_qa = create_qa_dataset(load_askathon_clean("data/qa/Askathon Cleaned responses - Form Responses 1.csv"))

In [None]:
# (max(data_qa["context"], key=lambda x: len(x.split())))

In [None]:
# will be easier for downstream preprocessing
data_qa["answers"] = data_qa[["answer_text", "answer_start"]]\
.apply(lambda r: dict(text=[r[0]], answer_start=[r[1]]), axis=1)

In [None]:
data_qa.head()

In [None]:
data_qa_train, data_qa_test = train_test_split(data_qa, test_size=0.2)

In [None]:
data_qa_train.shape, data_qa_test.shape

### Preprocess for training

- tokenization
- chunking
- etc

References:
- https://huggingface.co/docs/transformers/tasks/question_answering
- https://github.com/AmitNikhade/Kaggle/blob/main/chaii%20-%20Hindi%20and%20Tamil%20Question%20Answering/question-answering-roberta-starter-explained.ipynb

In [None]:
tokenizer = AutoTokenizer.from_pretrained("data/nasawiki-v6/")
# tokenizer = AutoTokenizer.from_pretrained("data/sq2-v6/train-watbertv6-squad-2ep/")

In [None]:
tokenizer.max_len_single_sentence, tokenizer.padding_side

In [None]:
# tokenizer.padding_side == "right"

In [None]:
def preprocess_function_1(examples, tokenizer, max_length=382, stride=128):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
#     examples["question"] = [q.lstrip() for q in examples["question"]]
    
    pad_on_right = tokenizer.padding_side == "right"

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [None]:
def preprocess_function_2(examples, tokenizer, max_length=384, stride=128):
#     questions = list(map(lambda x: x["question"], examples))
    inputs = tokenizer(
        examples["question"],
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_mapping = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        
        input_ids = inputs["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        
        sequence_ids = inputs.sequence_ids(i)
        
        sample_index = sample_mapping[i]

        answer = answers[sample_index]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])


        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [None]:
# preprocess_function(data_qa_train.iloc[:2, :])

In [None]:
train_dataset = Dataset.from_pandas(data_qa_train)
test_dataset = Dataset.from_pandas(data_qa_test)

In [None]:
len(train_dataset), len(test_dataset)

In [None]:
tokenized_trains = train_dataset.map(
    functools.partial(preprocess_function_2, tokenizer=tokenizer, max_length=384, stride=128),
    batched=True,
    remove_columns=train_dataset.column_names
)

In [None]:
tokenized_tests = test_dataset.map(
    functools.partial(preprocess_function_2, tokenizer=tokenizer, max_length=384, stride=128),
    batched=True,
    remove_columns=test_dataset.column_names
)

In [None]:
tokenized_trains, tokenized_tests

## Train

In [None]:
import pathlib

In [None]:
from transformers import TrainingArguments, Trainer, default_data_collator
from transformers import AutoModelForQuestionAnswering

In [None]:
import wandb

In [None]:
# model = AutoModelForQuestionAnswering.from_pretrained("data/sq2-v6/train-watbertv6-squad-2ep/")
model = AutoModelForQuestionAnswering.from_pretrained("data/nasawiki-v6/")

In [None]:
model

In [None]:
pathlib.Path(model.name_or_path).stem

In [None]:
wandb.login()

In [None]:
wandb.init(
    project="llm-test",
    entity="nish-test",
    tags=["qa", pathlib.Path(model.name_or_path).stem]
)

In [None]:
train_args = TrainingArguments(
    f"tmp/finetuned/qa/{pathlib.Path(model.name_or_path).stem}",
    evaluation_strategy = "epoch",
    save_strategy="epoch",
    save_total_limit=2,
    learning_rate=3e-5,
    warmup_ratio=0.1,
    gradient_accumulation_steps=8,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    weight_decay=0.01,
    report_to="wandb",
    logging_steps=1,
)

In [None]:
train_args.output_dir

In [None]:
trainer = Trainer(
    model,
    train_args,
    train_dataset=tokenized_trains,
    eval_dataset=tokenized_tests,
    data_collator=default_data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

## predict/evaluate

In [None]:
import evaluate
from rapidfuzz import fuzz

In [None]:
from jury import Jury

In [None]:
qa_pipe = pipeline(
    "question-answering",
    model="tmp/finetuned/qa/nasawiki-v6/checkpoint-30/",
    tokenizer="tmp/finetuned/qa/nasawiki-v6/checkpoint-30/",
)

In [None]:
qa_pipe.model

In [None]:
predictions = qa_pipe(
    list(map(lambda x: dict(context=x["context"], question=x["question"]), data_qa_train.to_dict("records")))
)

In [None]:
len(predictions)

In [None]:
predictions[:3]

In [None]:
def evaluate_fuzzy(gts: List[str], predictions: List[str]):
    _preprocess = lambda x: x.strip(" .,!?").lower()
    gts = list(map(_preprocess, gts))
    predictions = list(map(_preprocess, predictions))
    res = []
    for gt, pred in zip(gts, predictions):
        res.append(fuzz.token_set_ratio(gt, pred))
    return res

In [None]:
def evaluate_exact(gts: List[str], predictions: List[str]):
    _preprocess = lambda x: x.strip(" .,!?").lower()
    gts = list(map(_preprocess, gts))
    predictions = list(map(_preprocess, predictions))
    res = []
    for gt, pred in zip(gts, predictions):
        res.append((gt == pred)*100)
    return res

In [None]:
def evaluate_squad_1(gts: List[dict], predictions: List[str], squad_metric):
    predictions = list(map(lambda x: dict(prediction_text=x[1], id=str(x[0])), enumerate(predictions)))

    references = list(map(
        lambda x: {"answers": dict(
            answer_start=[x[0]["answer_start"]],
            text=[x[0]["answer_text"]],
        ), "id": x[1]["id"]},
        zip(gts, predictions)
    ))
    
    print(references[0])
    print(predictions[0])
    
    return squad_metric.compute(
        predictions=predictions,
        references=references,
    )

In [None]:
def evaluate_squad_2(gts: List[dict], predictions: List[dict], squad_metric):

    references = list(map(
        lambda x: {"answers": dict(
            answer_start=[x["answer_start"]],
            text=[x["answer_text"]],
        ), "id": x["id"]},
        gts
    ))
    
    print(references[0])
    print(predictions[0])
    return squad_metric.compute(
        predictions=predictions,
        references=references,
    )

In [None]:
res = evaluate_exact(
    data_qa_train["answer_text"].to_list(),
    list(map(lambda p: p["answer"], predictions)),
)
print(np.mean(res))
sns.boxplot(res)
plt.title("Ground truth vs prediction exact match")

In [None]:
res = evaluate_fuzzy(
    data_qa_train["answer_text"].to_list(),
    list(map(lambda p: p["answer"], predictions)),
)
print(np.mean(res))
sns.boxplot(res)
plt.title("Ground truth vs prediction fuzzy match")

In [None]:
evaluate_squad_1(
    gts=data_qa_test.to_dict("records"),
    predictions=list(map(lambda p: p["answer"], predictions)),
    squad_metric=evaluate.load("squad"),
)

In [None]:
evaluate_squad_2(
    gts=data_qa_test.to_dict("records"),
    predictions=list(map(lambda p: dict(prediction_text=p[0]["answer"], id=p[1]), zip(predictions, data_qa_test["id"]))),
    squad_metric=evaluate.load("squad"),
)

In [None]:
# evaluate_squad_3(
#     gts=data_qa_test.to_dict("records"),
#     predictions=list(map(lambda p: p["answer"], predictions)),
#     squad_metric=evaluate.load("squad"),
# )

In [None]:
_peek_predictions(
    gts=data_qa_test.to_dict("records"),
    predictions=list(map(lambda p: dict(prediction_text=p[0]["answer"], id=p[1]), zip(predictions, data_qa_test["id"]))),
    index=0
)

In [None]:
import jury

In [None]:
_preprocess_fn = lambda x: x.strip(" .,!?").lower()

In [None]:
Jury(metrics=["exact_match", "bleu", "squad"])(
    predictions=list(map(lambda p: p["answer"], predictions)),
    references=data_qa_train["answer_text"].to_list(),
)

In [None]:
Jury(metrics=["exact_match", "bleu", "squad"])(
    predictions=list(map(lambda p: _preprocess_fn(p["answer"]), predictions)),
    references=list(map(_preprocess_fn, data_qa_train["answer_text"].to_list())),
)

In [None]:
Jury(metrics=["squad"])(
    predictions=list(map(lambda p: p["answer"], predictions)),
    references=data_qa_test["answer_text"].to_list(),
)