<a href="https://colab.research.google.com/github/RosaMeyer/2023-lectures/blob/main/Week_4_QA_teluguContext_to_telugu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# QA Generation using Telugu questions and English contexts to generate the Telugu answer

Use the subset answer_inlang of the questions in Telugu to train (or fine-tune) a model to receive the Telugu question and English context as input and generate the Telugu answer.

TODO:
Use answer_inlang for lables and English answers.
Only use answer_inlnag for answerable questions that doesnt have an english answer or context.

## Imports

In [52]:
from utils import *

# !pip install evaluate
# %pip install sacrebleu

import os
import numpy as np
import torch
import random
from datasets import Dataset
import evaluate
import sacrebleu

from transformers import (
    MT5Tokenizer,
    MT5ForConditionalGeneration,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
)

## Get, filter and prep data

In [53]:
training_data = get_training_data()
validation_data = get_validation_data()

# Filtering for Telugu only
te_train = training_data[training_data['lang'] == 'te']
te_val = validation_data[validation_data['lang'] == 'te']

In [54]:
# The answer_inlang field is used as target/flag text only when available (for answerable questions).
# Use answer_inlang only as target (labels) if available; otherwise leave empty.
# Adds a column 'target_text' that is empty if answer_inlang=None.

te_train = te_train.copy()
te_val = te_val.copy()

# te_train['target_text'] = te_train['answer_inlang'].fillna("").apply(lambda x: x.strip())
# te_val['target_text'] = te_val['answer_inlang'].fillna("").apply(lambda x: x.strip())

# Use answer_inlang if available and non-empty, otherwise use English answer
te_train['target_text'] = te_train.apply(
    lambda row: row['answer_inlang'].strip() if pd.notna(row['answer_inlang']) and row['answer_inlang'].strip() else row['answer'],
    axis=1
)

te_val['target_text'] = te_val.apply(
    lambda row: row['answer_inlang'].strip() if pd.notna(row['answer_inlang']) and row['answer_inlang'].strip() else row['answer'],
    axis=1
)

## Model setup + load

We will be fine tuning google's MT5-small model to generate Telugu answers based on Telugu questions only as input.

In [55]:
MODEL_NAME = "google/mt5-small"
OUTPUT_DIR = "mt5_telugu_openqa"
MAX_SOURCE_LENGTH = 512
MAX_TARGET_LENGTH = 128
BATCH_SIZE = 8
NUM_EPOCHS = 5
LEARNING_RATE = 3e-4
SEED = 42

In [56]:
def set_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

In [57]:
# If running in colab or on a GPU-enabled machine, uncomment the line below
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "google/mt5-small"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.


In [58]:
# Ensure tokenizer has proper special tokens
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Update model config to match tokenizer
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.pad_token_id
model.config.forced_eos_token_id = tokenizer.eos_token_id

print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model vocab size: {model.config.vocab_size}")
print(f"Pad token ID: {tokenizer.pad_token_id}")
print(f"EOS token ID: {tokenizer.eos_token_id}")

# FIX THE MISMATCH - Resize model embeddings to match tokenizer
# Removes the last 12 token embeddings that the tokenizer doesn't know about
# Ensures the model and tokenizer speak the same "language" with the same vocabulary
model.resize_token_embeddings(len(tokenizer))
print(f"Resized model vocab to: {model.config.vocab_size}")

Tokenizer vocab size: 250100
Model vocab size: 250112
Pad token ID: 0
EOS token ID: 1
Resized model vocab to: 250100


## Defining preprocessing pipeline for our data

In [59]:
def make_input(example):
    q = example.get('question', '')
    ctx = example.get('context', '')
    return f"Question (Telugu): {q}\nContext (English): {ctx}\nAnswer (Telugu):"

def add_input_target(example):
    example['input_text'] = make_input(example)
    example['target_text'] = example.get('target_text', '')
    return example

def preprocess_fn(examples):
    inputs = examples['input_text']
    targets = examples['target_text']
    model_inputs = tokenizer(inputs, max_length=MAX_SOURCE_LENGTH, truncation=True)

    # Use text_target parameter for target tokenization
    labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

train_data = Dataset.from_pandas(te_train)
val_data = Dataset.from_pandas(te_val)

train_data = train_data.map(add_input_target)
val_data = val_data.map(add_input_target)

