In [None]:
!pip install -q evaluate seqeval
!pip install -q onnx onnxruntime

In [None]:
import torch
import os
import json
import numpy as np
import evaluate
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification, 
    EarlyStoppingCallback
)
from torch.nn.utils import prune
import torch
import shutil
from collections import Counter
import copy
import gc
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnxruntime as ort
from tqdm import tqdm
from IPython.display import FileLink

os.environ["HF_HOME"] = "/kaggle/working/hf"
os.environ["TRANSFORMERS_CACHE"] = "/kaggle/working/hf"
os.environ["HF_DATASETS_CACHE"] = "/kaggle/working/hf"

In [None]:
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    total_mem = torch.cuda.get_device_properties(0).total_memory
    
    print(f"‚úÖ GPU Found: {device_name}")
    print(f"Memory: {total_mem / 1024**3:.2f} GB")
else:
    print("‚ùå GPU NOT found. Check your drivers or Secure Boot.")

# PII Masking Model Pipeline (Browser Optimized)

This notebook implements a pipeline to create a small, efficient PII masking model using the Lottery Ticket Hypothesis (LTH) and Quantization.

**Key Steps:**
1. **Load Data**: `ai4privacy/pii-masking-300k` (Filter for **English** only).
2. **Preprocessing**: Robust tokenization using character offsets to handle dataset quirks.
3. **Save Initial Weights**: Critical for LTH "rewinding".
4. **Fine-tune** -> **Prune** -> **Reset** -> **Retrain** -> **Quantize**.


In [None]:
# Configuration
MODEL_CHECKPOINT = "distilbert-base-uncased"
DATASET_NAME = "ai4privacy/pii-masking-300k"
OUTPUT_DIR = "/kaggle/working/pii_model_output"
INITIAL_WEIGHTS_PATH = os.path.join(OUTPUT_DIR, "initial_weights.pt")

os.makedirs(OUTPUT_DIR, exist_ok=True)

## 1. Data Loading & Splits
We load the dataset and filter for **English** (`language == 'English'`).
Since the dataset only provides `train` and `validation` splits, we split `validation` into disjoint `validation` and `test` sets (50/50).

In [None]:
# Load Dataset
dataset = load_dataset(DATASET_NAME, trust_remote_code=True)
dataset = dataset.filter(lambda x: x["language"] == "English")

if "test" not in dataset:
    print("Creating Test split from Validation...")
    val_test_split = dataset["validation"].train_test_split(test_size=0.5, seed=42)
    dataset["validation"] = val_test_split["train"]
    dataset["test"] = val_test_split["test"]

print(dataset)


## 2. Parsing & Label Extraction
The dataset stores complex fields (`privacy_mask`) as JSON strings (in some versions) or list objects. We ensure they are parsed and extract unique labels.

In [None]:
def parse_dataset_row(example):
    if isinstance(example['privacy_mask'], str):
        example['privacy_mask'] = json.loads(example['privacy_mask'])
    return example

dataset = dataset.map(parse_dataset_row)

print("Extracting unique labels from Train split...")
unique_labels = set()
for privacy_mask in dataset['train']['privacy_mask']:
    for entity in privacy_mask:
        unique_labels.add(entity['label'])

label_list = ["O"]
for label in sorted(list(unique_labels)):
    label_list.append(f"B-{label}")
    label_list.append(f"I-{label}")

label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for label, i in label2id.items()}
print(f"Unique labels: {len(label_list)}")


## 3. Robust Tokenization & Alignment
We use `return_offsets_mapping=True` to map directly from character spans in `privacy_mask` to tokenizer tokens.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["source_text"],
        truncation=True,
        max_length=512,
        stride=64,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding=False 
    )

    labels = []
    sample_mapping = tokenized_inputs.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_inputs.pop("offset_mapping")

    for i, offsets in enumerate(offset_mapping):
        sample_idx = sample_mapping[i]
        mask_list = examples["privacy_mask"][sample_idx]
        
        chunk_labels = [label2id["O"]] * len(tokenized_inputs["input_ids"][i])
        
        for entity in mask_list:
            label_type = entity["label"]
            start_char = entity["start"]
            end_char = entity["end"]
            
            b_id = label2id.get(f"B-{label_type}")
            i_id = label2id.get(f"I-{label_type}")
            if b_id is None: continue
            
            overlapping_indices = []
            for idx, (t_start, t_end) in enumerate(offsets):
                if t_start == 0 and t_end == 0: continue
                if not (t_end <= start_char or t_start >= end_char):
                    overlapping_indices.append(idx)
            
            for k, idx in enumerate(overlapping_indices):
                current_label = chunk_labels[idx]
                if current_label != label2id["O"] and current_label != -100:
                     continue
                
                if k == 0:
                    t_start = offsets[idx][0]
                    if start_char >= t_start:
                        chunk_labels[idx] = b_id
                    else:
                        chunk_labels[idx] = i_id
                else:
                    chunk_labels[idx] = i_id
        
        for idx, (t_start, t_end) in enumerate(offsets):
            if t_start == 0 and t_end == 0:
                chunk_labels[idx] = -100
        
        labels.append(chunk_labels)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized_datasets = dataset.map(
    tokenize_and_align_labels, 
    batched=True, 
    remove_columns=dataset["train"].column_names
)


