In [None]:
!pip install -U datasets -q

In [2]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dataclasses import dataclass
import copy
import wandb
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)

In [None]:
print(bert_model)
print(sum(p.numel() for p in bert_model.parameters())) # 110M
# 768 embedding
# 12 слоёв, {Attention (768), Intermidiate + Output (768 -> 3072 -> 768)}
# 768 -> 4 classifier

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [None]:
print(bert_model.bert.embeddings)

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [4]:
@dataclass
class LSTMBERTConfig:
    bert_hidden_size: int = 768
    num_blocks: int = 6
    num_heads: int = 12
    intermediate_size: int = 1024
    dropout: float = 0.1

config = LSTMBERTConfig()

In [17]:
class LSTMTransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.lstm = nn.LSTM(
            input_size=config.bert_hidden_size,
            hidden_size=config.bert_hidden_size // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        self.mha = nn.MultiheadAttention(
            embed_dim=config.bert_hidden_size,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.bert_hidden_size)
        )

        self.layernorm1 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.layernorm2 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, attention_mask=None, return_attention=False):
        lstm_out, _ = self.lstm(x)

        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()
        else:
            attention_mask = None

        mha_out, attn_weights = self.mha(
            query=x,
            key=x,
            value=x,
            key_padding_mask=attention_mask,
            need_weights=return_attention
        )

        combined = self.layernorm1(mha_out + lstm_out)
        combined = self.dropout(combined)

        ffn_out = self.ffn(combined)
        output = self.layernorm2(ffn_out + combined)
        output = self.dropout(output)

        if return_attention:
            return output, attn_weights
        else:
            return output

class CleanBERT(nn.Module):
    def __init__(self, embedding_layer, config, num_labels=4):
        super().__init__()
        self.embeddings = embedding_layer

        self.blocks = nn.ModuleList([
            LSTMTransformerBlock(config) for _ in range(config.num_blocks)
        ])

        self.pooler = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.bert_hidden_size),
            nn.Tanh()
        )

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.bert_hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None,
                return_hidden_states=False, return_attentions=False):
        x = self.embeddings(input_ids)

        all_hidden_states = []
        all_attentions = []

        if return_hidden_states:
            all_hidden_states.append(x)

        for block in self.blocks:
            if return_attentions:
                x, attn = block(x, attention_mask=attention_mask, return_attention=True)
                all_attentions.append(attn)
            else:
                x = block(x, attention_mask=attention_mask, return_attention=False)

            if return_hidden_states:
                all_hidden_states.append(x)

        cls_output = x[:, 0]

        pooled = self.pooler(cls_output)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        outputs = {"logits": logits}

        if return_hidden_states:
            outputs["hidden_states"] = all_hidden_states

        if return_attentions:
            outputs["attentions"] = all_attentions

        return outputs

In [28]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False

In [None]:
sentences = [
    "Hello, how are you?",
    "Transformers are amazing.",
    "Distill BERT into a smaller model.",
    "We reduce dimensions but keep performance."
]

inputs = tokenizer(
    sentences,
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt"
).to(device)

out = distil_model(inputs["input_ids"], attention_mask=inputs["attention_mask"],
                   return_hidden_states=True, return_attentions=True)
logits, hidden_states, attentions = out["logits"], out["hidden_states"], out["attentions"]
print(logits, len(hidden_states), hidden_states[0].shape, len(attentions), attentions[0].shape)

tensor([[-0.0782,  0.1744, -0.6629,  0.1904],
        [ 0.1722,  0.2532, -0.2189,  0.3747],
        [    nan,     nan,     nan,     nan],
        [-0.2258, -0.3420, -0.8435, -0.1836]], device='cuda:0',
       grad_fn=<AddmmBackward0>) 7 torch.Size([4, 11, 768]) 6 torch.Size([4, 11, 11])


In [None]:
print(distil_model)
print(sum(p.numel() for p in distil_model.parameters())) # 69M (0.62 * BertOrigNumParam)

