In [None]:
!pip install -U datasets -q

In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dataclasses import dataclass
import copy
import wandb
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.amp import autocast, GradScaler
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
distil_model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=4).to(device)
distil_model.eval()

In [3]:
def prepare_data(batch_size=16):
    dataset = load_dataset("fancyzhx/ag_news")
    tokenizer = AutoTokenizer.from_pretrained("Intel/bert-base-uncased-mrpc")
    limit = 6000

    def tokenize(batch):
        encoded = tokenizer(
            text=batch["text"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        encoded = {k: v.long() for k, v in encoded.items()}
        return encoded

    full_train = dataset["train"]
    labels = np.array(full_train["label"])
    unique_labels = np.unique(labels)
    num_classes = len(unique_labels)
    per_class = limit // num_classes

    indices = []
    for label in unique_labels:
        idxs = list(np.where(labels == label)[0])[:per_class]
        indices.extend(idxs)

    np.random.shuffle(indices)
    selected = full_train.select(indices)

    split_idx = len(selected) // 2
    small_train = selected.select(range(split_idx))
    small_val = selected.select(range(split_idx, len(selected)))

    encoded_train = small_train.map(tokenize, batched=True, batch_size=batch_size)
    encoded_val = small_val.map(tokenize, batched=True, batch_size=batch_size)

    encoded_train = encoded_train.remove_columns(["text"])
    encoded_train = encoded_train.rename_column("label", "labels")
    encoded_train.set_format("torch")

    encoded_val = encoded_val.remove_columns(["text"])
    encoded_val = encoded_val.rename_column("label", "labels")
    encoded_val.set_format("torch")

    train_loader = DataLoader(encoded_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(encoded_val, batch_size=batch_size)

    return train_loader, val_loader

def kl_div_loss(student_logits, teacher_logits, temperature=2.3):
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

def hidden_state_loss(student_hs, teacher_hs):
    return F.mse_loss(student_hs, teacher_hs.detach())

def attention_kl_loss(student_attns, teacher_attns, mapping_attn, temperature=2.3):
    total_loss = 0.0
    num_layers = len(student_attns)

    for student_idx in range(num_layers):
        student_attn = student_attns[student_idx]  # [B, H, T, T]
        teacher_attn = teacher_attns[mapping_attn[student_idx]]  # [B, H, T, T]

        B, H, T, _ = student_attn.shape

        # Применяем температурное масштабирование
        student_attn = student_attn / temperature
        teacher_attn = teacher_attn / temperature

        # Вычисляем логарифмы вероятностей и вероятности
        student_log_probs = F.log_softmax(student_attn, dim=-1)  # [B, H, T, T]
        teacher_probs = F.softmax(teacher_attn, dim=-1)  # [B, H, T, T]

        # Вычисляем KL-дивергенцию для каждого элемента
        kl_per_element = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='none'
        )  # [B, H, T, T]

        # Суммируем по последнему измерению (T)
        kl_per_token = kl_per_element.sum(dim=-1)  # [B, H, T]

        # Суммируем по всем головам и токенам (но не по батчу!)
        kl_per_layer = kl_per_token.sum(dim=(1, 2))  # [B]

        # Усредняем по батчу
        layer_loss = kl_per_layer.mean()

        # Масштабируем обратно температурой
        layer_loss = layer_loss * (temperature ** 2)

        total_loss += layer_loss

    return total_loss / num_layers

In [4]:
def train_teacher(model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=16)

    optimizer = AdamW(model.parameters(), lr=6e-5)
    total_steps = len(train_loader) * 2
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    best_val_acc = 0.0
    step = 0

    model.train()

    for epoch in range(2):
        epoch_task_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            student_outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            student_logits = student_outputs.logits

            task_loss = F.cross_entropy(student_logits, labels)

            total_loss = task_loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in tqdm(val_loader):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = model(input_ids, attention_mask=attention_mask)
                preds = torch.argmax(outputs.logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        print("\n", correct / total)

        step += 1

train_teacher(distil_model, tokenizer, 'cuda')

# 0.81

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Epoch 1: 100%|██████████| 188/188 [04:27<00:00,  1.42s/it]
100%|██████████| 188/188 [01:28<00:00,  2.13it/s]



 0.884


Epoch 2: 100%|██████████| 188/188 [04:19<00:00,  1.38s/it]
100%|██████████| 188/188 [01:28<00:00,  2.12it/s]


 0.9156666666666666





In [5]:
bert_model = copy.deepcopy(distil_model)

In [6]:
@dataclass
class DistBERTConfig:
    bert_hidden_size: int = 768
    num_blocks: int = 4
    num_heads: int = 12
    intermediate_size: int = 1024
    dropout: float = 0.1

config = DistBERTConfig()

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.mha = nn.MultiheadAttention(
            embed_dim=config.bert_hidden_size,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.bert_hidden_size)
        )

        self.layernorm1 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.layernorm2 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, attention_mask=None, return_attention=False):

        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()
        else:
            attention_mask = None

        mha_out, attn_weights = self.mha(
            query=x,
            key=x,
            value=x,
            key_padding_mask=attention_mask,
            need_weights=return_attention,
            average_attn_weights=False
        )

        residual = x + self.dropout(mha_out)
        x = self.layernorm1(residual)

        ffn_out = self.ffn(x)
        residual = x + self.dropout(ffn_out)
        output = self.layernorm2(residual)

        if return_attention:
            return output, attn_weights
        else:
            return output

class CleanBERT(nn.Module):
    def __init__(self, embedding_layer, config, num_labels=4):
        super().__init__()
        self.embeddings = embedding_layer

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_blocks)
        ])

        self.pooler = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.bert_hidden_size),
            nn.Tanh()
        )

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.bert_hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None,
                return_hidden_states=False, return_attentions=False):
        x = self.embeddings(input_ids)

        all_hidden_states = []
        all_attentions = []

        if return_hidden_states:
            all_hidden_states.append(x)

        for block in self.blocks:
            if return_attentions:
                x, attn = block(x, attention_mask=attention_mask, return_attention=True)
                all_attentions.append(attn)
            else:
                x = block(x, attention_mask=attention_mask, return_attention=False)

            if return_hidden_states:
                all_hidden_states.append(x)

        cls_output = x[:, 0]

        pooled = self.pooler(cls_output)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        outputs = {"logits": logits}

        if return_hidden_states:
            outputs["hidden_states"] = all_hidden_states

        if return_attentions:
            outputs["attentions"] = all_attentions

        return outputs