## 4. Save Initial Weights
Save the untrained (or pre-trained base) weights to support the Lottery Ticket Hypothesis rewinding step later.

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_CHECKPOINT, 
    num_labels=len(label_list), 
    id2label=id2label, 
    label2id=label2id
)

torch.save(model.state_dict(), INITIAL_WEIGHTS_PATH)
print(f"Initial weights saved to {INITIAL_WEIGHTS_PATH} with {len(label_list)} labels and sliding window configuration.")


In [None]:
metric = evaluate.load("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


In [None]:
def evaluate_accuracy(model, dataset, batch_size=16):
    print("Evaluating model accuracy...")
    
    eval_trainer = Trainer(
        model=model,
        args=TrainingArguments(output_dir="/tmp/eval", per_device_eval_batch_size=batch_size, report_to="none"),
        data_collator=DataCollatorForTokenClassification(tokenizer),
        eval_dataset=dataset,
        compute_metrics=compute_metrics
    )
    
    metrics = eval_trainer.evaluate()
    print("Evaluation Results:", metrics)
    return metrics


In [None]:
# 6. Class Imbalance Handling
from torch.nn import CrossEntropyLoss

def compute_class_weights(dataset, label2id):
    print("Computing class weights...")
    label_counts = torch.zeros(len(label2id))
    
   
    for i, example in enumerate(dataset):
        labels = example['labels']
        for label in labels:
            if label != -100:
                label_counts[label] += 1
    
    print(f"Label counts: {label_counts}")
    
    weights = 1.0 / (label_counts + 100)
    weights = weights / weights.sum() * len(label2id)

    weights = torch.clamp(weights, min=0.1, max=10.0)
    weights[label2id["O"]] = 1.0
    
    return weights

class_weights = compute_class_weights(tokenized_datasets['train'], label2id)
print("Class weights calculated.", class_weights)


In [None]:
# Custom Weighted Trainer
data_collator = DataCollatorForTokenClassification(tokenizer)
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        loss_fct = CrossEntropyLoss(weight=class_weights.to(model.device), ignore_index=-100)
        loss = loss_fct(logits.view(-1, len(label2id)), labels.view(-1))
        
        return (loss, outputs) if return_outputs else loss


In [None]:
# Training Configuration & Execution
args = TrainingArguments(
    OUTPUT_DIR,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    
    learning_rate=2e-5, 
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    num_train_epochs=4,
    
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4, 
    
    weight_decay=0.01,
    fp16=torch.cuda.is_available(), 
    dataloader_num_workers=4,
    logging_steps=100,
    report_to="none", 
    seed=42
)

trainer = WeightedTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

print("Starting Training...")
try:
    trainer.train()
    print("‚úÖ Training completed successfully.")
    trainer.save_model(os.path.join(OUTPUT_DIR, "final_model"))
except KeyboardInterrupt:
    print("\nüõë Training interrupted by user. Saving checkpoint...")
    trainer.save_model(os.path.join(OUTPUT_DIR, "interrupted_checkpoint"))
    print("Checkpoint saved.")
except Exception as e:
    print(f"\n‚ùå info: Training failed with error: {e}")
    # Attempt to save despite error
    try:
        trainer.save_model(os.path.join(OUTPUT_DIR, "failed_checkpoint"))
        print("Crash checkpoint saved.")
    except:
        print("Could not save crash checkpoint.")
    raise e


## 6. Pruning

In [None]:
BASELINE_MODEL_PATH = os.path.join(OUTPUT_DIR, "final_model")

PRUNING_TARGETS = [0.25, 0.50, 0.75, 0.85] 

def get_sparsity(model):
    """Calculates the global sparsity of the model."""
    total_zeros = 0
    total_params = 0
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if hasattr(module, "weight_mask"):
                zeros = torch.sum(module.weight_mask == 0).item()
            elif hasattr(module, "weight"):
                zeros = torch.sum(module.weight == 0).item()
            total_zeros += zeros
            total_params += module.weight.nelement()
    return total_zeros / total_params

def strict_imp_loop():
    print(f"üöÄ Starting Strict IMP (Lottery Ticket Hypothesis)...")
    print(f"üéØ Targets: {PRUNING_TARGETS}")
    
    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        raise FileNotFoundError(f"‚ùå Initial weights not found at {INITIAL_WEIGHTS_PATH}!")
    
    w0_state_dict = torch.load(INITIAL_WEIGHTS_PATH, map_location="cpu")
    print(f"‚úÖ Loaded Initial Weights (W0).")

    current_model_path = BASELINE_MODEL_PATH
    current_sparsity = 0.0

    for round_idx, target_sparsity in enumerate(PRUNING_TARGETS):
        print(f"\n" + "="*50)
        print(f"   üé´ ROUND {round_idx + 1}: Target {target_sparsity*100}% Sparsity")
        print(f"="*50)

        print(f"üìÇ Loading model from: {current_model_path}")
        model = AutoModelForTokenClassification.from_pretrained(current_model_path)
        model.to("cuda")

        if current_sparsity >= target_sparsity:
            print(f"‚ö†Ô∏è Already at {current_sparsity:.1%}, skipping...")
            continue
            
        amount_to_prune = (target_sparsity - current_sparsity) / (1.0 - current_sparsity)
        print(f"‚úÇÔ∏è  Pruning {amount_to_prune:.1%} of remaining parameters...")

        parameters_to_prune = []
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                parameters_to_prune.append((module, 'weight'))

        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=amount_to_prune,
        )

        global_masks = {}
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                global_masks[name] = module.weight_mask.detach().cpu()

                
        print("‚ú® Rewinding to W0 (Initial Weights) and applying masks...")
        
        model.load_state_dict(w0_state_dict, strict=False)
        model.to("cuda")

        with torch.no_grad():
            for name, module in model.named_modules():
                if name in global_masks:
                    mask = global_masks[name].to(model.device)
                    
                    module.weight.mul_(mask)
                    
                    prune.custom_from_mask(module, name='weight', mask=mask)

        current_sparsity = get_sparsity(model)
        print(f"‚úÖ Verified Sparsity: {current_sparsity:.2%}")

        run_dir = os.path.join(OUTPUT_DIR, f"lth_sparsity_{int(target_sparsity*100)}")
        print(f"üèãÔ∏è‚Äç‚ôÄÔ∏è Retraining Ticket... Output: {run_dir}")

        torch.cuda.empty_cache()
        gc.collect()

        round_args = copy.deepcopy(args)
        round_args.output_dir = run_dir
        round_args.learning_rate = 2e-5  
        round_args.num_train_epochs = 4 
        
        trainer = WeightedTrainer(
            model=model,
            args=round_args,
            train_dataset=tokenized_datasets["train"],
            eval_dataset=tokenized_datasets["validation"],
            data_collator=data_collator,
            compute_metrics=compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )

        trainer.train()

        
        print("üíæ Saving model for next round...")
        
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                prune.remove(module, 'weight')

        trainer.save_model(run_dir)
        tokenizer.save_pretrained(run_dir)
        
        current_model_path = run_dir
        
        del model, trainer
        torch.cuda.empty_cache()
        gc.collect()

    print("\nüéâ IMP Loop Completed Successfully!")