CleanBERT(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (blocks): ModuleList(
    (0-5): 6 x LSTMTransformerBlock(
      (lstm): LSTM(768, 384, batch_first=True, bidirectional=True)
      (mha): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=768, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=768, bias=True)
      )
      (layernorm1): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (layernorm2): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (poo

Я беру уже готовые эмбединги. Так что CosDist loss не считаю. Attention loss не использую, у меня разное количество голов.

Loss = a * task_loss (CrossEntropy) + b * logits_loss (KLDiv) + c * hidden_states (MSE)

In [None]:
with torch.no_grad():
    outputs = bert_model(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        output_hidden_states=True,
        output_attentions=True
    )

# Получаем hidden_states: tuple из 13 элементов (embeddings + 12 слоёв)
hidden_states = outputs.hidden_states
attentions = outputs.attentions
hidden_states[12].shape
print(len(hidden_states), hidden_states[0].shape)
print(len(attentions), attentions[0].shape)

In [8]:
def prepare_data(batch_size=16):
    dataset = load_dataset("wangrongsheng/ag_news")

    def tokenize(batch):
        return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")

    encoded_dataset = dataset.map(tokenize, batched=True, batch_size=batch_size)
    encoded_dataset = encoded_dataset.remove_columns(["text"])
    encoded_dataset = encoded_dataset.rename_column("label", "labels")
    encoded_dataset.set_format("torch")

    train_loader = DataLoader(encoded_dataset["train"], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(encoded_dataset["test"], batch_size=batch_size)

    return train_loader, val_loader

def kl_div_loss(student_logits, teacher_logits, temperature=2.55):
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

def hidden_state_loss(student_hs, teacher_hs):
    return F.mse_loss(student_hs, teacher_hs.detach())


In [9]:
def train_teacher(model, train_loader, val_loader, device, epochs=3, lr=2e-5, save_path="bert_teacher_ag_news.pt"):
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * epochs

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")

        for i, batch in enumerate(pb):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            pb.set_postfix(loss=loss.item())

            if i == 100:
                break

        avg_train_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch + 1} | Train Loss: {avg_train_loss:.4f}")

        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)

                if i == 100:
                    break

        accuracy = correct / total
        print(f"Validation Accuracy: {accuracy * 100:.2f}%\n")

        model.train()

    # --- Сохранение модели ---
    torch.save(model.state_dict(), save_path)

    return model

In [10]:
# bert_model.load_state_dict(torch.load("bert_teacher_ag_news.pt"))
bert_model  = AutoModelForSequenceClassification.from_pretrained("fabriceyhc/bert-base-uncased-ag_news")
bert_model.eval()
bert_model = bert_model.to(device)

config.json:   0%|          | 0.00/919 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [29]:
wandb.init(project="distil_bert_ag_news", config={
    "a": 1.55,
    "b": 0.5,
    "c": 0.3,
    "batch_size": 8,
    "epochs": 5,
    "lr": 6e-5,
    "temperature": 2.55
})

0,1
hs_loss,█▁▃▇▁▁▅
kl_loss,█▆▁▅▂▁▂
step,▁▂▃▅▆▇█
task_loss,▇█▁▄▅▃▃
train_loss,█▆▁▅▃▂▃

0,1
hs_loss,9.67016
kl_loss,5.50192
step,7.0
task_loss,1.3368
train_loss,7.72405


In [20]:
def train_distil_model(student_model, teacher_model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

    optimizer = AdamW(student_model.parameters(), lr=wandb.config.lr)
    total_steps = len(train_loader) * wandb.config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    best_val_acc = 0.0
    step = 0

    for epoch in range(wandb.config.epochs):
        student_model.train()
        teacher_model.eval()

        total_loss = 0
        total_task_loss = 0
        total_kl_loss = 0
        total_hs_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True
                )
            teacher_logits = teacher_outputs.logits
            teacher_hiddens = teacher_outputs.hidden_states

            student_outputs = student_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_hidden_states=True
            )
            student_logits = student_outputs["logits"] if isinstance(student_outputs, dict) else student_outputs
            student_hiddens = student_outputs["hidden_states"] if isinstance(student_outputs, dict) else []

            task_loss = F.cross_entropy(student_logits, labels)
            kl_loss = kl_div_loss(student_logits, teacher_logits, temperature=wandb.config.temperature)

            hs_loss = 0
            for i, student_hs in enumerate(student_hiddens):
                try:
                    teacher_hs = teacher_hiddens[i * 2]
                    hs_loss += hidden_state_loss(student_hs, teacher_hs)
                except IndexError:
                    break

            total_loss = (
                wandb.config.a * task_loss +
                wandb.config.b * kl_loss +
                wandb.config.c * hs_loss
            )

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss_value = total_loss.item()
            total_loss += total_loss_value
            total_task_loss += task_loss.item()
            total_kl_loss += kl_loss.item()
            total_hs_loss += hs_loss.item()

            pb.set_postfix(loss=total_loss_value, task=task_loss.item(),
                           kl=kl_loss.item(), hs=hs_loss.item())

            wandb.log({
            "step": step + 1,
            "train_loss": total_loss_value,
            "task_loss": task_loss.item(),
            "kl_loss": kl_loss.item(),
            "hs_loss": hs_loss.item()
            })

            step += 1

        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = student_model(input_ids, attention_mask=attention_mask)
                logits = outputs["logits"] if isinstance(outputs, dict) else outputs
                predictions = torch.argmax(logits, dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total

        wandb.log({"epoch": epoch + 1, "val_acc": val_acc})

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(student_model.state_dict(), "best_distilled_model.pt")

        print(f"Epoch {epoch+1} | Train Loss: {total_loss / len(train_loader):.4f} | Val Acc: {val_acc * 100:.2f}%\n")

    return student_model

In [30]:
train_distil_model(distil_model, bert_model, tokenizer, device)

Epoch 1:   9%|▉         | 1374/15000 [19:36<3:14:25,  1.17it/s, hs=1.84, kl=0.03, loss=0.574, task=0.00476]


KeyboardInterrupt: 

Возможно лоссы стоит нормализовывать, чтобы их значения были в одном масштабе.

In [31]:
distil_model.eval()
correct = 0
total = 0
train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)
with torch.no_grad():
    for batch in tqdm(val_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = distil_model(input_ids, attention_mask=attention_mask)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

100%|██████████| 950/950 [03:42<00:00,  4.26it/s]


In [32]:
print(correct / total)
# wandb: https://wandb.ai/honkers/distil_bert_ag_news
# У учителя было че-то около 0.93

0.8993421052631579