In [20]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False

In [None]:
wandb.init(project="bert_mrpc_restricted-losses", config={
    "a": 0,
    "b": 1,
    "c": 1,
    "d": 0,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.5
})

In [14]:
def train_distil_model(student_model, teacher_model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

    optimizer = AdamW(student_model.parameters(), lr=wandb.config.lr)
    total_steps = len(train_loader) * wandb.config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    best_val_acc = 0.0
    step = 0

    scaler = GradScaler('cuda')

    mapping_attn = [2, 5, 8, 11]
    mapping_mse_full = [0, 3, 6, 9, 12]

    student_model.train()
    teacher_model.eval()

    for epoch in range(wandb.config.epochs):
        total_loss = 0
        total_task_loss = 0
        total_kl_loss = 0
        total_hs_loss = 0
        total_attn_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=True
                )
            teacher_logits = teacher_outputs.logits
            teacher_hiddens = teacher_outputs.hidden_states
            teacher_attns = teacher_outputs.attentions

            with autocast('cuda', dtype=torch.float16):
                student_outputs = student_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_hidden_states=True,
                    return_attentions=True
                )
                student_logits = student_outputs["logits"]
                student_hiddens = student_outputs["hidden_states"]
                student_attns = student_outputs["attentions"]

                task_loss = F.cross_entropy(student_logits, labels)
                kl_loss = kl_div_loss(student_logits, teacher_logits, temperature=wandb.config.temperature)

                hs_loss = 0
                for i, student_hs in enumerate(student_hiddens):
                    teacher_hs = teacher_hiddens[mapping_mse_full[i]]
                    hs_loss += hidden_state_loss(student_hs, teacher_hs)

                attn_loss = attention_kl_loss(
                    student_attns,
                    teacher_attns,
                    mapping_attn,
                    temperature=wandb.config.temperature
                )

                total_loss = (
                    wandb.config.a * task_loss +
                    wandb.config.b * kl_loss +
                    wandb.config.c * hs_loss +
                    wandb.config.d * attn_loss
                )

            scaler.scale(total_loss).backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

            total_loss_value = total_loss.item()
            total_task_loss += task_loss.item()
            total_kl_loss += kl_loss.item()
            total_hs_loss += hs_loss.item()
            total_attn_loss += attn_loss.item()

            pb.set_postfix(
                loss=total_loss_value,
                task=task_loss.item(),
                kl=kl_loss.item(),
                hs=hs_loss.item(),
                attn=attn_loss.item()
            )

            wandb.log({
                "step": step + 1,
                "train_loss": total_loss_value,
                "task_loss": task_loss.item(),
                "kl_loss": kl_loss.item(),
                "hs_loss": hs_loss.item(),
                "attn_loss": attn_loss.item()
            })

            step += 1

        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = student_model(input_ids, attention_mask=attention_mask)
                logits = outputs["logits"] if isinstance(outputs, dict) else outputs
                predictions = torch.argmax(logits, dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total
        wandb.log({"epoch": epoch + 1, "val_acc": val_acc})

        print(f"Epoch {epoch+1} | Val Acc: {val_acc * 100:.2f}%\n")

    return student_model

In [None]:
train_distil_model(distil_model, bert_model, tokenizer, device)