In [3]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
PII NER (BIO) with DeBERTa on Kaggle 'pii-detection-removal-from-educational-data' (FAST)

Zero-arg Kaggle usage:
----------------------
python train_pii_ner.py

Optional:
- PII_DATA_PATH=/custom/path to override auto-detect.
- Inference:
  python train_pii_ner.py --mode infer --text "Email me at alice@school.edu"

Speed notes:
- Fused AdamW, TF32, fp16, group_by_length, persistent workers, pad_to_multiple_of=8, parallel tokenization.
- Use --compile on A100 for extra speed after the first-epoch compile warmup.
"""

import os
import io
import re
import gc
import json
import zipfile
import argparse
import random
import tempfile
from typing import List, Dict, Any, Tuple

import numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForTokenClassification,
    DebertaV2ForTokenClassification,   # direct class (safe path for DeBERTa v2/v3)
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    set_seed,
    pipeline as hf_pipeline,
)
import torch
from seqeval.metrics import precision_score, recall_score, f1_score, accuracy_score

# Quieter + faster math on Ampere+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
try:
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision("high")
except Exception:
    pass


# --------------------------------
# Kaggle-friendly auto path finder
# --------------------------------
def auto_find_data_path() -> str:
    hint = os.environ.get("PII_DATA_PATH", "").strip()
    if hint and os.path.exists(hint):
        return hint

    common = "/kaggle/input/pii-detection-removal-from-educational-data"
    if os.path.isdir(common) and os.path.exists(os.path.join(common, "train.json")):
        return common

    input_root = "/kaggle/input"
    if os.path.isdir(input_root):
        for root, dirs, files in os.walk(input_root):
            if "train.json" in files:
                return root
        for root, _, files in os.walk(input_root):
            for f in files:
                if f.lower().endswith(".zip"):
                    zpath = os.path.join(root, f)
                    try:
                        with zipfile.ZipFile(zpath) as z:
                            if any(os.path.basename(n) == "train.json" for n in z.namelist()):
                                return zpath
                    except zipfile.BadZipFile:
                        pass

    if os.path.exists("train.json"):
        return os.getcwd()

    raise FileNotFoundError(
        "Could not auto-locate 'train.json'. "
        "Attach the Kaggle dataset or set PII_DATA_PATH=/path/to/folder_or_zip."
    )


def resolve_data_paths(data_path: str) -> Tuple[str, str]:
    if os.path.isdir(data_path):
        train_json = os.path.join(data_path, "train.json")
        test_json = os.path.join(data_path, "test.json")
        if not os.path.isfile(train_json):
            raise FileNotFoundError(f"train.json not found under {data_path}")
        return train_json, test_json if os.path.isfile(test_json) else None

    if zipfile.is_zipfile(data_path):
        td = tempfile.mkdtemp(prefix="pii_zip_")
        with zipfile.ZipFile(data_path) as z:
            z.extractall(td)
        train_json = os.path.join(td, "train.json")
        if os.path.isfile(train_json):
            test_json = os.path.join(td, "test.json")
            return train_json, test_json if os.path.isfile(test_json) else None
        for root, _, files in os.walk(td):
            if "train.json" in files:
                test_json = os.path.join(root, "test.json") if "test.json" in files else None
                return os.path.join(root, "train.json"), test_json

    raise ValueError(f"Invalid data_path: {data_path}")


# -----------------------------
# JSON loading and HF datasets
# -----------------------------
def load_json_records(path: str) -> List[Dict[str, Any]]:
    with io.open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    if isinstance(obj, dict):
        if all(isinstance(v, dict) for v in obj.values()):
            records = list(obj.values())
        elif "data" in obj and isinstance(obj["data"], list):
            records = obj["data"]
        else:
            raise ValueError("Unexpected JSON structure; expected list or dict-of-records.")
    elif isinstance(obj, list):
        records = obj
    else:
        raise ValueError("Unexpected JSON structure; expected list or dict.")
    for i, r in enumerate(records):
        if "tokens" not in r:
            raise ValueError(f"Missing 'tokens' in record {i}")
        if "labels" in r and len(r["labels"]) != len(r["tokens"]):
            raise ValueError(f"labels length != tokens length at record {i}")
    return records


def records_to_hf_dataset(records: List[Dict[str, Any]], with_labels: bool = True) -> Dataset:
    data = {
        "tokens": [r["tokens"] for r in records],
        "trailing_whitespace": [r.get("trailing_whitespace", [True] * len(r["tokens"])) for r in records],
        "document": [r.get("document", -1) for r in records],
        "full_text": [r.get("full_text", "") for r in records],
    }
    if with_labels:
        data["ner_tags_str"] = [r["labels"] for r in records]
    return Dataset.from_dict(data)


def build_label_list(train_records: List[Dict[str, Any]]) -> List[str]:
    uniq = set()
    for r in train_records:
        if "labels" in r and r["labels"] is not None:
            uniq.update(r["labels"])
    uniq.discard("O")
    b_tags = sorted([x for x in uniq if x.startswith("B-")])
    i_tags = sorted([x for x in uniq if x.startswith("I-")])
    return ["O"] + b_tags + i_tags


# -----------------------------
# Tokenization & BIO alignment
# -----------------------------
def tokenize_and_align_labels_fast(
    examples: Dict[str, Any],
    tokenizer: AutoTokenizer,
    label2id: Dict[str, int],
    max_length: int,
    doc_stride: int,
) -> Dict[str, Any]:
    tokenized = tokenizer(
        examples["tokens"],
        is_split_into_words=True,
        truncation=True,
        padding=False,
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=False,
    )

    if "ner_tags_str" not in examples:
        return tokenized

    all_labels = []
    overflow_to_sample = tokenized.pop("overflow_to_sample_mapping")
    for i in range(len(overflow_to_sample)):
        sample_idx = overflow_to_sample[i]
        word_ids = tokenized.word_ids(i)
        word_labels = examples["ner_tags_str"][sample_idx]

        labels = []
        prev_wid = None
        for wid in word_ids:
            if wid is None:
                labels.append(-100)
            else:
                if wid != prev_wid:
                    labels.append(label2id.get(word_labels[wid], label2id["O"]))  # first subtoken -> label
                else:
                    labels.append(-100)  # subsequent subtokens ignored
                prev_wid = wid
        all_labels.append(labels)

    tokenized["labels"] = all_labels
    return tokenized


# -----------------------------
# Metrics (seqeval)
# -----------------------------
def compute_seqeval_metrics(p: Any, id2label: Dict[int, str]) -> Dict[str, float]:
    preds = np.argmax(p.predictions, axis=-1)
    labels = p.label_ids

    true_labels = []
    true_preds = []
    for pred_row, lab_row in zip(preds, labels):
        y_true, y_pred = [], []
        for p_i, l_i in zip(pred_row, lab_row):
            if l_i == -100:
                continue
            y_true.append(id2label[int(l_i)])
            y_pred.append(id2label[int(p_i)])
        true_labels.append(y_true)
        true_preds.append(y_pred)

    return {
        "precision": precision_score(true_labels, true_preds),
        "recall": recall_score(true_labels, true_preds),
        "f1": f1_score(true_labels, true_preds),
        "accuracy": accuracy_score(true_labels, true_preds),
    }


# -----------------------------
# Train/val split
# -----------------------------
def build_datasets(train_json: str, val_ratio: float, seed: int):
    records = load_json_records(train_json)
    rng = random.Random(seed)
    rng.shuffle(records)
    n = len(records)
    n_val = max(1, int(n * val_ratio))
    val_records = records[:n_val]
    train_records = records[n_val:]
    return (
        records_to_hf_dataset(train_records, with_labels=True),
        records_to_hf_dataset(val_records, with_labels=True),
        train_records,
    )


# -----------------------------
# Main
# -----------------------------
def main():
    parser = argparse.ArgumentParser(description="DeBERTa NER (BIO) for Kaggle PII Detection — FAST")
    parser.add_argument("--mode", type=str, default="train", choices=["train", "infer"])
    parser.add_argument("--data-path", type=str, default="", help="Leave empty on Kaggle; auto-detects /kaggle/input/**")
    parser.add_argument("--out-dir", type=str, default="/kaggle/working/pii_deberta_v3_base")
    parser.add_argument("--model-name", type=str, default="microsoft/deberta-v3-base")
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--warmup-ratio", type=float, default=0.1)
    parser.add_argument("--train-batch", type=int, default=8)
    parser.add_argument("--eval-batch", type=int, default=16)
    parser.add_argument("--gradient-accumulation", type=int, default=1)
    parser.add_argument("--max-length", type=int, default=512)
    parser.add_argument("--doc-stride", type=int, default=128)
    parser.add_argument("--val-ratio", type=float, default=0.1)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--fp16", action="store_true", help="Enable FP16 mixed precision")
    parser.add_argument("--bf16", action="store_true", help="Enable BF16 mixed precision")
    parser.add_argument("--grad-checkpointing", action="store_true")
    parser.add_argument("--dataloader-workers", type=int, default=4)
    parser.add_argument("--log-steps", type=int, default=50)
    parser.add_argument("--patience", type=int, default=2)
    parser.add_argument("--save-best-only", action="store_true")
    parser.add_argument("--compile", action="store_true", help="Use torch.compile (A100 recommended)")
    parser.add_argument("--text", type=str, default="")
    args, _ = parser.parse_known_args()

    # Default to fp16 on GPU unless bf16 requested
    if torch.cuda.is_available() and not args.bf16:
        args.fp16 = True

    set_seed(args.seed)
    os.makedirs(args.out_dir, exist_ok=True)

    if args.mode == "infer":
        print(f"[info] Loading model from {args.out_dir} ...")
        tokenizer = AutoTokenizer.from_pretrained(args.out_dir, use_fast=True)
        try:
            model = AutoModelForTokenClassification.from_pretrained(args.out_dir)
        except (ModuleNotFoundError, ImportError):
            model = DebertaV2ForTokenClassification.from_pretrained(args.out_dir)
        nlp = hf_pipeline(
            "token-classification",
            model=model,
            tokenizer=tokenizer,
            aggregation_strategy="simple",
            device=0 if torch.cuda.is_available() else -1,
        )
        text = args.text.strip() or "Email me at alice_01@school.edu or call 555-123-4567. I'm Alice from 221B Baker Street."
        print("[demo] Input:", text)
        print("[demo] Aggregated entities:", nlp(text))
        return

    # --------------------- TRAIN MODE ---------------------
    data_path = args.data_path.strip() or auto_find_data_path()
    print(f"[info] Using data path: {data_path}")
    train_json, _ = resolve_data_paths(data_path)

    train_ds, val_ds, train_records = build_datasets(train_json, args.val_ratio, args.seed)

    # Labels
    label_list = build_label_list(train_records)
    id2label = {i: l for i, l in enumerate(label_list)}
    label2id = {l: i for i, l in enumerate(label_list)}
    print("[info] Labels:", label_list)

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)

    # Parallel map/tokenize with overflow windows + BIO alignment
    NUM_PROC = max(1, min(4, (os.cpu_count() or 2)))
    def _map_fn(batch):
        return tokenize_and_align_labels_fast(
            batch, tokenizer, label2id, max_length=args.max_length, doc_stride=args.doc_stride
        )

    train_tok = train_ds.map(
        _map_fn, batched=True, num_proc=NUM_PROC,
        remove_columns=train_ds.column_names, desc=f"Tokenizing train (num_proc={NUM_PROC})"
    )
    val_tok = val_ds.map(
        _map_fn, batched=True, num_proc=NUM_PROC,
        remove_columns=val_ds.column_names, desc=f"Tokenizing val (num_proc={NUM_PROC})"
    )

    # Model config
    config = AutoConfig.from_pretrained(
        args.model_name,
        num_labels=len(label_list),
        id2label=id2label,
        label2id=label2id,
    )

    # Robust model loader that avoids Auto's wide import scan (glm, etc.)
    def load_model(model_name: str, cfg: AutoConfig):
        name_l = model_name.lower()
        if "deberta-v2" in name_l or "deberta-v3" in name_l or "deberta" in name_l:
            return DebertaV2ForTokenClassification.from_pretrained(model_name, config=cfg)
        return AutoModelForTokenClassification.from_pretrained(model_name, config=cfg)

    try:
        model = load_model(args.model_name, config)
    except (ModuleNotFoundError, ImportError):
        model = DebertaV2ForTokenClassification.from_pretrained(args.model_name, config=config)
    except OSError as e:
        raise OSError(
            f"Failed to load {args.model_name}. On Kaggle (no internet), make sure the model is cached "
            f"or attach it as a dataset / enable internet for first run."
        ) from e

    if args.grad_checkpointing:
        model.gradient_checkpointing_enable()
    model.config.use_cache = False

    # Collator: dynamic padding, align to multiple of 8 to use Tensor Cores
    collator = DataCollatorForTokenClassification(tokenizer=tokenizer, pad_to_multiple_of=8)

    # TrainingArguments (FAST)
    training_args = TrainingArguments(
        output_dir=args.out_dir,
        eval_strategy="epoch",      # <- fix: must be evaluation_strategy
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=args.log_steps,

        per_device_train_batch_size=args.train_batch,
        per_device_eval_batch_size=args.eval_batch,
        gradient_accumulation_steps=args.gradient_accumulation,

        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        weight_decay=args.weight_decay,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type="linear",

        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,

        # SPEED knobs
        optim="adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch",
        group_by_length=True,                    # less padding -> faster
        dataloader_num_workers=args.dataloader_workers,
        dataloader_pin_memory=True,
        dataloader_persistent_workers=True,
        fp16=args.fp16,
        bf16=args.bf16,
        fp16_full_eval=args.fp16,               # eval with AMP too
        eval_accumulation_steps=64,             # lower GPU mem spikes during eval
        torch_compile=args.compile,             # use --compile to enable
        save_total_limit=1 if args.save_best_only else 3,
        save_safetensors=True,
        report_to="none",
        seed=args.seed,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_tok,
        eval_dataset=val_tok,
        data_collator=collator,
        tokenizer=tokenizer,
        compute_metrics=lambda p: compute_seqeval_metrics(p, id2label),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)],
    )

    print("[info] Starting training...")
    trainer.train()
    metrics = trainer.evaluate()
    print("[eval] metrics:", metrics)

    print("[info] Saving to", args.out_dir)
    trainer.save_model(args.out_dir)
    tokenizer.save_pretrained(args.out_dir)

    # Quick demo with aggregation
    try:
        nlp = hf_pipeline(
            "token-classification",
            model=trainer.model,
            tokenizer=tokenizer,
            aggregation_strategy="simple",
            device=0 if torch.cuda.is_available() else -1,
        )
        sample = val_ds[0]
        toks = sample["tokens"]
        ws = sample["trailing_whitespace"]
        text = "".join([t + (" " if (i < len(ws) and ws[i]) else "") for i, t in enumerate(toks)])
        text = re.sub(r"\s+", " ", text).strip()
        print("[demo] Aggregated NER on a val sample:", nlp(text))
    except Exception as e:
        print(f"[warn] Demo pipeline failed: {e}")

    del trainer, model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()


[info] Using data path: /kaggle/input/pii-detection-removal-from-educational-data
[info] Labels: ['O', 'B-EMAIL', 'B-ID_NUM', 'B-NAME_STUDENT', 'B-PHONE_NUM', 'B-STREET_ADDRESS', 'B-URL_PERSONAL', 'B-USERNAME', 'I-ID_NUM', 'I-NAME_STUDENT', 'I-PHONE_NUM', 'I-STREET_ADDRESS', 'I-URL_PERSONAL']


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



Tokenizing train (num_proc=4) (num_proc=4):   0%|          | 0/6127 [00:00<?, ? examples/s]

Tokenizing val (num_proc=4) (num_proc=4):   0%|          | 0/680 [00:00<?, ? examples/s]

pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]

Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


[info] Starting training...


model.safetensors:   0%|          | 0.00/371M [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.0005,0.000503,0.846154,0.814815,0.830189,0.999854
2,0.0003,0.000378,0.775229,0.89418,0.830467,0.999851
3,0.0001,0.000395,0.89071,0.862434,0.876344,0.99989


[eval] metrics: {'eval_loss': 0.0003909986699000001, 'eval_precision': 0.8907103825136612, 'eval_recall': 0.8624338624338624, 'eval_f1': 0.8763440860215054, 'eval_accuracy': 0.9998897342595655, 'eval_runtime': 30.3462, 'eval_samples_per_second': 45.805, 'eval_steps_per_second': 2.867, 'epoch': 3.0}
[info] Saving to /kaggle/working/pii_deberta_v3_base


Device set to use cuda:0
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[demo] Aggregated NER on a val sample: []
