In [None]:
%pip install peft evaluate transformers Levenshtein ipywidgets
%pip install protobuf==3.20.3

In [None]:
import os
os.environ["TRANSFORMERS_DISABLE_CHAT_TEMPLATES"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["TRANSFORMERS_NO_ADDITIONAL_CHAT_TEMPLATES"] = "1"

In [None]:
from datasets import load_dataset, load_from_disk
# from UQA.canine_utils import preprocess_uqa, lora_config, print_trainable_parameters, normalize_answer, exact_match_score, f1_score, edit_distance_score, gold_answer, decode_prediction
from transformers import CanineTokenizer
from peft import LoraConfig, TaskType, get_peft_model
import re
import string
from collections import Counter
import numpy as np
import Levenshtein

from transformers import TrainingArguments, Trainer, TrainerCallback
import json
from huggingface_hub import HfApi, notebook_login, whoami

In [None]:
notebook_login()
# whoami()

In [None]:
uqa_dataset = load_dataset("uqa/UQA")
uqa_train = uqa_dataset["train"].shuffle(seed=42).select(range(40000))
uqa_val = uqa_dataset["validation"].shuffle(seed=42).select(range(10000))

In [None]:
from transformers import CanineTokenizer, CanineForQuestionAnswering
import torch
model_name = 'google/canine-s'
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

tokenizer = CanineTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=False)
model = CanineForQuestionAnswering.from_pretrained(model_name, trust_remote_code=False)

In [None]:
# preprocessors
MAX_SEQ_LENGTH = 384
DOC_STRIDE = 64

def _build_byte_to_char_index(text):
    cumulative = [0]
    for char in text:
        cumulative.append(cumulative[-1] + len(char.encode("utf-8")))
    return cumulative

def _byte_to_char(cumulative_bytes, byte_index):
    from bisect import bisect_right
    position = bisect_right(cumulative_bytes, byte_index) - 1
    return max(position, 0)

# Safe preprocessing: enforce tokenizer/model limits
def preprocess_uqa(examples, tokenizer, max_length=MAX_SEQ_LENGTH, doc_stride=DOC_STRIDE, model_obj=None):
    # compute global allowed max (use tokenizer/model if available)
    tokenizer_max = getattr(tokenizer, "model_max_length", max_length)
    model_max = getattr(model_obj.config, "max_position_embeddings", None) if model_obj is not None else None
    # choose the smallest of the configured limits
    max_allowed = max_length
    if tokenizer_max is not None and tokenizer_max > 0:
        max_allowed = min(max_allowed, tokenizer_max)
    if model_max is not None and model_max > 0:
        max_allowed = min(max_allowed, model_max)

    questions = [q.strip() for q in examples["question"]]
    contexts = examples["context"]
    answers = examples["answer"]
    answer_starts = examples["answer_start"]
    special_tokens = tokenizer.num_special_tokens_to_add(pair=True)

    encoded = {"input_ids": [], "attention_mask": [], "token_type_ids": [],
               "start_positions": [], "end_positions": [], "overflow_to_sample_mapping": []}

    for example_idx, (question, context, answer, answer_start) in enumerate(zip(questions, contexts, answers, answer_starts)):
        question_tokens = tokenizer.encode(question, add_special_tokens=False)
        context_tokens = tokenizer.encode(context, add_special_tokens=False)

        # compute how many context tokens we can include (reserve special + question)
        max_context_tokens = max_allowed - len(question_tokens) - special_tokens
        if max_context_tokens <= 0 or not context_tokens:
            # skip or emit a short feature that points to CLS
            continue

        # rest of function unchanged but using max_context_tokens (same as before)
        stride_tokens = max_context_tokens - doc_stride
        if stride_tokens <= 0:
            stride_tokens = max_context_tokens
        span_start = 0
        context_length = len(context_tokens)
        while span_start < context_length:
            span_end = min(span_start + max_context_tokens, context_length)
            context_chunk = context_tokens[span_start:span_end]
            input_ids = tokenizer.build_inputs_with_special_tokens(question_tokens, context_chunk)
            token_type_ids = tokenizer.create_token_type_ids_from_sequences(question_tokens, context_chunk)
            attention_mask = [1] * len(input_ids)
            cls_index = input_ids.index(tokenizer.cls_token_id)
            context_offset = len(input_ids) - len(context_chunk) - 1

            if answer and answer_start != -1:
                byte_map = _build_byte_to_char_index(context)
                start_char = _byte_to_char(byte_map, answer_start)
                end_char = _byte_to_char(byte_map, max(answer_start + len(answer) - 1, answer_start))
                answer_span = (start_char, end_char)
                start_char, end_char = answer_span
                answer_in_chunk = start_char >= span_start and end_char < span_end
                if answer_in_chunk:
                    start_pos = context_offset + (start_char - span_start)
                    end_pos = context_offset + (end_char - span_start)
                else:
                    start_pos = cls_index
                    end_pos = cls_index
            else:
                start_pos = cls_index
                end_pos = cls_index

            # ensure final length <= max_allowed by truncating if necessary
            if len(input_ids) > max_allowed:
                input_ids = input_ids[:max_allowed]
                attention_mask = attention_mask[:max_allowed]
                token_type_ids = token_type_ids[:max_allowed]
                # if start/end fall outside, point to CLS
                if start_pos >= max_allowed or end_pos >= max_allowed:
                    start_pos = cls_index
                    end_pos = cls_index

            padding = max_allowed - len(input_ids)
            if padding > 0:
                pad_id = tokenizer.pad_token_id
                input_ids += [pad_id] * padding
                attention_mask += [0] * padding
                token_type_ids += [0] * padding

            encoded["input_ids"].append(input_ids)
            encoded["attention_mask"].append(attention_mask)
            encoded["token_type_ids"].append(token_type_ids)
            encoded["start_positions"].append(start_pos)
            encoded["end_positions"].append(end_pos)
            encoded["overflow_to_sample_mapping"].append(example_idx)

            if span_end == context_length:
                break
            span_start += stride_tokens

    return encoded