train_tokenized = train_data.map(
    preprocess_fn,
    batched=True,
    remove_columns=train_data.column_names
)

val_tokenized = val_data.map(
    preprocess_fn,
    batched=True,
    remove_columns=val_data.column_names
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8,
    padding=True,
    max_length=MAX_SOURCE_LENGTH
)

Map:   0%|          | 0/1355 [00:00<?, ? examples/s]

Map:   0%|          | 0/384 [00:00<?, ? examples/s]

Map:   0%|          | 0/1355 [00:00<?, ? examples/s]

Map:   0%|          | 0/384 [00:00<?, ? examples/s]

In [60]:
# Check tokenization

# VALIDATE tokenized data
def validate_tokens(dataset, name):
    print(f"\nValidating {name}...")
    for i, example in enumerate(dataset):
        input_ids = example['input_ids']
        labels = example['labels']

        max_input = max(input_ids)
        max_label = max([x for x in labels if x != -100])

        if max_input >= tokenizer.vocab_size:
            print(f"ERROR in example {i}: input_ids max = {max_input} >= vocab_size {tokenizer.vocab_size}")
            # Decode to see what text caused this
            print(f"Problematic input text: {train_data[i]['input_text'][:200]}")
            break

        if max_label >= tokenizer.vocab_size:
            print(f"ERROR in example {i}: labels max = {max_label} >= vocab_size {tokenizer.vocab_size}")
            print(f"Problematic target text: {train_data[i]['target_text'][:200]}")
            break

    print(f"{name} validation complete!")

validate_tokens(train_tokenized, "train")
validate_tokens(val_tokenized, "val")


Validating train...
train validation complete!

Validating val...
val validation complete!


In [64]:
# Check the actual tokenized data more carefully
print("Checking actual tokenized data...")
for i in range(min(5, len(train_tokenized))):
    example = train_tokenized[i]
    input_ids = example['input_ids']
    labels = example['labels']

    max_input = max(input_ids) if len(input_ids) > 0 else 0
    max_label = max([x for x in labels if x != -100]) if any(x != -100 for x in labels) else 0

    print(f"\nExample {i}:")
    print(f"  Input IDs: min={min(input_ids)}, max={max_input}, len={len(input_ids)}")
    print(f"  Labels: min={min(labels)}, max={max_label}, len={len(labels)}")

    if max_input >= tokenizer.vocab_size:
        print(f"  ⚠️ INVALID INPUT ID: {max_input} >= {tokenizer.vocab_size}")
        print(f"  Raw text: {train_data[i]['input_text'][:200]}")

    if max_label >= tokenizer.vocab_size:
        print(f"  ⚠️ INVALID LABEL ID: {max_label} >= {tokenizer.vocab_size}")
        print(f"  Raw text: {train_data[i]['target_text'][:200]}")

print(f"\nTokenizer vocab size: {len(tokenizer)}")

Checking actual tokenized data...

Example 0:
  Input IDs: min=1, max=230172, len=224
  Labels: min=1, max=5176, len=2

Example 1:
  Input IDs: min=1, max=220182, len=215
  Labels: min=1, max=78408, len=6

Example 2:
  Input IDs: min=1, max=213231, len=190
  Labels: min=1, max=68959, len=5

Example 3:
  Input IDs: min=1, max=200708, len=134
  Labels: min=1, max=733, len=2

Example 4:
  Input IDs: min=1, max=237614, len=269
  Labels: min=1, max=44201, len=4

Tokenizer vocab size: 250100


## Training steup - defining our evaluation metrics

We will use SacreBLEU and chrF loaded using the evaluate module

In [61]:
sacrebleu = evaluate.load("sacrebleu")
chrf = evaluate.load("chrf")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    return preds, labels

def compute_metrics(eval_pred):
    preds, labels = eval_pred

    # Replace -100 with pad_token_id BEFORE decoding
    # This is crucial - decoder can't handle -100
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Ensure preds are valid token IDs
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    bleu_score = sacrebleu.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])["score"]
    chrf_score = chrf.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])["score"]

    return {"sacrebleu": bleu_score, "chrf": chrf_score}

## Training the model

We finally train the model with a learning
rate of 0.0003, a weight decay of 0.01 and 5 training epochs.

