In [None]:
!pip install -U datasets -q

In [2]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dataclasses import dataclass
import copy
import wandb
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained('Intel/bert-base-uncased-mrpc')
bert_model = AutoModelForSequenceClassification.from_pretrained("Intel/bert-base-uncased-mrpc").to(device)
bert_model.eval()

In [4]:
print(bert_model)
print(sum(p.numel() for p in bert_model.parameters())) # 110M

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [5]:
print(bert_model.bert.embeddings)

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [6]:
@dataclass
class DistBERTConfig:
    bert_hidden_size: int = 768
    num_blocks: int = 4
    num_heads: int = 12
    intermediate_size: int = 1024
    dropout: float = 0.1

config = DistBERTConfig()

## Часть 1. Дистиляция в fp16

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.mha = nn.MultiheadAttention(
            embed_dim=config.bert_hidden_size,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.bert_hidden_size)
        )

        self.layernorm1 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.layernorm2 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, attention_mask=None, return_attention=False):

        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()
        else:
            attention_mask = None

        mha_out, attn_weights = self.mha(
            query=x,
            key=x,
            value=x,
            key_padding_mask=attention_mask,
            need_weights=return_attention,
            average_attn_weights=False
        )

        residual = x + self.dropout(mha_out)
        x = self.layernorm1(residual)

        ffn_out = self.ffn(x)
        residual = x + self.dropout(ffn_out)
        output = self.layernorm2(residual)

        if return_attention:
            return output, attn_weights
        else:
            return output

class CleanBERT(nn.Module):
    def __init__(self, embedding_layer, config, num_labels=2):
        super().__init__()
        self.embeddings = embedding_layer

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_blocks)
        ])

        self.pooler = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.bert_hidden_size),
            nn.Tanh()
        )

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.bert_hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None,
                return_hidden_states=False, return_attentions=False):
        x = self.embeddings(input_ids)

        all_hidden_states = []
        all_attentions = []

        if return_hidden_states:
            all_hidden_states.append(x)

        for block in self.blocks:
            if return_attentions:
                x, attn = block(x, attention_mask=attention_mask, return_attention=True)
                all_attentions.append(attn)
            else:
                x = block(x, attention_mask=attention_mask, return_attention=False)

            if return_hidden_states:
                all_hidden_states.append(x)

        cls_output = x[:, 0]

        pooled = self.pooler(cls_output)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        outputs = {"logits": logits}

        if return_hidden_states:
            outputs["hidden_states"] = all_hidden_states

        if return_attentions:
            outputs["attentions"] = all_attentions

        return outputs

In [74]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False

In [9]:
sentence_pairs = [
    ["Hello, how are you?", "Hi, what's up?"],
    ["Transformers are amazing", "BERT is a powerful model"],
    ["Distill BERT into a smaller model", "Make BERT lighter while preserving performance"],
    ["We reduce dimensions but keep performance", "Performance stays the same after compression"],
]

inputs = tokenizer(
    text=[pair[0] for pair in sentence_pairs],
    text_pair=[pair[1] for pair in sentence_pairs],
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt"
).to(device)

out = distil_model(inputs["input_ids"], attention_mask=inputs["attention_mask"],
                   return_hidden_states=True, return_attentions=True)
logits, hidden_states, attentions = out["logits"], out["hidden_states"], out["attentions"]
print(logits, len(hidden_states), hidden_states[0].shape, len(attentions), attentions[0].shape)

tensor([[0.0445, 0.4749],
        [0.0869, 0.2349],
        [0.2543, 0.0531],
        [0.0382, 0.2491]], device='cuda:0', grad_fn=<AddmmBackward0>) 5 torch.Size([4, 17, 768]) 4 torch.Size([4, 12, 17, 17])


In [10]:
print(distil_model)
print(sum(p.numel() for p in distil_model.parameters())) # 40M (0.36 * BertOrigNumParam)