strict_imp_loop()

In [None]:

BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "lth_sparsity_85") 
FINAL_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "pii_model_pruned_final")

print(f"üìÇ Loading best winning ticket from: {BEST_MODEL_DIR}")
model = AutoModelForTokenClassification.from_pretrained(BEST_MODEL_DIR)

print("üî® Making pruning permanent (baking masks into weights)...")

layers_processed = 0
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        if prune.is_pruned(module):
            prune.remove(module, 'weight')
            layers_processed += 1

print(f"‚úÖ Processed {layers_processed} layers. Masks are now removed.")


for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        if hasattr(module, 'weight_mask'):
            print(f"‚ö†Ô∏è WARNING: Layer {name} still has a mask!")
        

print(f"üíæ Saving final clean model to {FINAL_OUTPUT_DIR}...")
model.save_pretrained(FINAL_OUTPUT_DIR)
tokenizer.save_pretrained(FINAL_OUTPUT_DIR)

print("üéâ Success! The model is now standard architecture with zeroed weights.")

# 7. Quantization

## Export FP32 Model to ONNX

In [None]:
# Configuration
PRUNED_MODEL_PATH = "/kaggle/working/pii_model_output/pii_model_pruned_final"  # From previous step
ONNX_MODEL_PATH = "/kaggle/working/pii_model.onnx"
QUANTIZED_ONNX_PATH = "/kaggle/working/pii_model_quantized.onnx"


