In [1]:
!pip install transformers datasets accelerate sentencepiece

Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata

In [4]:
"""
Distill a TunBERT (teacher) into a smaller student for Masked Language Modeling.
Requirements:
    pip install transformers datasets accelerate sentencepiece
Run:
    python distill_tunbert.py
"""

import os
import math
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForMaskedLM,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW
from datasets import load_dataset


In [5]:
# Config / HParams
# --------------------
teacher_name_or_path = "not-lain/TunBERT"   # replace with your teacher HF path
tokenizer_name_or_path = "not-lain/TunBERT"  # or path to your tokenizer
student_init_model = "distilbert-base-uncased"  # or custom small BERT path/config
save_dir = "./tunbert-distilled"
os.makedirs(save_dir, exist_ok=True)

train_dataset_name_or_path = "text"  # or path; using datasets' text loader example
train_file = "train.txt"  # if using local text file with one sentence per line
valid_file = None  # optional validation file

max_length = 128
train_batch_size = 32
eval_batch_size = 64
num_train_epochs = 3
learning_rate = 5e-5
weight_decay = 0.01
warmup_steps = 1000
gradient_accumulation_steps = 1
max_grad_norm = 1.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Distillation hyperparams
temperature = 2.0
alpha_kd = 0.7   # weight for KD loss (soft labels)
alpha_mlm = 0.3  # weight for hard MLM loss (cross-entropy)
alpha_hidden = 0.0  # if >0, match hidden states (MSE)
hidden_layers_to_match = [-1]  # list of layer indices from student to match teacher (e.g. last layer)

In [None]:
# Utilities
# --------------------
def prepare_datasets():
    # Simple example: load text dataset from a local file or HF dataset
    if train_file and os.path.exists(train_file):
        dataset = load_dataset("text", data_files={"train": train_file, "validation": valid_file} if valid_file else {"train": train_file})
    else:
        # fallback to a small dataset example
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split={"train": "train", "validation": "validation"})

    return dataset

def tokenize_function(examples, tokenizer):
    # Join text if it's a list and then tokenize
    texts = examples["text"]
    return tokenizer(texts, truncation=True, max_length=max_length, padding="max_length", return_tensors=None)

def group_texts(examples):
    # If your text dataset has many short lines, group them to form longer sequences (optional)
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated["input_ids"])
    if total_length >= max_length:
        total_length = (total_length // max_length) * max_length
    result = {
        k: [t[i : i + max_length] for i in range(0, total_length, max_length)]
        for k, t in concatenated.items()
    }
    return result

# MLM masking helper (simple dynamic masking)
import random
def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
    """
    Prepare masked tokens inputs/labels for masked language modeling:
    80% mask -> [MASK], 10% random token, 10% keep original.
    """
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training (with probability mlm_probability)
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # only compute loss on masked tokens

    # 80% of the time, replace masked input tokens with tokenizer.mask_token_id ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, replace masked input tokens with random token
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest 10% of the time, keep original

    return inputs, labels

In [None]:
# Main
# --------------------
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True)
# Load teacher
teacher = AutoModelForMaskedLM.from_pretrained(teacher_name_or_path).to(device)
teacher.eval()  # teacher frozen

# Create student model and config
student_config = AutoConfig.from_pretrained(student_init_model)
student = AutoModelForMaskedLM.from_config(student_config)
# Option: initialize student from pretrained distilbert or from scratch
try:
    # Try to load base weights if student_init_model is a pretrained model
    student = AutoModelForMaskedLM.from_pretrained(student_init_model)
except Exception:
    # keep random init from config if loading fails
    pass

student.to(device)
student.train()

# Dataset
raw_datasets = prepare_datasets()
# tokenize
tokenized_train = raw_datasets["train"].map(lambda x: tokenize_function(x, tokenizer), batched=True, remove_columns=raw_datasets["train"].column_names)
if "validation" in raw_datasets:
    tokenized_val = raw_datasets["validation"].map(lambda x: tokenize_function(x, tokenizer), batched=True, remove_columns=raw_datasets["validation"].column_names)
else:
    tokenized_val = None

# Convert to torch tensors PyTorch-friendly dataset
def collate_fn(batch):
    input_ids = torch.tensor([ex["input_ids"] for ex in batch], dtype=torch.long)
    attention_mask = torch.tensor([ex["attention_mask"] for ex in batch], dtype=torch.long)
    # dynamically mask
    inputs_for_model, labels = mask_tokens(input_ids.clone(), tokenizer)
    return {"input_ids": inputs_for_model, "attention_mask": attention_mask, "labels": labels}