CleanBERT(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (mha): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=768, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=768, bias=True)
      )
      (layernorm1): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (layernorm2): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (pooler): Sequential(
    (0): Linear(in_features=768, out_features=768, bi

Loss = a * task_loss (CrossEntropy) + b * logits_loss (KLDiv) + c * hidden_states (MSE) + d * attention_loss (KLDiv)

Чтобы посчитать hidden_states и attention_loss нужен маппинг слоёв, о нём ниже.

In [11]:
with torch.no_grad():
    outputs = bert_model(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        output_hidden_states=True,
        output_attentions=True
    )

# Получаем hidden_states: tuple из 13 элементов (embeddings + 12 слоёв)
hidden_states = outputs.hidden_states
attentions = outputs.attentions
hidden_states[12].shape
print(len(hidden_states), hidden_states[0].shape)
print(len(attentions), attentions[0].shape)



13 torch.Size([4, 17, 768])
12 torch.Size([4, 12, 17, 17])


In [52]:
def prepare_data(batch_size=16):
    dataset = load_dataset("nyu-mll/glue", "mrpc")

    def tokenize(batch):
        return tokenizer(
            batch["sentence1"],
            batch["sentence2"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

    encoded_dataset = dataset.map(tokenize, batched=True, batch_size=batch_size)
    encoded_dataset = encoded_dataset.remove_columns(["sentence1", "sentence2", "idx"])
    encoded_dataset = encoded_dataset.rename_column("label", "labels")
    encoded_dataset.set_format("torch")

    train_loader = DataLoader(encoded_dataset["train"], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(encoded_dataset["test"], batch_size=batch_size)

    return train_loader, val_loader

def kl_div_loss(student_logits, teacher_logits, temperature=2.3):
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

def hidden_state_loss(student_hs, teacher_hs):
    return F.mse_loss(student_hs, teacher_hs.detach())

def attention_kl_loss(student_attns, teacher_attns, mapping_attn, temperature=2.3):
    total_loss = 0.0
    num_layers = len(student_attns)

    for student_idx in range(num_layers):
        student_attn = student_attns[student_idx]  # [B, H, T, T]
        teacher_attn = teacher_attns[mapping_attn[student_idx]]  # [B, H, T, T]

        B, H, T, _ = student_attn.shape

        # Применяем температурное масштабирование
        student_attn = student_attn / temperature
        teacher_attn = teacher_attn / temperature

        # Вычисляем логарифмы вероятностей и вероятности
        student_log_probs = F.log_softmax(student_attn, dim=-1)  # [B, H, T, T]
        teacher_probs = F.softmax(teacher_attn, dim=-1)  # [B, H, T, T]

        # Вычисляем KL-дивергенцию для каждого элемента
        kl_per_element = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='none'
        )  # [B, H, T, T]

        # Суммируем по последнему измерению (T)
        kl_per_token = kl_per_element.sum(dim=-1)  # [B, H, T]

        # Суммируем по всем головам и токенам (но не по батчу!)
        kl_per_layer = kl_per_token.sum(dim=(1, 2))  # [B]

        # Усредняем по батчу
        layer_loss = kl_per_layer.mean()

        # Масштабируем обратно температурой
        layer_loss = layer_loss * (temperature ** 2)

        total_loss += layer_loss

    return total_loss / num_layers

In [None]:
wandb.init(project="bert_mrpc_fp16dist", config={
    "a": 4,
    "b": 1,
    "c": 0.3,
    "d": 0.5,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.25
})

In [76]:
def train_distil_model(student_model, teacher_model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

    optimizer = AdamW(student_model.parameters(), lr=wandb.config.lr)
    total_steps = len(train_loader) * wandb.config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    best_val_acc = 0.0
    step = 0

    scaler = GradScaler('cuda')

    mapping_attn = [2, 5, 8, 11]
    mapping_mse_full = [0, 3, 6, 9, 12]

    student_model.train()
    teacher_model.eval()

    for epoch in range(wandb.config.epochs):
        total_loss = 0
        total_task_loss = 0
        total_kl_loss = 0
        total_hs_loss = 0
        total_attn_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=True
                )
            teacher_logits = teacher_outputs.logits
            teacher_hiddens = teacher_outputs.hidden_states
            teacher_attns = teacher_outputs.attentions

            with autocast('cuda', dtype=torch.float16):
                student_outputs = student_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_hidden_states=True,
                    return_attentions=True
                )
                student_logits = student_outputs["logits"]
                student_hiddens = student_outputs["hidden_states"]
                student_attns = student_outputs["attentions"]

                task_loss = F.cross_entropy(student_logits, labels)
                kl_loss = kl_div_loss(student_logits, teacher_logits, temperature=wandb.config.temperature)

                hs_loss = 0
                for i, student_hs in enumerate(student_hiddens):
                    teacher_hs = teacher_hiddens[mapping_mse_full[i]]
                    hs_loss += hidden_state_loss(student_hs, teacher_hs)

                attn_loss = attention_kl_loss(
                    student_attns,
                    teacher_attns,
                    mapping_attn,
                    temperature=wandb.config.temperature
                )

                total_loss = (
                    wandb.config.a * task_loss +
                    wandb.config.b * kl_loss +
                    wandb.config.c * hs_loss +
                    wandb.config.d * attn_loss
                )

            scaler.scale(total_loss).backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

            total_loss_value = total_loss.item()
            total_task_loss += task_loss.item()
            total_kl_loss += kl_loss.item()
            total_hs_loss += hs_loss.item()
            total_attn_loss += attn_loss.item()

            pb.set_postfix(
                loss=total_loss_value,
                task=task_loss.item(),
                kl=kl_loss.item(),
                hs=hs_loss.item(),
                attn=attn_loss.item()
            )

            wandb.log({
                "step": step + 1,
                "train_loss": total_loss_value,
                "task_loss": task_loss.item(),
                "kl_loss": kl_loss.item(),
                "hs_loss": hs_loss.item(),
                "attn_loss": attn_loss.item()
            })

            step += 1

        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = student_model(input_ids, attention_mask=attention_mask)
                logits = outputs["logits"] if isinstance(outputs, dict) else outputs
                predictions = torch.argmax(logits, dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total
        wandb.log({"epoch": epoch + 1, "val_acc": val_acc})

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(student_model.state_dict(), "best_distilled_model.pt")

            state_dict = {}
            for k, v in student_model.state_dict().items():
                state_dict[k] = v.clone().half()

            torch.save(state_dict, "best_distilled_model_fp16.pt")

        print(f"Epoch {epoch+1} | Val Acc: {val_acc * 100:.2f}%\n")

    return student_model

In [77]:
train_distil_model(distil_model, bert_model, tokenizer, device)

Epoch 1: 100%|██████████| 459/459 [02:52<00:00,  2.66it/s, attn=1.23, hs=2.67, kl=0.256, loss=2.45, task=0.196]


Epoch 1 | Val Acc: 70.43%



Epoch 2: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.902, hs=2.09, kl=0.728, loss=3.74, task=0.484]


Epoch 2 | Val Acc: 62.43%



Epoch 3: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.808, hs=1.95, kl=0.744, loss=4.25, task=0.63]


Epoch 3 | Val Acc: 67.54%



Epoch 4: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.76, hs=1.83, kl=0.991, loss=3.7, task=0.445]


Epoch 4 | Val Acc: 62.72%



Epoch 5: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.805, hs=1.88, kl=1.21, loss=4.12, task=0.486]