In [62]:
print("Checking tokenization...")
sample = train_tokenized[0]
print(f"Input IDs range: {min(sample['input_ids'])} to {max(sample['input_ids'])}")
print(f"Label IDs range: {min([x for x in sample['labels'] if x != -100])} to {max([x for x in sample['labels'] if x != -100])}")
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")

Checking tokenization...
Input IDs range: 1 to 230172
Label IDs range: 1 to 5176
Tokenizer vocab size: 250100


In [65]:
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    predict_with_generate=True,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    generation_max_length=MAX_TARGET_LENGTH,
    generation_num_beams=4,
    remove_unused_columns=False,
    logging_nan_inf_filter=False,
    fp16=False, # torch.cuda.is_available()
    bf16=False,
    metric_for_best_model="chrf",
    greater_is_better=True,
    report_to="none",
    save_total_limit=1,
    seed=SEED,
)

from transformers import GenerationConfig

# Create a generation config with constraints
generation_config = GenerationConfig(
    max_length=MAX_TARGET_LENGTH,
    num_beams=4,
    early_stopping=True,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.pad_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    # Constrain vocabulary
    min_length=1,
    no_repeat_ngram_size=0,
)

# Set it on the model
model.generation_config = generation_config

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
print(trainer.evaluate())

  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss,Sacrebleu,Chrf
1,4.1901,2.828332,4.618975,12.439263
2,2.5749,2.343723,6.856505,16.131668
3,2.0042,2.136548,9.498912,18.785723
4,1.6823,2.003389,9.114259,20.668285
5,1.4872,1.979692,9.660872,20.928386




{'eval_loss': 1.9796921014785767, 'eval_sacrebleu': 9.660871850194463, 'eval_chrf': 20.928386152445942, 'eval_runtime': 17.7799, 'eval_samples_per_second': 21.597, 'eval_steps_per_second': 2.7, 'epoch': 5.0}


## Preview some results

Finally, we see our fine tuned model in actions using this simple preview function that compares the golden answer for some of the validation set golden values to the ones generated by the model.

In [66]:
train_data = Dataset.from_pandas(te_train)
val_data = Dataset.from_pandas(te_val)

def preview(n=5):
    """
    Preview generations with a prompt that matches training and
    decoding settings tuned for short, factual answers.
    """

    # Safety: ensure we use the same tokenizer/model instance used for training
    device = next(model.parameters()).device

    rows = val_data.select(range(min(n, len(val_data))))
    for x in rows:
        # Must match the exact training prefix you used
        inp = f"Question (Telugu): {x['question']} \nContext (English): {x['context']} \nAnswer (Telugu):"

        enc = tokenizer(
            [inp],
            return_tensors="pt",
            truncation=True,
            max_length=128,
            padding=False,
        ).to(device)

        with torch.no_grad():
            out = model.generate(
                **enc,
                max_new_tokens=32,
                do_sample=True,
                top_p=0.9,
                top_k=40,
                temperature=0.8,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()

        print("Q (te):", x["question"])
        print("Gold (te):", x["answer_inlang"])
        print("Pred (te):", pred)
        print("Gold (en)", x["answer"])
        print("Context (en):", x["context"])
        print("----")
preview(5)

Q (te): ఒరెగాన్ రాష్ట్రంలోని అతిపెద్ద నగరం ఏది ?
Gold (te): None
Pred (te): Portland
Gold (en) Portland
Context (en): Portland is the largest city in the U.S. state of Oregon and the seat of Multnomah County. It is a major port in the Willamette Valley region of the Pacific Northwest, at the confluence of the Willamette and Columbia rivers. As of 2017, Portland had an estimated population of 647,805, making it the 26th-largest city in the United States, and the second-most populous in the Pacific Northwest (after Seattle). Approximately 2.4 million people live in the Portland metropolitan statistical area (MSA), making it the 25th most populous MSA in the United States. Its Combined Statistical Area (CSA) ranks 18th-largest with a population of around 3.2 million. Approximately 60% of Oregon's population resides within the Portland metropolitan area.
----
Q (te): కలరా వ్యాధిని మొదటగా ఏ దేశంలో కనుగొన్నారు ?
Gold (te): None
Pred (te): Chlera
Gold (en) Indian subcontinent
Context (en): Th