# Week 40 - Classifying Span containing Answer

In [None]:
import os
import polars as pl
import torch
import numpy as np

# Huggingface imports
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
from datasets import load_dataset, Dataset


In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

# Arabic, Telegu and Korean
df_ar_train = df_train.filter(pl.col("lang") == "ar")
df_ar_val = df_val.filter(pl.col("lang") == "ar")
df_te_train = df_train.filter(pl.col("lang") == "te")
df_te_val = df_val.filter(pl.col("lang") == "te")
df_ko_train = df_train.filter((pl.col("lang") == "ko") & (pl.col("answerable") == True))
df_ko_val = df_val.filter(pl.col("lang") == "ko")

# Make a dict
data = {
    "arabic": {"train": df_ar_train, "val": df_ar_val},
    "telegu": {"train": df_te_train, "val": df_te_val},
    "korean": {"train": df_ko_train, "val": df_ko_val},
}
assert df_ko_train.height == sum(df_ko_train["answerable"]), "All answers should be answerable"

In [None]:
df_ko_train["answer"][1]

In [None]:
tst_cont = df_ko_train["context"][1]
tst_answer = df_ko_train["answer"][1]
print("CONTEXT: ", tst_cont)
print("ANSWER: ", tst_answer)

mbert_checkpoint = "bert-base-multilingual-uncased"
mbert_tokenizer = AutoTokenizer.from_pretrained(mbert_checkpoint)
print("CONTEXT TOKEN: ", mbert_tokenizer(tst_cont, return_offsets_mapping=True)["input_ids"])
print("ANSWER TOKEN: ", mbert_tokenizer(tst_answer, return_offsets_mapping=True)["input_ids"])
print("ANSWER OFFSETS: ", mbert_tokenizer(tst_answer, return_offsets_mapping=True)["offset_mapping"])

In [None]:
num2bio = { 0: "O", 1: "B-ANS", 2: "I-ANS"}
bio2num = {"O": 0, "B-ANS": 1, "I-ANS": 2}

In [None]:
def sequence_labeler(
    context: str, 
    answer_start: int, 
    answer_text: str,
) -> list:
    """Enhanced version with validation."""
    
    if answer_start == -1:  # Unanswerable
        encoding = mbert_tokenizer(context, return_offsets_mapping=True)
        return [0] * len(encoding["input_ids"])
    
    answer_end = answer_start + len(answer_text) # if answer_text else answer_start + 1
    
    encoding = mbert_tokenizer(
        context,
        return_offsets_mapping=True,
        add_special_tokens=True,
        truncation=True,
        max_length=1024
    )
    
    tokens = encoding["input_ids"]
    offset_mapping = encoding["offset_mapping"]
    labels = np.zeros(len(tokens), dtype=np.int8)
    
    answer_token_indices = []
    for idx, (token_start, token_end) in enumerate(offset_mapping):
        if token_start == 0 and token_end == 0:
            continue
        if token_start < answer_end and token_end > answer_start:
            answer_token_indices.append(idx)


    if answer_token_indices:
        answer_tokens_ids = [tokens[i] for i in answer_token_indices]
        # print(f"Answer tokens: {(answer_tokens_ids)}")
        for i in range(len(tokens) - len(answer_tokens_ids) + 1): 
            if tokens[i:i+len(answer_tokens_ids)] == answer_tokens_ids:
                # print(f"Match found at token indices: {list(range(i, i+len(answer_tokens_ids)))}")
                # print(f"Corresponding context: '{context[max(0, offset_mapping[i][0] - 20) : offset_mapping[i+len(answer_tokens_ids)-1][1] + 20]}'")
                labels[i] = 1  # B-ANS
                for j in range(i+1, i+len(answer_tokens_ids)):
                    labels[j] = 2  # I-ANS
        
        #labels[answer_token_indices[0]] = 1
        #for idx in answer_token_indices[1:]:
        #    labels[idx] = 2

    elif answer_text:  # Validation mode
        print(f"WARNING: No tokens found for answer '{answer_text}' at position {answer_start}")
        print(f"Context: {context[max(0, answer_start-20):answer_start+len(answer_text)+20]}")
    
    return labels

In [None]:
df_ko_train = df_ko_train.with_columns(
    pl.struct(["context", "answer_start", "answer"]).map_elements(
        lambda x: sequence_labeler(x["context"], x["answer_start"], x["answer"]), return_dtype=pl.List(pl.Int8)
    ).alias("labels")
)
df_ko_val = df_ko_val.with_columns(
    pl.struct(["context", "answer_start", "answer"]).map_elements(
        lambda x: sequence_labeler(x["context"], x["answer_start"], x["answer"]), return_dtype=pl.List(pl.Int8)
    ).alias("labels")
)

In [None]:
def prepare_data(df: pl.DataFrame) -> Dataset:
    # Convert Polars to dict format for HF datasets
    data_dict = {
        "question": df["question"].to_list(),
        "context": df["context"].to_list(),
        "label": df["labels"]
    }
    return Dataset.from_dict(data_dict)

def tokenize_function(examples: Dataset, tokenizer: AutoTokenizer) -> Dataset:
    # Tokenize with question and content separated by [SEP]
    # [CLS] is added automatically
    return tokenizer(
        examples["question"],
        examples["context"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    ) # type: ignore

# Prepare datasets
train_dataset = prepare_data(df_ko_train)
val_dataset = prepare_data(df_ko_val)
tokenized_train = train_dataset.map(lambda examples: tokenize_function(examples, mbert_tokenizer), batched=True)
tokenized_val = val_dataset.map(lambda examples: tokenize_function(examples, mbert_tokenizer), batched=True)

In [None]:
from transformers import BertForTokenClassification
from transformers import Seq2SeqTrainingArguments, Trainer


model = BertForTokenClassification.from_pretrained(mbert_checkpoint, num_labels=3)

args = Seq2SeqTrainingArguments(
    output_dir="./mbert-iob",
    #evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=mbert_tokenizer,
)
trainer.train()


In [None]:
from transformers import BertForTokenClassification

def predict(
    question: pl.Series,
    context: pl.Series,
    model: BertForTokenClassification,
    tokenizer: AutoTokenizer
):
    """Get model prediction for a single example"""

    print(question)

    inputs = tokenizer(
        question, 
        context, 
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    ) # type: ignore

    # Move to GPU if available
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}
        model = model.cuda() # type: ignore
    else:
        inputs = {k: v.cpu() for k, v in inputs.items()}
        model = model.cpu() # type: ignore
    
    model.eval() # type: ignore
    with torch.no_grad():
        outputs = model(**inputs) # type: ignore
        logits = outputs.logits
        #probs = torch.softmax(logits, dim=1)
        #prediction = torch.argmax(logits, dim=1).item()
    
        probs = torch.softmax(logits, dim=2)  # Softmax over the last dimension (num_labels)
        prediction = torch.argmax(logits, dim=2).squeeze().tolist()  #

    return prediction

model = BertForTokenClassification.from_pretrained(mbert_checkpoint, num_labels=3)

predict(
    df_ko_val["question"][:10].to_list(), 
    df_ko_val["context"][:10].to_list(), 
    model,
    mbert_tokenizer,
)
    