Epoch 5 | Val Acc: 68.58%



Epoch 6: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.707, hs=1.88, kl=0.707, loss=2.85, task=0.305]


Epoch 6 | Val Acc: 63.71%



Epoch 7: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.784, hs=1.86, kl=0.812, loss=2.41, task=0.163]


Epoch 7 | Val Acc: 62.72%



Epoch 8: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.816, hs=1.89, kl=0.26, loss=2.65, task=0.353]


Epoch 8 | Val Acc: 58.20%



Epoch 9: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.793, hs=1.84, kl=0.881, loss=2.17, task=0.0848]


Epoch 9 | Val Acc: 63.36%



Epoch 10: 100%|██████████| 459/459 [02:48<00:00,  2.73it/s, attn=0.864, hs=1.81, kl=0.303, loss=1.95, task=0.169]


Epoch 10 | Val Acc: 61.68%



Epoch 11: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.753, hs=1.77, kl=0.0176, loss=0.99, task=0.0165]


Epoch 11 | Val Acc: 61.86%



Epoch 12: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.748, hs=1.71, kl=0.421, loss=1.7, task=0.0983]


Epoch 12 | Val Acc: 60.70%



Epoch 13:   9%|▊         | 40/459 [00:14<02:36,  2.67it/s, attn=0.825, hs=1.8, kl=0.367, loss=1.59, task=0.0675]