train_loader = DataLoader(tokenized_train, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(tokenized_val, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn) if tokenized_val else None

# Optimizer & scheduler
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": weight_decay},
    {"params": [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
total_steps = (len(train_loader) // gradient_accumulation_steps) * num_train_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

# Training loop
global_step = 0
for epoch in range(num_train_epochs):
    student.train()
    running_loss = 0.0
    for batch_idx, batch in enumerate(train_loader):
        inputs = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)  # -100 for non-mlm positions

        # Teacher logits (no grad)
        with torch.no_grad():
            teacher_outputs = teacher(input_ids=inputs, attention_mask=attention_mask, output_hidden_states=alpha_hidden > 0)
            teacher_logits = teacher_outputs.logits  # (bs, seq_len, vocab_size)
            teacher_hidden = teacher_outputs.hidden_states if alpha_hidden > 0 else None

        # Student forward
        student_outputs = student(input_ids=inputs, attention_mask=attention_mask, output_hidden_states=alpha_hidden > 0)
        student_logits = student_outputs.logits
        student_hidden = student_outputs.hidden_states if alpha_hidden > 0 else None

        # KD loss: KL divergence between softened probabilities
        # reshape: (bs*seq_len, vocab)
        T = temperature
        s_logits = student_logits.view(-1, student_logits.size(-1)) / T
        t_logits = teacher_logits.view(-1, teacher_logits.size(-1)) / T

        # compute log softmax and softmax for KL
        s_log_prob = F.log_softmax(s_logits, dim=-1)
        t_prob = F.softmax(t_logits, dim=-1)
        kd_loss = F.kl_div(s_log_prob, t_prob, reduction="batchmean") * (T * T)

        # Hard MLM loss (student)
        mlm_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        mlm_loss = mlm_loss_fct(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))

        total_loss = alpha_kd * kd_loss + alpha_mlm * mlm_loss

        # Optional hidden-state MSE distillation
        if alpha_hidden > 0 and student_hidden is not None and teacher_hidden is not None:
            # Simple example: match last hidden state ([-1]) after projecting to same dim if needed
            # teacher_hidden[-1]: (bs, seq_len, hidden_t), student_hidden[-1]: (bs, seq_len, hidden_s)
            th = teacher_hidden[hidden_layers_to_match[0]]
            sh = student_hidden[hidden_layers_to_match[0]]
            if th.shape != sh.shape:
                # project student or teacher - simple linear projection could be added
                # for simplicity, reduce via linear on student (add small projector)
                proj = torch.nn.Linear(sh.size(-1), th.size(-1)).to(device)
                sh_proj = proj(sh)
                hidden_loss = F.mse_loss(sh_proj, th)
            else:
                hidden_loss = F.mse_loss(sh, th)
            total_loss = total_loss + alpha_hidden * hidden_loss

        # Backprop
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        running_loss += total_loss.item()
        global_step += 1

        if global_step % 100 == 0:
            avg_loss = running_loss / 100
            print(f"Epoch {epoch+1} Step {global_step} AvgLoss {avg_loss:.4f}")
            running_loss = 0.0

    # End epoch - optional evaluation
    if val_loader is not None:
        student.eval()
        eval_loss = 0.0
        nb_eval_steps = 0
        with torch.no_grad():
            for vb in val_loader:
                v_inputs = vb["input_ids"].to(device)
                v_attention = vb["attention_mask"].to(device)
                v_labels = vb["labels"].to(device)
                v_student_out = student(input_ids=v_inputs, attention_mask=v_attention)
                v_logits = v_student_out.logits
                eval_loss += mlm_loss_fct(v_logits.view(-1, v_logits.size(-1)), v_labels.view(-1)).item()
                nb_eval_steps += 1
        print(f"Validation MLM loss after epoch {epoch+1}: {eval_loss/nb_eval_steps:.4f}")

    # Save checkpoint each epoch
    student.save_pretrained(os.path.join(save_dir, f"student-epoch-{epoch+1}"))
    tokenizer.save_pretrained(os.path.join(save_dir, f"student-epoch-{epoch+1}"))

# Final save
student.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print("Distillation finished. Model saved to", save_dir)

