In [1]:
!pip install transformers datasets torch accelerate tqdm




In [None]:
import math
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (BertForMaskedLM, BertConfig, AutoTokenizer,
                          AdamW, get_linear_schedule_with_warmup,
                          DataCollatorForLanguageModeling, AutoModelForMaskedLM)
from datasets import load_dataset, concatenate_datasets
from tqdm.notebook import tqdm


In [None]:
# Load teacher (BERT-base) and define student architecture
teacher = BertForMaskedLM.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# DistilBERT-style student config
student_config = BertConfig(
    vocab_size=teacher.config.vocab_size,              # same vocab
    hidden_size=teacher.config.hidden_size,            # 768
    num_hidden_layers=6,                               # half the layers
    num_attention_heads=teacher.config.num_attention_heads,  # 12
    intermediate_size=teacher.config.intermediate_size,      # 3072
    hidden_dropout_prob=teacher.config.hidden_dropout_prob,
    attention_probs_dropout_prob=teacher.config.attention_probs_dropout_prob,
    max_position_embeddings=teacher.config.max_position_embeddings,
    type_vocab_size=teacher.config.type_vocab_size,
    add_pooling_layer=False        # skip the NSP pooling head, like DistilBERT
)
student = AutoModelForMaskedLM.from_config(student_config)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# Load a small corpus for distillation (e.g., WikiText)
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
def tokenize_function(examples):
    return tokenizer(
      examples['text'],
      truncation=True,
      max_length=128,
      padding='max_length',
      return_special_tokens_mask=True
    )

tokenized = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'special_tokens_mask'])


In [None]:
# Hyperparameters and DataLoader with dynamic masking
epochs = 3
micro_batch_size = 16
gradient_accumulation_steps = 256
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
dataloader = DataLoader(tokenized, batch_size=micro_batch_size, shuffle=True, collate_fn=data_collator)
total_steps = math.ceil(len(dataloader) * epochs / gradient_accumulation_steps)
warmup_steps = int(0.1 * total_steps)
optimizer = AdamW(student.parameters(), lr=5e-4, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

In [None]:
# Training loop with triple loss
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher.to(device).eval()
student.to(device).train()
global_step = 0
for epoch in range(epochs):
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit='batch')
    optimizer.zero_grad()
    for step, batch in enumerate(loop):
        inputs = batch['input_ids'].to(device)
        masks = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        with torch.no_grad():
            teacher_out = teacher(input_ids=inputs, attention_mask=masks, output_hidden_states=True)
        student_out = student(input_ids=inputs, attention_mask=masks, output_hidden_states=True)
        mlm_loss = F.cross_entropy(
            student_out.logits.view(-1, student.config.vocab_size),
            labels.view(-1),
            ignore_index=-100
        )
        kd_loss = F.kl_div(F.log_softmax(student_out.logits/2, dim=-1), F.softmax(teacher_out.logits/2, dim=-1), reduction='batchmean') * 4
        s_hid = student_out.hidden_states[-1][:,0,:]
        t_hid = teacher_out.hidden_states[-1][:,0,:]
        cos_loss = torch.nn.CosineEmbeddingLoss()(s_hid, t_hid, torch.ones(s_hid.size(0), device=device))
        loss = (mlm_loss + kd_loss + cos_loss) / gradient_accumulation_steps
        loss.backward()
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1
            loop.set_postfix(step=global_step, loss=loss.item())
    print(f"Epoch {epoch+1} done.")