KeyboardInterrupt: 

Тут такой прикол, что в fp16 я не стал обучать явно, а через автокаст, который только градиенты с некоторых (почти всех) слоев считает в fp16, веса хранит в fp32. По идее после torch.half() точность не упадет (у меня не падала). Я так сделал, потому что карпатый, вроде как, тоже в автокасте обучал. А в half() я никого не видел.

In [78]:
distil_model.load_state_dict(torch.load(f"best_distilled_model_fp16.pt"))
distil_model.eval()
correct = 0
total = 0
train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)
with torch.no_grad():
    for batch in tqdm(val_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = distil_model(input_ids, attention_mask=attention_mask)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
print("\n", correct / total)

100%|██████████| 216/216 [00:10<00:00, 19.85it/s]


 0.7049275362318841





У меня максимально выбивает 0.7. После всех игр с лоссами. Нестабильность в vall_accuracy, как будто бы, из-за того, что модель инстантно переучиваться начинает. Ну или, может быть, из-за нестабильности в важности лоссов. Хотя я попытался это минимизировать, навряд ли из-за этого.

Короче как будто бы я в хард кап упираюсь вычислительных способностей берта с 4 слоями и fp16 Точностью

## Часть 2. Дистиляция с шумом, в fp16.

**Тут я обучаю в fp16 с шумом.**

In [79]:
def add_noise(tensor, noise_std=0.05):
    if tensor is None:
        return None
    if noise_std > 0.0 and tensor.requires_grad:
        noise = torch.randn_like(tensor) * noise_std
        return tensor + noise
    return tensor

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.mha = nn.MultiheadAttention(
            embed_dim=config.bert_hidden_size,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.bert_hidden_size)
        )

        self.layernorm1 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.layernorm2 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, attention_mask=None, return_attention=False):

        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()
        else:
            attention_mask = None

        mha_out, attn_weights = self.mha(
            query=x,
            key=x,
            value=x,
            key_padding_mask=attention_mask,
            need_weights=return_attention,
            average_attn_weights=False
        )

        residual = x + self.dropout(mha_out)
        x = self.layernorm1(residual)

        ffn_out = self.ffn(x)
        residual = x + self.dropout(ffn_out)
        output = self.layernorm2(residual)

        if return_attention:
            return output, attn_weights
        else:
            return output

class CleanBERT(nn.Module):
    def __init__(self, embedding_layer, config, num_labels=2):
        super().__init__()
        self.embeddings = embedding_layer

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_blocks)
        ])

        self.pooler = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.bert_hidden_size),
            nn.Tanh()
        )

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.bert_hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None,
                return_hidden_states=False, return_attentions=False):
        x = self.embeddings(input_ids)

        if self.training:
            x = add_noise(x, noise_std=0.05)

        all_hidden_states = []
        all_attentions = []

        if return_hidden_states:
            all_hidden_states.append(x)

        for block in self.blocks:
            if return_attentions:
                x, attn = block(x, attention_mask=attention_mask, return_attention=True)
                all_attentions.append(attn)
            else:
                x = block(x, attention_mask=attention_mask, return_attention=False)

            if self.training:
                x = add_noise(x, noise_std=0.05)

            if return_hidden_states:
                all_hidden_states.append(x)

        cls_output = x[:, 0]

        pooled = self.pooler(cls_output)
        if self.training:
            pooled = add_noise(pooled, noise_std=0.05)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        outputs = {"logits": logits}

        if return_hidden_states:
            outputs["hidden_states"] = all_hidden_states

        if return_attentions:
            outputs["attentions"] = all_attentions

        return outputs

