# 掩码语言模型(MLM)蒸馏示例

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# 加载掩码语言模型
teacher = AutoModelForMaskedLM.from_pretrained('bert-base-uncased').to(device)
student = AutoModelForMaskedLM.from_pretrained('distilbert-base-uncased').to(device)
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [None]:
# 准备数据
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train[:5000]')
tokenized = dataset.map(lambda x: tokenizer(x['text'], truncation=True, max_length=128), batched=True)

# MLM数据整理器（自动添加mask）
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

In [None]:
# MLM蒸馏训练器
class MLMDistillTrainer(Trainer):
    def __init__(self, *args, teacher=None, temp=3.0, alpha=0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher
        self.temp = temp
        self.alpha = alpha
        
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        mlm_loss = outputs.loss
        
        with torch.no_grad():
            teacher_logits = self.teacher(**inputs).logits
        
        # 只对被mask的位置计算蒸馏损失
        mask_indices = inputs['labels'] != -100
        student_logits = outputs.logits[mask_indices]
        teacher_logits = teacher_logits[mask_indices]
        
        distill_loss = F.kl_div(
            F.log_softmax(student_logits / self.temp, dim=-1),
            F.softmax(teacher_logits / self.temp, dim=-1),
            reduction='batchmean'
        ) * (self.temp ** 2)
        
        total_loss = self.alpha * distill_loss + (1 - self.alpha) * mlm_loss
        return (total_loss, outputs) if return_outputs else total_loss

In [None]:
# 训练
trainer = MLMDistillTrainer(
    model=student,
    teacher=teacher,
    args=TrainingArguments(
        output_dir='./mlm_distilled',
        num_train_epochs=2,
        per_device_train_batch_size=8,
        logging_steps=100
    ),
    train_dataset=tokenized,
    data_collator=data_collator,
    tokenizer=tokenizer
)

trainer.train()

In [None]:
# 测试掩码预测
text = "The capital of France is [MASK]."
inputs = tokenizer(text, return_tensors='pt').to(device)

with torch.no_grad():
    outputs = student(**inputs)
    predictions = outputs.logits

mask_token_index = (inputs['input_ids'] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
predicted_token_id = predictions[0, mask_token_index].argmax(axis=-1)
print(f"预测结果: {tokenizer.decode(predicted_token_id)}")