# Week 40 - Classifying Span containing Answer

In [None]:
import gc
import os
import polars as pl
import torch

# Huggingface imports
from transformers import (
    AutoTokenizer,
    BertForTokenClassification,
    Seq2SeqTrainingArguments,
    Trainer,
)

# Import own modules
from bert_utils import (
    bio_sequence_labeler,
    get_results,
    display_results,
)

from datasets import load_dataset, Dataset

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
    device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

# Arabic, Telegu and Korean
df_train = df_train.filter(pl.col("lang").is_in(["ar", "te", "ko"]))
df_val = df_val.filter(pl.col("lang").is_in(["ar", "te", "ko"]))

In [None]:
mbert_checkpoint = "bert-base-multilingual-uncased"
mbert_tokenizer = AutoTokenizer.from_pretrained(mbert_checkpoint)

### Create label columns
Note that we cut off at 512 tokens. Two instances will have their answer after 512. We do this for simplicity, since the model has 512 token max, and we expect the model to generalize well without these two points.

In [None]:
# Create BIO label column for train and val
df_train = df_train.with_columns(
    pl.struct(["question", "context", "answer_start", "answer"]).map_elements(
        lambda x: bio_sequence_labeler(
            x["answer_start"],
            x["answer"],
            x["question"],
            x["context"],
            mbert_tokenizer,
        ),
        return_dtype=pl.List(pl.Int8)
    ).alias("labels")
)
df_val = df_val.with_columns(
    pl.struct(["question", "context", "answer_start", "answer"]).map_elements(
        lambda x: bio_sequence_labeler(
            x["answer_start"],
            x["answer"],
            x["question"],
            x["context"],
            mbert_tokenizer,
        ),
        return_dtype=pl.List(pl.Int8)
    ).alias("labels")
)

In [None]:
def prepare_data(df: pl.DataFrame) -> Dataset:
    # Convert Polars to dict format for HF datasets
    data_dict = {
        "question": df["question"].to_list(),
        "context": df["context"].to_list(),
        "labels": df["labels"].to_list(),
    }
    return Dataset.from_dict(data_dict)

def tokenize_function(examples: Dataset, tokenizer: AutoTokenizer) -> Dataset:
    # Tokenize with question and content separated by [SEP]
    # [CLS] is added automatically
    return tokenizer(
        examples["question"],
        examples["context"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    ) # type: ignore

# Prepare datasets
train_dataset = prepare_data(df_train)
val_dataset = prepare_data(df_val)
tokenized_train = train_dataset.map(lambda examples: tokenize_function(examples, mbert_tokenizer), batched=True)
tokenized_val = val_dataset.map(lambda examples: tokenize_function(examples, mbert_tokenizer), batched=True)


## Check model performance before fine-tuning

In [None]:
model = BertForTokenClassification.from_pretrained(mbert_checkpoint, num_labels=3)
val_set = df_val.sample(50, shuffle=True) # Just to see poor performance quickly
display_results(*get_results(model, val_set, mbert_tokenizer))

In [None]:
# Weird memory issue fix
#model = None # before GC
#gc.collect()
#with torch.no_grad():
#    torch.cuda.empty_cache()

In [None]:
# Training cell
output_path = "./mbert-iob"
args = Seq2SeqTrainingArguments(
    # Memory efficient params
    fp16=False,
    auto_find_batch_size=True,

    output_dir=output_path,
    overwrite_output_dir = True,
    learning_rate=2e-5,

    num_train_epochs=5,
    weight_decay=0.01, # L2 regularization term
    generation_max_length=512,

    save_total_limit=2,
    save_strategy = "best",
    load_best_model_at_end = True,
    
    logging_strategy="epoch",
    eval_strategy = "epoch",
    log_level="info",
    report_to=[],
    logging_dir=None,
)

In [None]:
# TODO: Ideally these are global vars up top.
model_path = os.path.join(os.getcwd(), "results", "answer_span_classifier")
save_path = os.path.join(model_path, "fine_tuned")
patience = 2

TRAIN = False # FLIP THIS TO TRAIN
from transformers import EarlyStoppingCallback
if not os.path.exists(model_path):
    print(f"No classifiers folder found, creating {model_path}...")
    os.makedirs(model_path)
if not TRAIN:
    if os.path.exists(save_path):
        print(f"Model already exists at {save_path}, loading model...")
        model = BertForTokenClassification.from_pretrained(save_path, num_labels=3)
        mbert_tokenizer = AutoTokenizer.from_pretrained(save_path)
    else:
        print("No model found.")
if TRAIN:
    print(f"No model found at {save_path}, training model...")
    model = BertForTokenClassification.from_pretrained(mbert_checkpoint, num_labels=3)
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=mbert_tokenizer, # type: ignore
        callbacks = [EarlyStoppingCallback(patience)]
    )
    trainer.train()
    print("Training completed.")
    model.save_pretrained(save_path) # type: ignore
    mbert_tokenizer.save_pretrained(save_path) # type: ignore
    print("Model saved.")

## First get a taste

In [None]:
smaller_val_set = df_val.sample(15, shuffle=True) # get 50 random samples
y_true, y_pred = get_results(model, smaller_val_set, mbert_tokenizer)
display_results(y_true, y_pred)

## Get Confusion Matrix for each language

In [None]:
df_val_ar = df_val.filter(pl.col("lang") == "te")
df_val_ko = df_val.filter(pl.col("lang") == "ko")
df_val_te = df_val.filter(pl.col("lang") == "te")

df_val_ar = df_val_ar.with_columns(
    pl.struct(["question", "context", "answer_start", "answer"]).map_elements(
        lambda x: bio_sequence_labeler(x["answer_start"],x["answer"],x["question"],x["context"],
        mbert_tokenizer),
    return_dtype=pl.List(pl.Int8)).alias("labels"))
df_val_ko = df_val_ko.with_columns(
    pl.struct(["question", "context", "answer_start", "answer"]).map_elements(
        lambda x: bio_sequence_labeler(x["answer_start"],x["answer"],x["question"],x["context"],
        mbert_tokenizer),
    return_dtype=pl.List(pl.Int8)).alias("labels"))
df_val_te = df_val_te.with_columns(
    pl.struct(["question", "context", "answer_start", "answer"]).map_elements(
        lambda x: bio_sequence_labeler(x["answer_start"],x["answer"],x["question"],x["context"],
        mbert_tokenizer),
    return_dtype=pl.List(pl.Int8)).alias("labels"))

y_true, y_pred = get_results(model, df_val_ar, mbert_tokenizer)
display_results(y_true, y_pred, title="Arabic")
y_true, y_pred = get_results(model, df_val_ko, mbert_tokenizer)
display_results(y_true, y_pred, title="Korean")
y_true, y_pred = get_results(model, df_val_te, mbert_tokenizer)
display_results(y_true, y_pred, title="Telugu")

## Get full Confusion Matrix for fully trained model

In [None]:
#y_true, y_pred = get_results(model, df_val, mbert_tokenizer)
#display_results(y_true, y_pred)