In [None]:
wandb.init(project="bert_mrpc_noice", config={
    "a": 4,
    "b": 1,
    "c": 0.3,
    "d": 0.5,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.25
})

In [81]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False

train_distil_model(distil_model, bert_model, tokenizer, device)

Epoch 1: 100%|██████████| 459/459 [02:50<00:00,  2.69it/s, attn=1.17, hs=2.75, kl=1.18, loss=5.25, task=0.667]


Epoch 1 | Val Acc: 70.03%



Epoch 2: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.955, hs=2.16, kl=0.649, loss=3.58, task=0.451]


Epoch 2 | Val Acc: 69.33%



Epoch 3: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.847, hs=1.95, kl=0.789, loss=4.83, task=0.759]


Epoch 3 | Val Acc: 59.48%



Epoch 4: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.862, hs=2.1, kl=0.469, loss=2.92, task=0.347]


Epoch 4 | Val Acc: 63.01%



Epoch 5: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.806, hs=1.92, kl=1.07, loss=4.85, task=0.701]


Epoch 5 | Val Acc: 66.20%



Epoch 6: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.821, hs=1.91, kl=1.01, loss=3.81, task=0.455]


Epoch 6 | Val Acc: 65.45%



Epoch 7: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.839, hs=1.83, kl=1.43, loss=3.36, task=0.24]


Epoch 7 | Val Acc: 66.14%



Epoch 8: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.851, hs=1.98, kl=0.682, loss=2.86, task=0.289]


Epoch 8 | Val Acc: 64.99%



Epoch 9: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.823, hs=1.95, kl=0.677, loss=3.95, task=0.569]


Epoch 9 | Val Acc: 60.64%



Epoch 10: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.824, hs=1.89, kl=0.758, loss=3.52, task=0.447]


Epoch 10 | Val Acc: 64.35%



Epoch 11: 100%|██████████| 459/459 [02:47<00:00,  2.74it/s, attn=0.779, hs=1.83, kl=0.203, loss=1.44, task=0.0746]


Epoch 11 | Val Acc: 63.54%



Epoch 12:  21%|██        | 95/459 [00:35<02:14,  2.70it/s, attn=0.767, hs=1.91, kl=0.483, loss=1.75, task=0.0792]


KeyboardInterrupt: 