print(f"üìÇ Loading model from {PRUNED_MODEL_PATH}...")
model = AutoModelForTokenClassification.from_pretrained(PRUNED_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(PRUNED_MODEL_PATH)
model.eval() 

dummy_input_text = "My name is John Doe."
inputs = tokenizer(dummy_input_text, return_tensors="pt")

# 3. Export to ONNX
print(f"üîÑ Exporting to ONNX (FP32)...")
torch.onnx.export(
    model,
    (inputs['input_ids'], inputs['attention_mask']), 
    ONNX_MODEL_PATH,
    input_names=['input_ids', 'attention_mask'],
    output_names=['logits'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
        'logits': {0: 'batch_size', 1: 'sequence_length'}
    },
    opset_version=14, 
    do_constant_folding=True
)
print(f"‚úÖ Exported to {ONNX_MODEL_PATH}")

## Quantize the ONNX Model (INT8)

In [None]:
print(f"üìâ Quantizing ONNX model to INT8...")

quantize_dynamic(
    model_input=ONNX_MODEL_PATH,
    model_output=QUANTIZED_ONNX_PATH,
    weight_type=QuantType.QUInt8 
)

size_fp32 = os.path.getsize(ONNX_MODEL_PATH) / (1024 * 1024)
size_int8 = os.path.getsize(QUANTIZED_ONNX_PATH) / (1024 * 1024)

print(f"üéâ Done!")
print(f"Original ONNX (FP32):   {size_fp32:.2f} MB")
print(f"Quantized ONNX (INT8):  {size_int8:.2f} MB")
print(f"üîª Reduction:            {100 - (size_int8 / size_fp32 * 100):.1f}%")

In [None]:
metric = evaluate.load("seqeval")

def evaluate_onnx(model_path, dataset, label_list):
    print(f"üïµÔ∏è‚Äç‚ôÄÔ∏è Evaluating ONNX model: {model_path}")
    
    session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
    
    input_name = session.get_inputs()[0].name
    label_map = {i: label for i, label in enumerate(label_list)}
    
    predictions = []
    references = []
    
    print("Running inference...")
    for batch in tqdm(dataset):
        inputs = {
            "input_ids": np.array([batch["input_ids"]], dtype=np.int64),
            "attention_mask": np.array([batch["attention_mask"]], dtype=np.int64)
        }
        
        outputs = session.run(None, inputs)[0] 
        preds = np.argmax(outputs, axis=2)[0] 
        
       
        true_labels = [label_map[l] for l in batch["labels"] if l != -100]
        true_preds = [label_map[p] for (p, l) in zip(preds, batch["labels"]) if l != -100]
        
        predictions.append(true_preds)
        references.append(true_labels)
        
    results = metric.compute(predictions=predictions, references=references)
    return results

In [None]:
import shutil
import json
import os

SOURCE_MODEL_DIR = "/kaggle/working/pii_model_output/pii_model_pruned_final" 
QUANTIZED_MODEL_PATH = "/kaggle/working/pii_model_quantized.onnx"         
EXPORT_DIR = "/kaggle/working/browser_ready_pack"

if os.path.exists(EXPORT_DIR):
    shutil.rmtree(EXPORT_DIR)
os.makedirs(EXPORT_DIR)

print(f"üöÄ Preparing browser artifacts in: {EXPORT_DIR}")

dst_model_path = os.path.join(EXPORT_DIR, "model.onnx")
shutil.copy(QUANTIZED_MODEL_PATH, dst_model_path)
print(f"‚úÖ Copied Model: {dst_model_path}")

tokenizer_files = [
    "tokenizer.json", 
    "tokenizer_config.json", 
    "special_tokens_map.json", 
    "vocab.txt"
]

for filename in tokenizer_files:
    src = os.path.join(SOURCE_MODEL_DIR, filename)
    dst = os.path.join(EXPORT_DIR, filename)
    if os.path.exists(src):
        shutil.copy(src, dst)
        print(f"‚úÖ Copied Tokenizer: {filename}")
    else:
        print(f"‚ö†Ô∏è Warning: {filename} not found (some tokenizers don't use all of them).")

try:
    with open(os.path.join(SOURCE_MODEL_DIR, "config.json"), "r") as f:
        config_data = json.load(f)
    
    if "id2label" in config_data:
        labels_path = os.path.join(EXPORT_DIR, "labels.json")
        with open(labels_path, "w") as f:
            json.dump(config_data["id2label"], f, indent=2)
        print(f"‚úÖ Extracted Label Map: labels.json")
    else:
        print("‚ö†Ô∏è 'id2label' not found in config.json!")
except Exception as e:
    print(f"‚ùå Error extracting labels: {e}")

shutil.make_archive("/kaggle/working/pii_browser_pack", 'zip', EXPORT_DIR)
print("\nüéâ Done! Download 'pii_browser_pack.zip' from the Output tab.")

In [None]:
FileLink(r'pii_browser_pack.zip')

In [None]:
print("Zipping model for download...")
shutil.make_archive("/kaggle/working/pii_model", 'zip', os.path.join(OUTPUT_DIR, "final_model"))
print("Done! You can now download pii_model.zip from the Output tab.")

In [None]:
FileLink(r'pii_model.zip')