In [None]:
# LoRA config
lora_config = LoraConfig(
    task_type=TaskType.QUESTION_ANS,
    r=32,   # changed from 8
    lora_alpha=64,  # changed from 32
    lora_dropout=0.1,
    target_modules=["query", "value", "key"],   # added key, output.dense
    bias="none",
    modules_to_save=["qa_outputs"],
)

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")


In [None]:
# preprocess the train and val splits
processed_train = uqa_train.map(lambda examples: preprocess_uqa(examples, tokenizer), batched=True, remove_columns=uqa_train.column_names)
processed_val = uqa_val.map(lambda examples: preprocess_uqa(examples, tokenizer), batched=True, remove_columns=uqa_val.column_names)

In [None]:
processed_train

In [None]:
processed_val

In [None]:
processed_train.save_to_disk("/kaggle/working/cache/processed_train_uqa")
processed_val.save_to_disk("/kaggle/working/cache/processed_val_uqa")   # cached it


processed_train = load_from_disk("/kaggle/working/cache/processed_train_uqa")
processed_val = load_from_disk("/kaggle/working/cache/processed_val_uqa")

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")



In [None]:
# build LoRA model

peft_model = get_peft_model(model, lora_config)
peft_model.gradient_checkpointing_enable()
print_trainable_parameters(peft_model)

In [None]:
# evals


def normalize_answer(text):
    text = (text or "").lower()
    def remove_articles(s):
        return re.sub(r"\b(a|an|the)\b", " ", s)
    def remove_punctuation(s):
        return "".join(ch for ch in s if ch not in string.punctuation)
    def white_space_fix(s):
        return " ".join(s.split())
    return white_space_fix(remove_articles(remove_punctuation(text)))

def exact_match_score(prediction, ground_truth):
    return float(normalize_answer(prediction) == normalize_answer(ground_truth))

def f1_score(prediction, ground_truth):
    pred_tokens = normalize_answer(prediction).split()
    gold_tokens = normalize_answer(ground_truth).split()
    if not gold_tokens:
        return 1.0 if not pred_tokens else 0.0
    if not pred_tokens:
        return 0.0
    common = Counter(pred_tokens) & Counter(gold_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0.0
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gold_tokens)
    return 2 * precision * recall / (precision + recall)