In [82]:
distil_model.load_state_dict(torch.load(f"best_distilled_model_fp16.pt"))
distil_model.eval()
correct = 0
total = 0
train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)
with torch.no_grad():
    for batch in tqdm(val_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = distil_model(input_ids, attention_mask=attention_mask)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
print("\n", correct / total)

100%|██████████| 216/216 [00:10<00:00, 19.73it/s]


 0.7002898550724638





##Часть 3. Дистиляция с шумом, в fp32.

In [None]:
wandb.init(project="bert_mrpc_noice", config={
    "a": 4,
    "b": 1,
    "c": 0.3,
    "d": 0.5,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.25
})

In [84]:
def train_distil_model_2(student_model, teacher_model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

    optimizer = AdamW(student_model.parameters(), lr=wandb.config.lr)
    total_steps = len(train_loader) * wandb.config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    best_val_acc = 0.0
    step = 0

    scaler = GradScaler('cuda')

    mapping_attn = [2, 5, 8, 11]
    mapping_mse_full = [0, 3, 6, 9, 12]

    student_model.train()
    teacher_model.eval()

    for epoch in range(wandb.config.epochs):
        total_loss = 0
        total_task_loss = 0
        total_kl_loss = 0
        total_hs_loss = 0
        total_attn_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=True
                )
            teacher_logits = teacher_outputs.logits
            teacher_hiddens = teacher_outputs.hidden_states
            teacher_attns = teacher_outputs.attentions

            student_outputs = student_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_hidden_states=True,
                return_attentions=True
            )
            student_logits = student_outputs["logits"]
            student_hiddens = student_outputs["hidden_states"]
            student_attns = student_outputs["attentions"]

            task_loss = F.cross_entropy(student_logits, labels)
            kl_loss = kl_div_loss(student_logits, teacher_logits, temperature=wandb.config.temperature)

            hs_loss = 0
            for i, student_hs in enumerate(student_hiddens):
                teacher_hs = teacher_hiddens[mapping_mse_full[i]]
                hs_loss += hidden_state_loss(student_hs, teacher_hs)

            attn_loss = attention_kl_loss(
                student_attns,
                teacher_attns,
                mapping_attn,
                temperature=wandb.config.temperature
            )

            total_loss = (
                wandb.config.a * task_loss +
                wandb.config.b * kl_loss +
                wandb.config.c * hs_loss +
                wandb.config.d * attn_loss
            )

            scaler.scale(total_loss).backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

            total_loss_value = total_loss.item()
            total_task_loss += task_loss.item()
            total_kl_loss += kl_loss.item()
            total_hs_loss += hs_loss.item()
            total_attn_loss += attn_loss.item()

            pb.set_postfix(
                loss=total_loss_value,
                task=task_loss.item(),
                kl=kl_loss.item(),
                hs=hs_loss.item(),
                attn=attn_loss.item()
            )

            wandb.log({
                "step": step + 1,
                "train_loss": total_loss_value,
                "task_loss": task_loss.item(),
                "kl_loss": kl_loss.item(),
                "hs_loss": hs_loss.item(),
                "attn_loss": attn_loss.item()
            })

            step += 1

        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = student_model(input_ids, attention_mask=attention_mask)
                logits = outputs["logits"] if isinstance(outputs, dict) else outputs
                predictions = torch.argmax(logits, dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total
        wandb.log({"epoch": epoch + 1, "val_acc": val_acc})

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(student_model.state_dict(), "best_distilled_model.pt")

            state_dict = {}
            for k, v in student_model.state_dict().items():
                state_dict[k] = v.clone().half()

            torch.save(state_dict, "best_distilled_model_fp16.pt")

        print(f"Epoch {epoch+1} | Val Acc: {val_acc * 100:.2f}%\n")

    return student_model

In [89]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False

train_distil_model_2(distil_model, bert_model, tokenizer, device)

Epoch 1: 100%|██████████| 459/459 [03:30<00:00,  2.18it/s, attn=1.54, hs=3.66, kl=0.665, loss=6.31, task=0.944]


Epoch 1 | Val Acc: 65.10%



Epoch 2: 100%|██████████| 459/459 [03:24<00:00,  2.25it/s, attn=1.11, hs=2.57, kl=0.815, loss=5.34, task=0.8]


Epoch 2 | Val Acc: 64.81%



Epoch 3: 100%|██████████| 459/459 [03:24<00:00,  2.25it/s, attn=1.15, hs=2.29, kl=0.561, loss=3.73, task=0.477]


Epoch 3 | Val Acc: 57.04%



Epoch 4: 100%|██████████| 459/459 [03:24<00:00,  2.25it/s, attn=1.03, hs=2.23, kl=0.49, loss=3.3, task=0.407]


Epoch 4 | Val Acc: 64.58%



Epoch 5:  31%|███       | 141/459 [01:02<02:21,  2.25it/s, attn=1.07, hs=2.17, kl=0.667, loss=4.07, task=0.554]


KeyboardInterrupt: 

In [90]:
distil_model.load_state_dict(torch.load(f"best_distilled_model_fp16.pt"))
distil_model.eval()
correct = 0
total = 0
train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)
with torch.no_grad():
    for batch in tqdm(val_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = distil_model(input_ids, attention_mask=attention_mask)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
print("\n", correct / total)

100%|██████████| 216/216 [00:10<00:00, 19.67it/s]


 0.6504347826086957





**Вроде как шум спас от потери точности при конвертации fp32 -> fp16, но, походу, зашумленные fp32 градиенты хуже справляются со своей задачей, нежели зашумленные fp16**

## Часть 4. Просто дистиляция и torch.half()

In [None]:
wandb.init(project="bert_mrpc_part4", config={
    "a": 4,
    "b": 1,
    "c": 0.3,
    "d": 0.5,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.25
})

In [92]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.mha = nn.MultiheadAttention(
            embed_dim=config.bert_hidden_size,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.bert_hidden_size)
        )

        self.layernorm1 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.layernorm2 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, attention_mask=None, return_attention=False):

        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()
        else:
            attention_mask = None

        mha_out, attn_weights = self.mha(
            query=x,
            key=x,
            value=x,
            key_padding_mask=attention_mask,
            need_weights=return_attention,
            average_attn_weights=False
        )

        residual = x + self.dropout(mha_out)
        x = self.layernorm1(residual)

        ffn_out = self.ffn(x)
        residual = x + self.dropout(ffn_out)
        output = self.layernorm2(residual)

        if return_attention:
            return output, attn_weights
        else:
            return output

