In [10]:
from transformers import Trainer
import torch
import torch.nn.functional as F

class DistillationSeq2SeqTrainer(Trainer):
    def __init__(self, teacher_model=None, distil_weight=0.5, temperature=2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.distil_weight = distil_weight
        self.temperature = temperature
        self.teacher_model.eval()
        self.teacher_model.to(self.model.device)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        inputs_no_labels = {k: v for k, v in inputs.items() if k != "labels"}

        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs_no_labels)
            teacher_logits = teacher_outputs.logits

        loss_ce = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            ignore_index=self.tokenizer.pad_token_id
        )

        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)

        loss_kl = F.kl_div(
            student_log_probs, teacher_probs,
            reduction="batchmean",
            log_target=False
        ) * (self.temperature ** 2)

        loss = (1 - self.distil_weight) * loss_ce + self.distil_weight * loss_kl

        # Wandb logging
        if self.state.global_step % 10 == 0 and self.args.report_to and "wandb" in self.args.report_to:
            wandb.log({
                "loss": loss.item(),
                "loss_ce": loss_ce.item(),
                "loss_kl": loss_kl.item(),
                "step": self.state.global_step
            })

        return (loss, student_outputs) if return_outputs else loss


In [20]:
from transformers import (
    BartTokenizer,
    BartForConditionalGeneration,
    BartConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import load_dataset
import os

# 1. Charger le tokenizer (on garde celui de bart-large)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
vocab_size = tokenizer.vocab_size  # = 50265

# 2. Créer une architecture MiniBART
mini_config = BartConfig(
    d_model=256,
    encoder_layers=3,
    decoder_layers=3,
    encoder_attention_heads=4,
    decoder_attention_heads=4,
    encoder_ffn_dim=1024,
    decoder_ffn_dim=1024,
    vocab_size=tokenizer.vocab_size -1,
    max_position_embeddings=1024,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    is_encoder_decoder=True
)

student_model = BartForConditionalGeneration(mini_config)

# 3. Charger les données de summarization (CNN/DM)
dataset = load_dataset("cnn_dailymail", "3.0.0")
train_data = dataset["train"].select(range(10))  # pour test rapide
val_data = dataset["validation"].select(range(3))

# 4. Tokenization
def preprocess(examples):
    inputs = tokenizer(examples["article"], max_length=512, truncation=True, padding="max_length")
    targets = tokenizer(examples["highlights"], max_length=128, truncation=True, padding="max_length")
    inputs["labels"] = targets["input_ids"]
    return inputs

tokenized_train = train_data.map(preprocess, batched=True)
tokenized_val = val_data.map(preprocess, batched=True)

Map: 100%|██████████| 10/10 [00:00<00:00, 87.01 examples/s]
Map: 100%|██████████| 3/3 [00:00<00:00, 59.87 examples/s]


In [21]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./mini-bart-distilled",
    evaluation_strategy="epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=2,
    predict_with_generate=True,
    logging_dir="./logs",
    fp16=torch.cuda.is_available()
)

In [22]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=student_model)
teacher_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")

In [23]:
trainer = DistillationSeq2SeqTrainer(
    model=student_model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    teacher_model=teacher_model,  # BART-LARGE
    distil_weight=0.5,
    temperature=2.0
)
trainer.train()


  super().__init__(*args, **kwargs)
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


Epoch,Training Loss,Validation Loss
1,No log,143.057709
2,No log,139.622131


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


TrainOutput(global_step=4, training_loss=134.80177307128906, metrics={'train_runtime': 16.0448, 'train_samples_per_second': 1.247, 'train_steps_per_second': 0.249, 'total_flos': 339801538560.0, 'train_loss': 134.80177307128906, 'epoch': 2.0})

In [15]:
print("student vocab size:", student_model.config.vocab_size)
print("teacher vocab size:", teacher_model.config.vocab_size)

student vocab size: 50265
teacher vocab size: 50264
