In [None]:
!pip install -U datasets -q

In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dataclasses import dataclass
import copy
import wandb
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained('Intel/bert-base-uncased-mrpc')
distil_model = AutoModelForSequenceClassification.from_pretrained("Intel/bert-base-uncased-mrpc").to(device)
distil_model.eval()

1) EMA обновление весов

2) Loss = a * task_loss (CrossEntropy) + b * logits_loss (KLDiv) + c * attention_loss (KLDiv)

In [20]:
def prepare_data(batch_size=16):
    dataset = load_dataset("nyu-mll/glue", "mrpc")
    tokenizer = AutoTokenizer.from_pretrained("Intel/bert-base-uncased-mrpc")

    def tokenize(batch):
        return tokenizer(
            text=batch["sentence1"],
            text_pair=batch["sentence2"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

    encoded_dataset = dataset.map(tokenize, batched=True, batch_size=batch_size)
    encoded_dataset = encoded_dataset.remove_columns(["sentence1", "sentence2", "idx"])
    encoded_dataset = encoded_dataset.rename_column("label", "labels")
    encoded_dataset.set_format("torch")

    train_loader = DataLoader(encoded_dataset["train"], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(encoded_dataset["test"], batch_size=batch_size)

    return train_loader, val_loader

def kl_div_loss(student_logits, teacher_logits, temperature=2.3):
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

def attention_kl_loss(student_attns, teacher_attns, mapping_attn, temperature=2.3):
    total_loss = 0.0
    num_layers = len(student_attns)

    for student_idx in range(num_layers):
        student_attn = student_attns[student_idx]  # [B, H, T, T]
        teacher_attn = teacher_attns[mapping_attn[student_idx]]  # [B, H, T, T]

        B, H, T, _ = student_attn.shape

        # Применяем температурное масштабирование
        student_attn = student_attn / temperature
        teacher_attn = teacher_attn / temperature

        # Вычисляем логарифмы вероятностей и вероятности
        student_log_probs = F.log_softmax(student_attn, dim=-1)  # [B, H, T, T]
        teacher_probs = F.softmax(teacher_attn, dim=-1)  # [B, H, T, T]

        # Вычисляем KL-дивергенцию для каждого элемента
        kl_per_element = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='none'
        )  # [B, H, T, T]

        # Суммируем по последнему измерению (T)
        kl_per_token = kl_per_element.sum(dim=-1)  # [B, H, T]

        # Суммируем по всем головам и токенам (но не по батчу!)
        kl_per_layer = kl_per_token.sum(dim=(1, 2))  # [B]

        # Усредняем по батчу
        layer_loss = kl_per_layer.mean()

        # Масштабируем обратно температурой
        layer_loss = layer_loss * (temperature ** 2)

        total_loss += layer_loss

    return total_loss / num_layers

In [8]:
class EMAModel:
    def __init__(self, model, decay=0.99):
        self.decay = decay
        self.shadow = copy.deepcopy(model)
        self.shadow.eval()

    def update(self, model):
        with torch.no_grad():
            for shadow_param, model_param in zip(self.shadow.parameters(), model.parameters()):
                shadow_param.copy_(shadow_param * self.decay + (1 - self.decay) * model_param)

In [None]:
wandb.init(project="bert_mrpc_self-distil", config={
    "a": 2,
    "b": 1,
    "c": 1,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.5
})

In [26]:
def train_teacher(model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

    optimizer = AdamW(model.parameters(), lr=wandb.config.lr)
    total_steps = len(train_loader) * wandb.config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    best_val_acc = 0.0
    step = 0

    model.train()

    for epoch in range(wandb.config.epochs):
        epoch_task_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            student_outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            student_logits = student_outputs.logits

            task_loss = F.cross_entropy(student_logits, labels)

            total_loss = task_loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        model.eval()
        correct = 0
        total = 0
        train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)
        with torch.no_grad():
            for batch in tqdm(val_loader):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = model(input_ids, attention_mask=attention_mask)
                preds = torch.argmax(outputs.logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        print("\n", correct / total)

        step += 1

train_teacher(distil_model, tokenizer, 'cuda')

Epoch 1: 100%|██████████| 459/459 [05:41<00:00,  1.34it/s]
100%|██████████| 216/216 [00:50<00:00,  4.25it/s]



 0.8057971014492754


Epoch 2: 100%|██████████| 459/459 [05:33<00:00,  1.37it/s]
100%|██████████| 216/216 [00:50<00:00,  4.27it/s]



 0.7831884057971015


Epoch 3:  12%|█▏        | 56/459 [00:40<04:53,  1.37it/s]


KeyboardInterrupt: 

In [28]:
def set_dropout_and_norm_to_eval(module):
    if isinstance(module, torch.nn.modules.dropout._DropoutNd):
        module.eval()
    elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module.eval()
    elif isinstance(module, torch.nn.LayerNorm):
        module.eval()

    for child in module.children():
        set_dropout_and_norm_to_eval(child)

In [29]:
def train_distil_model(student_model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

    optimizer = AdamW(student_model.parameters(), lr=wandb.config.lr)
    total_steps = len(train_loader) * wandb.config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    ema_teacher = EMAModel(student_model, decay=0.99)
    best_val_acc = 0.0
    step = 0

    mapping_attn = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

    student_model.train()
    set_dropout_and_norm_to_eval(student_model)

    for epoch in range(wandb.config.epochs):
        epoch_task_loss = 0
        epoch_kl_loss = 0
        epoch_attn_loss = 0
        epoch_total_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            with torch.no_grad():
                teacher_outputs = ema_teacher.shadow(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_attentions=True
                )
            teacher_logits = teacher_outputs.logits
            teacher_attns = teacher_outputs.attentions

            student_outputs = student_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=True
            )
            student_logits = student_outputs.logits
            student_attns = student_outputs.attentions

            task_loss = F.cross_entropy(student_logits, labels)
            kl_loss = kl_div_loss(student_logits, teacher_logits, temperature=wandb.config.temperature)
            attn_loss = attention_kl_loss(student_attns, teacher_attns, mapping_attn, temperature=wandb.config.temperature)

            total_loss = (
                wandb.config.a * task_loss +
                wandb.config.b * kl_loss +
                wandb.config.c * attn_loss
            )

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            ema_teacher.update(student_model)

            step_loss = total_loss.item()
            step_task = task_loss.item()
            step_kl = kl_loss.item()
            step_attn = attn_loss.item()

            epoch_task_loss += step_task
            epoch_kl_loss += step_kl
            epoch_attn_loss += step_attn
            epoch_total_loss += step_loss

            pb.set_postfix(loss=step_loss)

            wandb.log({
                "step": step + 1,
                "train_loss": step_loss,
                "task_loss": step_task,
                "kl_loss": step_kl,
                "attn_loss": step_attn
            })

            step += 1

        student_model.eval()
        ema_teacher.shadow.eval()

        student_correct = 0
        ema_correct = 0
        total = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                student_outputs = student_model(input_ids, attention_mask=attention_mask)
                student_preds = torch.argmax(student_outputs.logits, dim=1)
                student_correct += (student_preds == labels).sum().item()

                ema_outputs = ema_teacher.shadow(input_ids, attention_mask=attention_mask)
                ema_preds = torch.argmax(ema_outputs.logits, dim=1)
                ema_correct += (ema_preds == labels).sum().item()

                total += labels.size(0)

        student_val_acc = student_correct / total
        ema_val_acc = ema_correct / total

        wandb.log({
            "epoch": epoch + 1,
            "student_val_acc": student_val_acc,
            "ema_val_acc": ema_val_acc,
            "avg_train_loss": epoch_total_loss / len(train_loader),
            "avg_task_loss": epoch_task_loss / len(train_loader),
            "avg_kl_loss": epoch_kl_loss / len(train_loader),
            "avg_attn_loss": epoch_attn_loss / len(train_loader),
        })

        if ema_val_acc > best_val_acc:
            best_val_acc = ema_val_acc
            torch.save(ema_teacher.shadow.state_dict(), "best_ema_teacher.pt")

        print(f"Epoch {epoch+1} | Student Acc: {student_val_acc*100:.2f}% | EMA Acc: {ema_val_acc*100:.2f}%")
        student_model.train()

    return student_model, ema_teacher.shadow

In [30]:
train_distil_model(distil_model, tokenizer, device)    # Тут у меня в wandb всё в одном ране смешалось, я вначале не обучал учителя и не отключал дропауты, потом обучил и отключил. KL лосы сразу упали, там можно посмотреть.

Epoch 1: 100%|██████████| 459/459 [09:08<00:00,  1.19s/it, loss=0.0276]


Epoch 1 | Student Acc: 79.48% | EMA Acc: 80.81%


Epoch 2: 100%|██████████| 459/459 [09:16<00:00,  1.21s/it, loss=2.33]


Epoch 2 | Student Acc: 80.29% | EMA Acc: 81.10%


Epoch 3: 100%|██████████| 459/459 [09:16<00:00,  1.21s/it, loss=0.128]


Epoch 3 | Student Acc: 81.04% | EMA Acc: 81.97%


Epoch 4: 100%|██████████| 459/459 [09:16<00:00,  1.21s/it, loss=0.0703]


Epoch 4 | Student Acc: 81.86% | EMA Acc: 83.07%


Epoch 5: 100%|██████████| 459/459 [09:16<00:00,  1.21s/it, loss=0.0622]


Epoch 5 | Student Acc: 80.87% | EMA Acc: 82.14%


Epoch 6: 100%|██████████| 459/459 [09:15<00:00,  1.21s/it, loss=3.79]


Epoch 6 | Student Acc: 81.33% | EMA Acc: 81.39%


Epoch 7:  13%|█▎        | 59/459 [01:11<08:07,  1.22s/it, loss=0.0429]


KeyboardInterrupt: 