def edit_distance_score(prediction, ground_truth):
    pred_norm = normalize_answer(prediction)
    gold_norm = normalize_answer(ground_truth)
    if not gold_norm and not pred_norm:
        return 1.0
    if not gold_norm or not pred_norm:
        return 0.0
    distance = Levenshtein.distance(pred_norm, gold_norm)
    max_len = max(len(pred_norm), len(gold_norm))
    return 1.0 - (distance / max_len) if max_len > 0 else 1.0

def gold_answer(example):
    # Extracts the gold answer substring from the context using character offsets
    answer = example.get("answer")
    context = example.get("context")
    answer_start = example.get("answer_start", -1)
    if answer and answer_start is not None and answer_start != -1:
        return context[answer_start: answer_start + len(answer)]
    return "[CLS]"


def decode_prediction(input_ids, start_idx, end_idx, tokenizer=None):
    if start_idx > end_idx:
        start_idx, end_idx = end_idx, start_idx
    if tokenizer is None:
        raise ValueError("Tokenizer must be provided for decoding.")
    cls_index = input_ids.index(tokenizer.cls_token_id)
    # If both point to CLS token, return [CLS] sentinel
    if start_idx == cls_index and end_idx == cls_index:
        return "[CLS]"
    start_idx = max(start_idx, 0)
    end_idx = min(end_idx, len(input_ids) - 1)
    if start_idx > end_idx:
        return "[CLS]"
    text = tokenizer.decode(input_ids[start_idx:end_idx + 1], skip_special_tokens=True)
    text = text.strip()
    return text if text else "[CLS]"


def evaluate_checkpoint(checkpoint_path=None, model_instance=None, eval_dataset=None):
    """Evaluate either a checkpoint path (loads model) or a provided model instance.

    - checkpoint_path: path to checkpoint folder
    - model_instance: an in-memory model (preferably a PeftModel or CanineForQuestionAnswering)
    - eval_dataset: optional dataset to evaluate; if None the default processed_val will be used
    """
    if eval_dataset is None:
        eval_dataset = processed_val

    # If a model_instance is given, use it directly (avoid re-loading a fresh base model)
    if model_instance is not None:
        eval_model = model_instance
    else:
        base_model = CanineForQuestionAnswering.from_pretrained(model_name, trust_remote_code=False)
        eval_model = get_peft_model(base_model, lora_config)
        # Try loading adapter weights; fall back to PeftModel.from_pretrained if needed
        try:
            eval_model.load_adapter(checkpoint_path)
        except Exception:
            from peft import PeftModel
            eval_model = PeftModel.from_pretrained(base_model, checkpoint_path)

    eval_model.to(device)

    eval_args = TrainingArguments(
        # Small evaluation config; uses cpu/mps if no gpu during eval
        output_dir="outputs/canine-s-uqa",
        per_device_eval_batch_size=16,
        dataloader_drop_last=False,
        fp16=True,
        bf16=False,
        report_to="none",
    )

    # Run evaluation via a lightweight Trainer so prediction loop is standard
    eval_trainer = Trainer(
        model=eval_model,
        args=eval_args,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
    )

    predictions = eval_trainer.predict(eval_dataset)
    start_logits, end_logits = predictions.predictions
    best_predictions = {}
    for feature_index, feature in enumerate(eval_dataset):
        sample_idx = int(feature["overflow_to_sample_mapping"])
        input_ids = feature["input_ids"]
        start_idx = int(np.argmax(start_logits[feature_index]))
        end_idx = int(np.argmax(end_logits[feature_index]))
        score = float(start_logits[feature_index][start_idx] + end_logits[feature_index][end_idx])
        prediction_text = decode_prediction(input_ids, start_idx, end_idx, tokenizer=tokenizer)
        stored = best_predictions.get(sample_idx)
        if stored is None or score > stored[0]:
            best_predictions[sample_idx] = (score, prediction_text)

    em_scores = []
    f1_scores = []
    edit_dist_scores = []
    for sample_idx, (_, prediction_text) in best_predictions.items():
        reference = gold_answer(uqa_val[int(sample_idx)])
        em_scores.append(exact_match_score(prediction_text, reference))
        f1_scores.append(f1_score(prediction_text, reference))
        edit_dist_scores.append(edit_distance_score(prediction_text, reference))

    em = float(np.mean(em_scores)) if em_scores else 0.0
    f1 = float(np.mean(f1_scores)) if f1_scores else 0.0
    edit_dist = float(np.mean(edit_dist_scores)) if edit_dist_scores else 0.0
    print(f"Examples evaluated: {len(em_scores)}")
    print(f"Exact Match: {em * 100:.2f}")
    print(f"F1: {f1 * 100:.2f}")
    print(f"Edit Distance (normalized): {edit_dist * 100:.2f}")
    return {"exact_match": em, "f1": f1, "edit_distance": edit_dist}


