In [1]:
!pip install transformers datasets accelerate sentencepiece

Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata

In [3]:
!pip install python-dotenv

Collecting python-dotenv
  Downloading python_dotenv-1.1.1-py3-none-any.whl.metadata (24 kB)
Downloading python_dotenv-1.1.1-py3-none-any.whl (20 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.1.1


In [9]:
import os
import torch
import wandb
import tqdm
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn.functional as F
import logging
from dotenv import load_dotenv

In [None]:


load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


if os.getenv("KAGGLE_CONTAINER_NAME"):
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("hamza_hf_token")
    WANDB_API_KEY = user_secrets.get_secret("wandb_api_key")
else:   
    HF_TOKEN = os.getenv("HF_TOKEN")  # Hugging Face token
    WANDB_API_KEY = os.getenv("WANDB_API_KEY")  # W&B API key

# --- Configuration ---
DATASET_NAME = "hamzabouajila/tunisian-derja-unified-raw-corpus" # Replace with your dataset name
TEACHER_MODEL = "not-lain/TunBERT"
STUDENT_MODEL = "distilbert-base-uncased"
SAVE_DIR = "/kaggle/working/distilled_tunbert"
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 16
NUM_TRAIN_EPOCHS = 3
MAX_LENGTH = 128
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 100
MAX_GRAD_NORM = 1.0
TEMPERATURE = 2.0
ALPHA_KD = 0.5
ALPHA_MLM = 0.5
ALPHA_HIDDEN = 0.0  # Set to 0 to disable hidden state loss for simplicity
GRADIENT_ACCUMULATION_STEPS = 1
HIDDEN_LAYERS_TO_MATCH = [-1]  # Match last hidden layer if used

# --- Install Dependencies ---
try:
    import transformers
    import datasets
    import wandb
except ImportError:
    logger.info("Installing required libraries...")
    os.system("pip install transformers datasets wandb torch")

# --- W&B Initialization ---
def init_wandb():
    status = wandb.login(key=WANDB_API_KEY)
    print("wandb logged in ", status)
    wandb.init(
        project="tunbert-distillation",
        config={
            "teacher_model": TEACHER_MODEL,
            "student_model": STUDENT_MODEL,
            "dataset": DATASET_NAME,
            "train_batch_size": TRAIN_BATCH_SIZE,
            "eval_batch_size": EVAL_BATCH_SIZE,
            "num_train_epochs": NUM_TRAIN_EPOCHS,
            "max_length": MAX_LENGTH,
            "learning_rate": LEARNING_RATE,
            "weight_decay": WEIGHT_DECAY,
            "warmup_steps": WARMUP_STEPS,
            "max_grad_norm": MAX_GRAD_NORM,
            "temperature": TEMPERATURE,
            "alpha_kd": ALPHA_KD,
            "alpha_mlm": ALPHA_MLM,
            "alpha_hidden": ALPHA_HIDDEN,
            "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
        },
    )

# --- Utilities ---
def prepare_datasets():
    logger.info(f"Loading dataset: {DATASET_NAME}")
    dataset = load_dataset(DATASET_NAME, split="train", token=HF_TOKEN)
    # Split into train and validation if needed
    dataset = dataset.train_test_split(test_size=0.1, seed=42)
    return {"train": dataset["train"], "validation": dataset["test"]}

def tokenize_function(examples, tokenizer):
    texts = [str(text) if text is not None and str(text) != "nan" else "" for text in examples["text"]]
    return tokenizer(
        texts,
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length",
        return_tensors=None,
        return_special_tokens_mask=True,
    )

def group_texts(examples):
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated["input_ids"])
    if total_length >= MAX_LENGTH:
        total_length = (total_length // MAX_LENGTH) * MAX_LENGTH
    result = {
        k: [t[i : i + MAX_LENGTH] for i in range(0, total_length, MAX_LENGTH)]
        for k, t in concatenated.items()
    }
    return result

def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
    labels = inputs.clone()
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100

    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    return inputs, labels


In [None]:
# Initialize W&B
init_wandb()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Load tokenizer and models
logger.info(f"Loading tokenizer and teacher model: {TEACHER_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL, token=HF_TOKEN, use_fast=True)
teacher = AutoModelForMaskedLM.from_pretrained(TEACHER_MODEL, token=HF_TOKEN).to(device)
teacher.eval()

logger.info(f"Initializing student model: {STUDENT_MODEL}")
student_config = AutoConfig.from_pretrained(STUDENT_MODEL)
student = AutoModelForMaskedLM.from_config(student_config).to(device)
try:
    student = AutoModelForMaskedLM.from_pretrained(STUDENT_MODEL).to(device)
except Exception:
    logger.info("Using random initialization for student model")
student.train()

In [7]:

# Load and preprocess dataset
raw_datasets = prepare_datasets()
tokenized_train = raw_datasets["train"].map(
    lambda x: tokenize_function(x, tokenizer),
    batched=True,
    remove_columns=raw_datasets["train"].column_names
)
tokenized_val = raw_datasets["validation"].map(
    lambda x: tokenize_function(x, tokenizer),
    batched=True,
    remove_columns=raw_datasets["validation"].column_names
) if "validation" in raw_datasets else None

def collate_fn(batch):
    input_ids = torch.tensor([ex["input_ids"] for ex in batch], dtype=torch.long)
    attention_mask = torch.tensor([ex["attention_mask"] for ex in batch], dtype=torch.long)
    inputs_for_model, labels = mask_tokens(input_ids.clone(), tokenizer)
    return {"input_ids": inputs_for_model, "attention_mask": attention_mask, "labels": labels}

train_loader = DataLoader(tokenized_train, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(tokenized_val, batch_size=EVAL_BATCH_SIZE, shuffle=False, collate_fn=collate_fn) if tokenized_val else None

# Optimizer and scheduler
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": WEIGHT_DECAY},
    {"params": [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)
total_steps = (len(train_loader) // GRADIENT_ACCUMULATION_STEPS) * NUM_TRAIN_EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps)


Map:   0%|          | 0/722393 [00:00<?, ? examples/s]

Map:   0%|          | 0/80266 [00:00<?, ? examples/s]

In [10]:
# Training loop
global_step = 0
for epoch in range(NUM_TRAIN_EPOCHS):
    student.train()
    running_loss = {"total": 0.0, "kd": 0.0, "mlm": 0.0, "hidden": 0.0}
    running_steps = 0
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        inputs = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Teacher forward
        with torch.no_grad():
            teacher_outputs = teacher(input_ids=inputs, attention_mask=attention_mask, output_hidden_states=ALPHA_HIDDEN > 0)
            teacher_logits = teacher_outputs.logits
            teacher_hidden = teacher_outputs.hidden_states if ALPHA_HIDDEN > 0 else None

        # Student forward
        student_outputs = student(input_ids=inputs, attention_mask=attention_mask, output_hidden_states=ALPHA_HIDDEN > 0)
        student_logits = student_outputs.logits
        student_hidden = student_outputs.hidden_states if ALPHA_HIDDEN > 0 else None

        # KD loss
        s_logits = student_logits.view(-1, student_logits.size(-1)) / TEMPERATURE
        t_logits = teacher_logits.view(-1, teacher_logits.size(-1)) / TEMPERATURE
        s_log_prob = F.log_softmax(s_logits, dim=-1)
        t_prob = F.softmax(t_logits, dim=-1)
        kd_loss = F.kl_div(s_log_prob, t_prob, reduction="batchmean") * (TEMPERATURE ** 2)

        # MLM loss
        mlm_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        mlm_loss = mlm_loss_fct(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))

        # Total loss
        total_loss = ALPHA_KD * kd_loss + ALPHA_MLM * mlm_loss

        # Hidden state loss
        hidden_loss = 0.0
        if ALPHA_HIDDEN > 0 and student_hidden is not None and teacher_hidden is not None:
            th = teacher_hidden[HIDDEN_LAYERS_TO_MATCH[0]]
            sh = student_hidden[HIDDEN_LAYERS_TO_MATCH[0]]
            if th.shape != sh.shape:
                proj = torch.nn.Linear(sh.size(-1), th.size(-1)).to(device)
                sh_proj = proj(sh)
                hidden_loss = F.mse_loss(sh_proj, th)
            else:
                hidden_loss = F.mse_loss(sh, th)
            total_loss = total_loss + ALPHA_HIDDEN * hidden_loss

        # Backprop
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        # Log metrics
        running_loss["total"] += total_loss.item()
        running_loss["kd"] += kd_loss.item()
        running_loss["mlm"] += mlm_loss.item()
        running_loss["hidden"] += hidden_loss.item() if ALPHA_HIDDEN > 0 else 0.0
        running_steps += 1

        if global_step % 100 == 0:
            avg_loss = {k: v / running_steps for k, v in running_loss.items()}
            wandb.log({
                "epoch": epoch + 1,
                "step": global_step,
                "total_loss": avg_loss["total"],
                "kd_loss": avg_loss["kd"],
                "mlm_loss": avg_loss["mlm"],
                "hidden_loss": avg_loss["hidden"] if ALPHA_HIDDEN > 0 else 0.0,
            })
            logger.info(f"Epoch {epoch+1} Step {global_step} TotalLoss {avg_loss['total']:.4f} KDLoss {avg_loss['kd']:.4f} MLMLoss {avg_loss['mlm']:.4f}")
            running_loss = {k: 0.0 for k in running_loss}
            running_steps = 0

        global_step += 1

    # Validation
    if val_loader is not None:
        student.eval()
        eval_loss = 0.0
        eval_steps = 0
        with torch.no_grad():
            for vb in tqdm(val_loader, desc="Validation"):
                v_inputs = vb["input_ids"].to(device)
                v_attention = vb["attention_mask"].to(device)
                v_labels = vb["labels"].to(device)
                v_student_out = student(input_ids=v_inputs, attention_mask=v_attention)
                v_logits = v_student_out.logits
                v_loss = mlm_loss_fct(v_logits.view(-1, v_logits.size(-1)), v_labels.view(-1)).item()
                eval_loss += v_loss
                eval_steps += 1
        avg_eval_loss = eval_loss / eval_steps
        wandb.log({"epoch": epoch + 1, "validation_mlm_loss": avg_eval_loss})
        logger.info(f"Validation MLM loss after epoch {epoch+1}: {avg_eval_loss:.4f}")

    # Save checkpoint
    checkpoint_dir = os.path.join(SAVE_DIR, f"student-epoch-{epoch+1}")
    student.save_pretrained(checkpoint_dir)
    tokenizer.save_pretrained(checkpoint_dir)
    wandb.save(checkpoint_dir + "/*")
    logger.info(f"Saved checkpoint to {checkpoint_dir}")

# Final save
student.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
wandb.save(SAVE_DIR + "/*")
logger.info(f"Distillation finished. Model saved to {SAVE_DIR}")

# Push to Hugging Face Hub
logger.info("Pushing model to Hugging Face Hub")
student.push_to_hub("hamzabouajila/distilled_tunbert", use_auth_token=HF_TOKEN)
tokenizer.push_to_hub("hamzabouajila/distilled_tunbert", use_auth_token=HF_TOKEN)

wandb.finish()


TypeError: 'module' object is not callable