class CleanBERT(nn.Module):
    def __init__(self, embedding_layer, config, num_labels=2):
        super().__init__()
        self.embeddings = embedding_layer

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_blocks)
        ])

        self.pooler = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.bert_hidden_size),
            nn.Tanh()
        )

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.bert_hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None,
                return_hidden_states=False, return_attentions=False):
        x = self.embeddings(input_ids)

        all_hidden_states = []
        all_attentions = []

        if return_hidden_states:
            all_hidden_states.append(x)

        for block in self.blocks:
            if return_attentions:
                x, attn = block(x, attention_mask=attention_mask, return_attention=True)
                all_attentions.append(attn)
            else:
                x = block(x, attention_mask=attention_mask, return_attention=False)

            if return_hidden_states:
                all_hidden_states.append(x)

        cls_output = x[:, 0]

        pooled = self.pooler(cls_output)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        outputs = {"logits": logits}

        if return_hidden_states:
            outputs["hidden_states"] = all_hidden_states

        if return_attentions:
            outputs["attentions"] = all_attentions

        return outputs

In [93]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False

train_distil_model_2(distil_model, bert_model, tokenizer, device)

Epoch 1: 100%|██████████| 459/459 [03:26<00:00,  2.22it/s, attn=1.43, hs=3.56, kl=1.07, loss=5.5, task=0.662]


Epoch 1 | Val Acc: 68.23%



Epoch 2: 100%|██████████| 459/459 [03:24<00:00,  2.25it/s, attn=1.11, hs=2.5, kl=1.03, loss=4.8, task=0.617]


Epoch 2 | Val Acc: 64.70%



Epoch 3: 100%|██████████| 459/459 [03:24<00:00,  2.25it/s, attn=1.05, hs=2.35, kl=0.236, loss=3.3, task=0.459]


Epoch 3 | Val Acc: 62.96%



Epoch 4: 100%|██████████| 459/459 [03:24<00:00,  2.25it/s, attn=1.07, hs=2.2, kl=0.559, loss=3.95, task=0.549]


Epoch 4 | Val Acc: 64.81%



Epoch 5:  40%|████      | 184/459 [01:22<02:03,  2.23it/s, attn=1.1, hs=2.2, kl=0.411, loss=3.55, task=0.484]


KeyboardInterrupt: 

In [94]:
distil_model.load_state_dict(torch.load(f"best_distilled_model_fp16.pt"))
distil_model.eval()
correct = 0
total = 0
train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)
with torch.no_grad():
    for batch in tqdm(val_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = distil_model(input_ids, attention_mask=attention_mask)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
print("\n", correct / total)

100%|██████████| 216/216 [00:10<00:00, 19.73it/s]


 0.6823188405797102



