In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Pure-PyTorch fine-tuning for DeBERTa-v3-Large on SQuAD v2.0 (span extraction)
- Faithful to typical HF settings used in your notebook:
  * model: microsoft/deberta-v3-large
  * max_length=384, doc_stride=128
  * per_device_train_batch_size=4, per_device_eval_batch_size=8
  * fp16 mixed precision
  * save/eval every ~1/4 epoch
- Rich, explicit print statements for visibility at every step.
- Checkpointing and resume support.
- Lightweight evaluation: validation loss + feature-level start/end token accuracy
  (fast and informative; EM/F1 text metrics require heavier post-processing).
"""

'\nPure-PyTorch fine-tuning for DeBERTa-v3-Large on SQuAD v2.0 (span extraction)\n- Faithful to typical HF settings used in your notebook:\n  * model: microsoft/deberta-v3-large\n  * max_length=384, doc_stride=128\n  * per_device_train_batch_size=4, per_device_eval_batch_size=8\n  * fp16 mixed precision\n  * save/eval every ~1/4 epoch\n- Rich, explicit print statements for visibility at every step.\n- Checkpointing and resume support.\n- Lightweight evaluation: validation loss + feature-level start/end token accuracy\n  (fast and informative; EM/F1 text metrics require heavier post-processing).\n'

In [2]:
import os
import math
import time
import json
from dataclasses import dataclass
from typing import Dict, List, Optional, Any

import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding,
)

from datasets import load_dataset
import numpy as np
import random


2025-11-12 09:34:28.815775: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762940069.079775      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762940069.146808      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [3]:
# ---------------------------
# Reproducibility helpers
# ---------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
from dataclasses import dataclass
import os

@dataclass
class Config:
    model_name: str = "microsoft/deberta-v3-large"
    output_dir: str = "./deberta-squad-v2-pytorch"
    checkpoint_dir: str = "./deberta-squad-v2-pytorch/ckpts"

    # ✅ Cut sequence and batch length to fit memory
    max_length: int = 384        
    doc_stride: int = 64

    train_batch_size: int = 1    # per-GPU batch
    eval_batch_size: int = 1
    gradient_accumulation_steps: int = 4   # effective batch = 4

    lr: float = 2e-5
    weight_decay: float = 0.01
    num_epochs: int = 1
    warmup_ratio: float = 0.06
    logging_steps: int = 100
    fp16: bool = True

    eval_per_epoch: int = 4
    save_per_epoch: int = 2

    save_total_limit: int = 2
    seed: int = 42



# Instantiate & make dirs
cfg = Config()
os.makedirs(cfg.output_dir, exist_ok=True)
os.makedirs(cfg.checkpoint_dir, exist_ok=True)


In [5]:

# ---------------------------
# Load data and tokenizer
# ---------------------------
print("[INFO] Loading tokenizer and datasets...")
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.cls_token

print("[INFO] Loading SQuAD v2 dataset (train/validation)...")
raw_datasets = load_dataset("squad_v2")

max_length = cfg.max_length
doc_stride = cfg.doc_stride

def preprocess_function(examples):
    # Strip left spaces in questions (common clean-up)
    questions = [q.lstrip() for q in examples["question"]]

    tokenized = tokenizer(
        questions,
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized["offset_mapping"]
    start_positions = []
    end_positions = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id) if tokenizer.cls_token_id in input_ids else 0

        sequence_ids = tokenized.sequence_ids(i)
        sample_idx = sample_mapping[i]
        answers = examples["answers"][sample_idx]

        # Set offsets outside the context to None
        # (so we can easily skip them)
        for k, s_id in enumerate(sequence_ids):
            if s_id != 1:  # 1 corresponds to the context in QA tokenizers
                offsets[k] = (None, None)

        if len(answers["answer_start"]) == 0:
            # No answer (SQuAD v2): position the span at CLS
            start_positions.append(cls_index)
            end_positions.append(cls_index)
            continue

        start_char = answers["answer_start"][0]
        end_char = start_char + len(answers["text"][0])

        # Find start/end token indices in the context
        token_start_index = 0
        while token_start_index < len(sequence_ids) and sequence_ids[token_start_index] != 1:
            token_start_index += 1
        token_end_index = len(sequence_ids) - 1
        while token_end_index >= 0 and sequence_ids[token_end_index] != 1:
            token_end_index -= 1

        # If the answer is not fully inside the context span, use CLS
        if token_start_index > token_end_index:
            start_positions.append(cls_index)
            end_positions.append(cls_index)
            continue

        # Otherwise move the token_start_index and token_end_index to the two ends of the answer
        # Note: we could still find mismatch because of stride; we handle fallback to CLS.
        start_index = token_start_index
        end_index = token_end_index
        found_start = False
        found_end = False
        for idx in range(token_start_index, token_end_index + 1):
            if offsets[idx] is None:
                continue
            start, _end = offsets[idx]
            if start is not None and start <= start_char < _end:
                start_index = idx
                found_start = True
                break
        for idx in range(token_end_index, token_start_index - 1, -1):
            if offsets[idx] is None:
                continue
            start, _end = offsets[idx]
            if start is not None and start < end_char <= _end:
                end_index = idx
                found_end = True
                break

        if not (found_start and found_end):
            # The answer doesn't fit in the current span
            start_positions.append(cls_index)
            end_positions.append(cls_index)
        else:
            start_positions.append(start_index)
            end_positions.append(end_index)

    tokenized["start_positions"] = start_positions
    tokenized["end_positions"] = end_positions
    return tokenized

print("[INFO] Tokenizing train split (this can take a bit)...")
tokenized_train = raw_datasets["train"].map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
    desc="Tokenizing train",
)

print("[INFO] Tokenizing validation split...")
tokenized_val = raw_datasets["validation"].map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
    desc="Tokenizing validation",
)

[INFO] Loading tokenizer and datasets...


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



[INFO] Loading SQuAD v2 dataset (train/validation)...


README.md: 0.00B [00:00, ?B/s]

squad_v2/train-00000-of-00001.parquet:   0%|          | 0.00/16.4M [00:00<?, ?B/s]

squad_v2/validation-00000-of-00001.parqu(…):   0%|          | 0.00/1.35M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/130319 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11873 [00:00<?, ? examples/s]

[INFO] Tokenizing train split (this can take a bit)...


Tokenizing train:   0%|          | 0/130319 [00:00<?, ? examples/s]

[INFO] Tokenizing validation split...


Tokenizing validation:   0%|          | 0/11873 [00:00<?, ? examples/s]

In [6]:
# ---------------------------
# Torch Dataset wrappers
# ---------------------------
class QADataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        example = self.encodings[idx]
        return {k: torch.tensor(v) for k, v in example.items() if k != "offset_mapping"}

train_dataset = QADataset(tokenized_train)
eval_dataset  = QADataset(tokenized_val)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8 if cfg.fp16 else None)

In [7]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model_name = "microsoft/deberta-v3-large"
checkpoint_path = "/kaggle/input/deberta-v3-large/step24648_epoch0.pt"

print("[INFO] Loading base architecture:", base_model_name)
model = AutoModelForQuestionAnswering.from_pretrained(base_model_name)

# -----------------------------
# 1. Load full checkpoint
# -----------------------------
print("[INFO] Loading checkpoint:", checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)

# If the actual weights are in checkpoint["model"]
if "model" in checkpoint:
    state_dict = checkpoint["model"]
else:
    state_dict = checkpoint   # maybe already a pure dict

# -----------------------------
# 2. Clean prefixes if needed
# -----------------------------
new_state_dict = {}

for k, v in state_dict.items():

    # remove "module." (DataParallel)
    if k.startswith("module."):
        k = k.replace("module.", "")

    # remove "model." prefix if present
    if k.startswith("model."):
        k = k.replace("model.", "")

    new_state_dict[k] = v

# -----------------------------
# 3. Load weights
# -----------------------------
missing, unexpected = model.load_state_dict(new_state_dict, strict=False)

print("[INFO] Missing keys:", missing)
print("[INFO] Unexpected keys:", unexpected)

model.to(device)

print("[INFO] Checkpoint successfully loaded!")



[INFO] Loading base architecture: microsoft/deberta-v3-large


pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/874M [00:00<?, ?B/s]

Some weights of DebertaV2ForQuestionAnswering were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[INFO] Loading checkpoint: /kaggle/input/deberta-v3-large/step24648_epoch0.pt
[INFO] Missing keys: []
[INFO] Unexpected keys: []
[INFO] Checkpoint successfully loaded!


In [8]:

# ---------------------------
# Optimizer & Scheduler
# ---------------------------
# Weight decay for all but LayerNorm/bias
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": cfg.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.lr)

In [9]:

# ---------------------------
# DataLoaders
# ---------------------------
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.train_batch_size,
    shuffle=True,
    collate_fn=data_collator,
    pin_memory=True,
)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=cfg.eval_batch_size,
    shuffle=False,
    collate_fn=data_collator,
    pin_memory=True,
)

In [10]:
# ---------------------------
# Scheduler (linear warmup)
# ---------------------------
num_update_steps_per_epoch = math.ceil(len(train_loader) / cfg.gradient_accumulation_steps)
t_total = num_update_steps_per_epoch * cfg.num_epochs
num_warmup_steps = int(cfg.warmup_ratio * t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)

In [11]:
# ---------------------------
# AMP scaler
# ---------------------------
scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16)

  scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16)


In [12]:
# ---------------------------
# Save/Load checkpoints
# ---------------------------
def save_checkpoint(step: int, epoch: int, best_val: Optional[float] = None):
    ckpt_path = os.path.join(cfg.checkpoint_dir, f"step{step}_epoch{epoch}.pt")
    state = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict(),
        "step": step,
        "epoch": epoch,
        "best_val": best_val,
        "config": cfg.__dict__,
    }
    torch.save(state, ckpt_path)
    print(f"[CKPT] Saved checkpoint: {ckpt_path}")
    # Rotate: keep only last N
    ckpts = sorted([f for f in os.listdir(cfg.checkpoint_dir) if f.endswith(".pt")], key=lambda x: os.path.getmtime(os.path.join(cfg.checkpoint_dir, x)))
    excess = len(ckpts) - cfg.save_total_limit
    for i in range(excess):
        to_del = os.path.join(cfg.checkpoint_dir, ckpts[i])
        try:
            os.remove(to_del)
            print(f"[CKPT] Removed old checkpoint: {to_del}")
        except Exception as e:
            print(f"[CKPT] Failed to remove {to_del}: {e}")


def load_latest_checkpoint():
    ckpts = [os.path.join(cfg.checkpoint_dir, f) for f in os.listdir(cfg.checkpoint_dir) if f.endswith(".pt")]
    if not ckpts:
        return None
    latest = max(ckpts, key=os.path.getmtime)
    print(f"[CKPT] Found latest checkpoint: {latest}. Resuming...")
    state = torch.load(latest, map_location="cpu")
    model.load_state_dict(state["model"])
    optimizer.load_state_dict(state["optimizer"])
    scheduler.load_state_dict(state["scheduler"])
    try:
        scaler.load_state_dict(state["scaler"])
    except Exception:
        print("[CKPT] Skipping scaler state load (mismatch or missing).")
    return state

In [13]:

# ---------------------------
# Evaluation helpers
# ---------------------------
@torch.no_grad()
def evaluate() -> Dict[str, float]:
    model.eval()
    total_loss = 0.0
    n_batches = 0
    token_hits = 0
    token_total = 0

    for batch in eval_loader:
        for k in batch:
            batch[k] = batch[k].to(device, non_blocking=True)

        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            token_type_ids=batch.get("token_type_ids", None),
            start_positions=batch["start_positions"],
            end_positions=batch["end_positions"],
        )
        loss = outputs.loss
        start_logits = outputs.start_logits
        end_logits = outputs.end_logits

        total_loss += loss.item()
        n_batches += 1

        # feature-level token accuracy: did we predict the correct start & end token positions?
        pred_start = start_logits.argmax(dim=-1)
        pred_end = end_logits.argmax(dim=-1)
        token_hits += (pred_start == batch["start_positions"]).sum().item()
        token_hits += (pred_end == batch["end_positions"]).sum().item()
        token_total += 2 * batch["input_ids"].size(0)

    avg_loss = total_loss / max(n_batches, 1)
    token_acc = token_hits / max(token_total, 1)
    model.train()
    return {"val_loss": avg_loss, "feature_token_accuracy": token_acc}

In [14]:
def train():
    set_seed(cfg.seed)
    start_epoch = 0
    global_step = 0
    best_val = float("inf")

    # ----------------------------
    # Resume from checkpoint if any
    # ----------------------------
    resume_state = load_latest_checkpoint()
    if resume_state is not None:
        global_step = resume_state["step"]
        start_epoch = resume_state["epoch"]
        best_val = resume_state.get("best_val", best_val)
        print(f"[RESUME] epoch={start_epoch}, step={global_step}, best_val={best_val:.4f}")
    else:
        print("[RESUME] No checkpoint found. Starting fresh.")

    # ----------------------------
    # Setup training
    # ----------------------------
    model.train()
    print(f"[SETUP] num_epochs={cfg.num_epochs} | steps/epoch={num_update_steps_per_epoch} | total_steps={t_total}")

    eval_every = max(1, num_update_steps_per_epoch // cfg.eval_per_epoch)
    save_every = max(1, num_update_steps_per_epoch // cfg.save_per_epoch)
    print(f"[SETUP] Evaluating every {eval_every} steps (~quarter-epoch)")
    print(f"[SETUP] Saving checkpoints every {save_every} steps (~half-epoch)")

    running_loss = 0.0
    running_count = 0
    last_log_time = time.time()

    # ----------------------------
    # Epoch loop
    # ----------------------------
    for epoch in range(start_epoch, cfg.num_epochs):
        print(f"\n[EPOCH {epoch+1}/{cfg.num_epochs}] Starting...")
        epoch_start = time.time()

        for step, batch in enumerate(train_loader):
            # Move batch to GPU
            for k in batch:
                batch[k] = batch[k].to(device, non_blocking=True)

            # Forward pass (mixed precision)
            with torch.cuda.amp.autocast(enabled=cfg.fp16):
                outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    token_type_ids=batch.get("token_type_ids", None),
                    start_positions=batch["start_positions"],
                    end_positions=batch["end_positions"],
                )
                loss = outputs.loss / cfg.gradient_accumulation_steps

            # Backprop
            scaler.scale(loss).backward()
            running_loss += loss.item()
            running_count += 1

            # Update after accumulation
            if (step + 1) % cfg.gradient_accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
                global_step += 1

                # ----------------------------
                # Logging
                # ----------------------------
                if global_step % cfg.logging_steps == 0:
                    avg_loss = running_loss / max(running_count, 1)
                    elapsed = time.time() - last_log_time
                    current_lr = scheduler.get_last_lr()[0]
                    print(f"[LOG] step={global_step:6d} | lr={current_lr:.6e} | "
                          f"avg_loss={avg_loss:.4f} | {elapsed:.1f}s since last log")
                    running_loss, running_count = 0.0, 0
                    last_log_time = time.time()

                # ----------------------------
                # Evaluation (4× per epoch)
                # ----------------------------
                if global_step % eval_every == 0:
                    print(f"[EVAL] step={global_step} ...")
                    metrics = evaluate()
                    val_loss = metrics["val_loss"]
                    token_acc = metrics.get("feature_token_accuracy", 0)
                    print(f"[EVAL] step={global_step} | val_loss={val_loss:.4f} | token_acc={token_acc:.4f}")

                    # Track & save best model
                    if val_loss < best_val:
                        best_val = val_loss
                        best_model_path = os.path.join(cfg.output_dir, "best_model.pt")
                        torch.save(model.state_dict(), best_model_path)
                        print(f"[BEST] New best val_loss={best_val:.4f} saved to {best_model_path}")

                # ----------------------------
                # Save checkpoint (2× per epoch)
                # ----------------------------
                if global_step % save_every == 0 or (step + 1 == len(train_loader)):
                    save_checkpoint(global_step, epoch, best_val=best_val)
                    print(f"[SAVE] Checkpoint saved at step {global_step}")

        epoch_time = (time.time() - epoch_start) / 60
        print(f"[EPOCH {epoch+1}] completed in {epoch_time:.2f} min")

    # ----------------------------
    # Final save
    # ----------------------------
    final_model_path = os.path.join(cfg.output_dir, "final_model.pt")
    torch.save(model.state_dict(), final_model_path)
    print(f"[DONE] Training complete. Final model saved to {final_model_path}")


In [15]:
train()

[RESUME] No checkpoint found. Starting fresh.
[SETUP] num_epochs=1 | steps/epoch=32859 | total_steps=32859
[SETUP] Evaluating every 8214 steps (~quarter-epoch)
[SETUP] Saving checkpoints every 16429 steps (~half-epoch)

[EPOCH 1/1] Starting...


  with torch.cuda.amp.autocast(enabled=cfg.fp16):


[LOG] step=   100 | lr=1.014713e-06 | avg_loss=0.1798 | 84.3s since last log
[LOG] step=   200 | lr=2.029427e-06 | avg_loss=0.2025 | 83.9s since last log
[LOG] step=   300 | lr=3.044140e-06 | avg_loss=0.1389 | 84.4s since last log
[LOG] step=   400 | lr=4.058853e-06 | avg_loss=0.1621 | 83.9s since last log
[LOG] step=   500 | lr=5.073567e-06 | avg_loss=0.1740 | 83.6s since last log
[LOG] step=   600 | lr=6.088280e-06 | avg_loss=0.1598 | 83.8s since last log
[LOG] step=   700 | lr=7.102993e-06 | avg_loss=0.1362 | 83.8s since last log
[LOG] step=   800 | lr=8.117707e-06 | avg_loss=0.1536 | 83.6s since last log
[LOG] step=   900 | lr=9.132420e-06 | avg_loss=0.1453 | 83.9s since last log
[LOG] step=  1000 | lr=1.014713e-05 | avg_loss=0.1506 | 83.8s since last log
[LOG] step=  1100 | lr=1.116185e-05 | avg_loss=0.1683 | 83.8s since last log
[LOG] step=  1200 | lr=1.217656e-05 | avg_loss=0.1433 | 83.6s since last log
[LOG] step=  1300 | lr=1.319127e-05 | avg_loss=0.1705 | 83.8s since last log