In [None]:
training_args = TrainingArguments(
    output_dir="outputs/canine-s-uqa",
    
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    
    num_train_epochs=1,
    learning_rate=3e-4,
    weight_decay=0.01,
    eval_strategy="no",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    logging_steps=25,
    fp16=True,
    bf16=False,
    report_to="none",
    push_to_hub=True,
    hub_model_id="VohraAK/canine-s-uqa",
    hub_strategy="checkpoint",
    )

class CustomEvalCallback(TrainerCallback):
    def __init__(self, eval_func, eval_dataset, use_in_memory_model=True, verbose=True):
        self.eval_func = eval_func
        self.eval_dataset = eval_dataset
        self.use_in_memory_model = use_in_memory_model
        self.verbose = verbose
        # trainer reference (set after trainer exists)
        self.trainer = None

    def on_save(self, args, state, control, model=None, **kwargs):
        checkpoint_path = f"{args.output_dir}/checkpoint-{state.global_step}"
        if self.verbose:
            print(f"\nüîç Running custom evaluation at step {state.global_step}...")

        # Prefer evaluating the in-memory trainer model (fast + avoids re-loading)
        if self.use_in_memory_model and self.trainer is not None:
            if self.verbose:
                print("Using in-memory model for evaluation (no reloading).")
            try:
                metrics = self.eval_func(checkpoint_path=None, model_instance=self.trainer.model, eval_dataset=self.eval_dataset)
            except Exception as e:
                print("‚ö†Ô∏è in-memory evaluation failed, falling back to checkpoint load:", e)
                metrics = self.eval_func(checkpoint_path)
        else:
            metrics = self.eval_func(checkpoint_path)

        # record metrics in state.log_history
        state.log_history.append({
            "step": state.global_step,
            "eval_exact_match": metrics.get("exact_match"),
            "eval_f1": metrics.get("f1"),
            "eval_edit_distance": metrics.get("edit_distance"),
        })

        if self.verbose:
            print(f"‚úÖ Step {state.global_step}: EM={metrics.get('exact_match',0)*100:.2f}, F1={metrics.get('f1',0)*100:.2f}, EditDist={metrics.get('edit_distance',0)*100:.2f}")

        # Update trainer_state.json to include custom metrics
        state_path = f"{checkpoint_path}/trainer_state.json"
        try:
            with open(state_path, 'r') as f:
                state_dict = json.load(f)
            state_dict['log_history'] = state.log_history
            with open(state_path, 'w') as f:
                json.dump(state_dict, f, indent=2)
            if self.verbose:
                print(f"üíæ Updated trainer_state.json with custom metrics")
        except Exception as e:
            if self.verbose:
                print(f"‚ö†Ô∏è  Warning: Could not update trainer_state.json: {e}")

        try:
            if self.verbose:
                print(f"‚òÅÔ∏è  Pushing checkpoint-{state.global_step} to Hub...")
            api = HfApi()
            api.upload_folder(
                folder_path=checkpoint_path,
                repo_id=args.hub_model_id,
                path_in_repo=f"checkpoint-{state.global_step}",
                commit_message=f"Add checkpoint {state.global_step} (EM={metrics.get('exact_match',0)*100:.1f}%, F1={metrics.get('f1',0)*100:.1f}%)",
                repo_type="model"
            )
            if self.verbose:
                print(f"‚úÖ Pushed checkpoint-{state.global_step} to Hub")
        except Exception as e:
            if self.verbose:
                print(f"‚ö†Ô∏è  Warning: Could not push to Hub: {e}")

        return control

In [None]:
trainer_cb = CustomEvalCallback(evaluate_checkpoint, processed_val, use_in_memory_model=True)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=processed_train,
    eval_dataset=processed_val,
    callbacks=[trainer_cb],
)


In [None]:
trainer.train()