In [2]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dataclasses import dataclass
import copy
import wandb
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from datasets import load_dataset
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import os

In [3]:
tokenizer = AutoTokenizer.from_pretrained('Intel/bert-base-uncased-mrpc')
bert_model = AutoModelForSequenceClassification.from_pretrained("Intel/bert-base-uncased-mrpc").to(device)
bert_model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/892 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [4]:
@dataclass
class DistBERTConfig:
    bert_hidden_size: int = 768
    num_blocks: int = 4
    num_heads: int = 12
    intermediate_size: int = 1024
    dropout: float = 0.1

config = DistBERTConfig()

In [5]:
# pip install -U datasets fsspec huggingface_hub


In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.mha = nn.MultiheadAttention(
            embed_dim=config.bert_hidden_size,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True
        )

        self.ffn = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.bert_hidden_size)
        )

        self.layernorm1 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.layernorm2 = nn.LayerNorm(config.bert_hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, attention_mask=None, return_attention=False):

        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()
        else:
            attention_mask = None

        mha_out, attn_weights = self.mha(
            query=x,
            key=x,
            value=x,
            key_padding_mask=attention_mask,
            need_weights=return_attention,
            average_attn_weights=False
        )

        residual = x + self.dropout(mha_out)
        x = self.layernorm1(residual)

        ffn_out = self.ffn(x)
        residual = x + self.dropout(ffn_out)
        output = self.layernorm2(residual)

        if return_attention:
            return output, attn_weights
        else:
            return output

In [7]:

class CleanBERT(nn.Module):
    def __init__(self, embedding_layer, config, num_labels=2):
        super().__init__()
        self.embeddings = embedding_layer

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_blocks)
        ])

        self.pooler = nn.Sequential(
            nn.Linear(config.bert_hidden_size, config.bert_hidden_size),
            nn.Tanh()
        )

        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.bert_hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None,
                return_hidden_states=False, return_attentions=False):
        x = self.embeddings(input_ids)

        all_hidden_states = []
        all_attentions = []

        if return_hidden_states:
            all_hidden_states.append(x)

        for block in self.blocks:
            if return_attentions:
                x, attn = block(x, attention_mask=attention_mask, return_attention=True)
                all_attentions.append(attn)
            else:
                x = block(x, attention_mask=attention_mask, return_attention=False)

            if return_hidden_states:
                all_hidden_states.append(x)

        cls_output = x[:, 0]

        pooled = self.pooler(cls_output)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        outputs = {"logits": logits}

        if return_hidden_states:
            outputs["hidden_states"] = all_hidden_states

        if return_attentions:
            outputs["attentions"] = all_attentions

        return outputs

In [8]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False


In [9]:

sentence_pairs = [
    ["Hello, how are you?", "Hi, what's up?"],
    ["Transformers are amazing", "BERT is a powerful model"],
    ["Distill BERT into a smaller model", "Make BERT lighter while preserving performance"],
    ["We reduce dimensions but keep performance", "Performance stays the same after compression"],
]

inputs = tokenizer(
    text=[pair[0] for pair in sentence_pairs],
    text_pair=[pair[1] for pair in sentence_pairs],
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt"
).to(device)

out = distil_model(inputs["input_ids"], attention_mask=inputs["attention_mask"],
                   return_hidden_states=True, return_attentions=True)
logits, hidden_states, attentions = out["logits"], out["hidden_states"], out["attentions"]
print(logits, len(hidden_states), hidden_states[0].shape, len(attentions), attentions[0].shape)

tensor([[-0.4870, -0.1048],
        [-0.4097, -0.1200],
        [-0.6067, -0.1512],
        [-0.5719,  0.0411]], device='cuda:0', grad_fn=<AddmmBackward0>) 5 torch.Size([4, 17, 768]) 4 torch.Size([4, 12, 17, 17])


In [10]:

def prepare_data(batch_size=16):
    dataset = load_dataset("nyu-mll/glue", "mrpc")

    def tokenize(batch):
        return tokenizer(
            batch["sentence1"],
            batch["sentence2"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

    encoded_dataset = dataset.map(tokenize, batched=True, batch_size=batch_size)
    encoded_dataset = encoded_dataset.remove_columns(["sentence1", "sentence2", "idx"])
    encoded_dataset = encoded_dataset.rename_column("label", "labels")
    encoded_dataset.set_format("torch")

    train_loader = DataLoader(encoded_dataset["train"], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(encoded_dataset["test"], batch_size=batch_size)

    return train_loader, val_loader

def kl_div_loss(student_logits, teacher_logits, temperature=2.3):
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

def hidden_state_loss(student_hs, teacher_hs):
    return F.mse_loss(student_hs, teacher_hs.detach())


def attention_kl_loss(student_attns, teacher_attns, mapping_attn, temperature=2.3):
    total_loss = 0.0
    num_layers = len(student_attns)

    for student_idx in range(num_layers):
        student_attn = student_attns[student_idx]  # [B, H, T, T]
        teacher_attn = teacher_attns[mapping_attn[student_idx]]  # [B, H, T, T]

        B, H, T, _ = student_attn.shape

        # Применяем температурное масштабирование
        student_attn = student_attn / temperature
        teacher_attn = teacher_attn / temperature

        # Вычисляем логарифмы вероятностей и вероятности
        student_log_probs = F.log_softmax(student_attn, dim=-1)  # [B, H, T, T]
        teacher_probs = F.softmax(teacher_attn, dim=-1)  # [B, H, T, T]

        # Вычисляем KL-дивергенцию для каждого элемента
        kl_per_element = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction='none'
        )  # [B, H, T, T]

        # Суммируем по последнему измерению (T)
        kl_per_token = kl_per_element.sum(dim=-1)  # [B, H, T]

        # Суммируем по всем головам и токенам (но не по батчу!)
        kl_per_layer = kl_per_token.sum(dim=(1, 2))  # [B]

        # Усредняем по батчу
        layer_loss = kl_per_layer.mean()

        # Масштабируем обратно температурой
        layer_loss = layer_loss * (temperature ** 2)

        total_loss += layer_loss

    return total_loss / num_layers

In [11]:
wandb.init(project="bert_mrpc_fp16dist", config={
    "a": 4,
    "b": 1,
    "c": 0.3,
    "d": 0.5,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.25
})

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33megor283693[0m ([33megor283693-hse-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Обучение

In [None]:
wandb.init(project="bert_mrpc_fp16dist", config={
    "a": 4,
    "b": 1,
    "c": 0.3,
    "d": 0.5,
    "batch_size": 8,
    "epochs": 13,
    "lr": 6e-5,
    "temperature": 1.25
})

In [None]:
# wandb.init(project="bert_mrpc_fp16dist", config={
#     "a": 1,
#     "b": 1,
#     "c": 1,
#     "d": 1,
#     "batch_size": 8,
#     "epochs": 13,
#     "lr": 6e-5,
#     "temperature": 1.25
# })

In [None]:
def train_distil_model(student_model, teacher_model, tokenizer, device):
    train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

    optimizer = AdamW(student_model.parameters(), lr=wandb.config.lr)
    total_steps = len(train_loader) * wandb.config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=total_steps
    )

    best_val_acc = 0.0
    step = 0

    scaler = GradScaler('cuda')

    mapping_attn = [2, 5, 8, 11]
    mapping_mse_full = [0, 3, 6, 9, 12]

    student_model.train()
    teacher_model.eval()

    for epoch in range(wandb.config.epochs):
        total_loss = 0
        total_task_loss = 0
        total_kl_loss = 0
        total_hs_loss = 0
        total_attn_loss = 0

        pb = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
        for batch in pb:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    output_attentions=True
                )
            teacher_logits = teacher_outputs.logits
            teacher_hiddens = teacher_outputs.hidden_states
            teacher_attns = teacher_outputs.attentions

            with autocast('cuda', dtype=torch.float16):
                student_outputs = student_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_hidden_states=True,
                    return_attentions=True
                )
                student_logits = student_outputs["logits"]
                student_hiddens = student_outputs["hidden_states"]
                student_attns = student_outputs["attentions"]

                task_loss = F.cross_entropy(student_logits, labels)
                kl_loss = kl_div_loss(student_logits, teacher_logits, temperature=wandb.config.temperature)

                # hs_loss = 0
                # for i, student_hs in enumerate(student_hiddens):
                #     teacher_hs = teacher_hiddens[mapping_mse_full[i]]
                #     hs_loss += hidden_state_loss(student_hs, teacher_hs)

                min_len = min(len(student_hiddens), len(mapping_mse_full))
                hs_loss = 0
                for i in range(min_len):
                    hs_loss += hidden_state_loss(student_hiddens[i], teacher_hiddens[mapping_mse_full[i]])

                attn_loss = attention_kl_loss(
                    student_attns,
                    teacher_attns,
                    mapping_attn,
                    temperature=wandb.config.temperature
                )

                total_loss = (
                    wandb.config.a * task_loss +
                    wandb.config.b * kl_loss +
                    wandb.config.c * hs_loss +
                    wandb.config.d * attn_loss
                )

            scaler.scale(total_loss).backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

            total_loss_value = total_loss.item()
            total_task_loss += task_loss.item()
            total_kl_loss += kl_loss.item()
            total_hs_loss += hs_loss.item()
            total_attn_loss += attn_loss.item()

            pb.set_postfix(
                loss=total_loss_value,
                task=task_loss.item(),
                kl=kl_loss.item(),
                hs=hs_loss.item(),
                attn=attn_loss.item()
            )

            wandb.log({
                "step": step + 1,
                "train_loss": total_loss_value,
                "task_loss": task_loss.item(),
                "kl_loss": kl_loss.item(),
                "hs_loss": hs_loss.item(),
                "attn_loss": attn_loss.item()
            })

            step += 1

        student_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = student_model(input_ids, attention_mask=attention_mask)
                logits = outputs["logits"] if isinstance(outputs, dict) else outputs
                predictions = torch.argmax(logits, dim=1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total
        wandb.log({"epoch": epoch + 1, "val_acc": val_acc})

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(student_model.state_dict(), "best_distilled_model.pt")

            state_dict = {}
            for k, v in student_model.state_dict().items():
                state_dict[k] = v.clone().half()

            torch.save(state_dict, "best_distilled_model_fp16.pt")

        print(f"Epoch {epoch+1} | Val Acc: {val_acc * 100:.2f}%\n")

    return student_model


In [None]:
train_distil_model(distil_model, bert_model, tokenizer, device)

In [None]:
# torch.save(distil_model.state_dict(), "student_model_fp32.pt")

# Тут уже новая часть:

## Тут я взял модель и обучил её на 10 эпох, взял за основу:

In [12]:
distil_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in distil_model.embeddings.parameters():
    param.requires_grad = False


In [13]:
state_dict = torch.load("/content/sample_data/student_model_fp32.pt")
for k in state_dict:
    state_dict[k] = state_dict[k].float()

distil_model.load_state_dict(state_dict)


<All keys matched successfully>

In [14]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs["logits"]
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total


In [15]:
train_loader, val_loader = prepare_data(batch_size=wandb.config.batch_size)

README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/649k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/75.7k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/308k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

In [None]:
evaluate(distil_model, val_loader, device)

0.6272463768115942

In [None]:
import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    distil_model.cpu(),
    {nn.Linear},
    dtype=torch.qint8
)



In [None]:
device = 'cuda'
quantized_model.to('cpu')

CleanBERT(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (mha): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (ffn): Sequential(
        (0): DynamicQuantizedLinear(in_features=768, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (1): GELU(approximate='none')
        (2): DynamicQuantizedLinear(in_features=1024, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      )
      (layernorm1): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (layernorm2): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dro

In [None]:
evaluate(quantized_model, val_loader, device='cpu')

0.6295652173913043

In [None]:
torch.save(distil_model.state_dict(), "fp32.pt")
torch.save(quantized_model.state_dict(), "int8.pt")


print("FP32 size:", os.path.getsize("fp32.pt") / 1e6, "MB")
print("INT8 size:", os.path.getsize("int8.pt") / 1e6, "MB")


FP32 size: 160.774324 MB
INT8 size: 140.132924 MB


### Почти не изменилась точность, но немного уменьшилась по размеру

# Квантизация и дообучение:

### Тут прунинг, потом файнтюнинг, потом квантование

In [None]:
import torch.nn.utils.prune as prune

def apply_pruning(model, amount=0.2):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)
    return model


In [None]:
# state_dict = torch.load("/content/sample_data/student_model_fp32.pt")
# # state_dict = torch.load("/content/int8.pt")
# for k in state_dict:
#     state_dict[k] = state_dict[k].float()

# distil_model.load_state_dict(state_dict)

In [None]:
def finetune_after_pruning(model, train_loader, val_loader, device, epochs=2, lr=1e-4):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs["logits"]
            loss = criterion(logits, labels)
            loss.backward(retain_graph=True)
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} finetune loss: {total_loss/len(train_loader):.4f}")


In [None]:
pruned_model = apply_pruning(distil_model)

### Тут я запрунил модель и немного зафайнтюнил, чтобы потом уже заквантовать


In [None]:
finetune_after_pruning(pruned_model, train_loader, val_loader, device, epochs=10, lr=1e-4)

Epoch 1 finetune loss: 0.2850
Epoch 2 finetune loss: 0.2455
Epoch 3 finetune loss: 0.2092
Epoch 4 finetune loss: 0.1658
Epoch 5 finetune loss: 0.1474
Epoch 6 finetune loss: 0.1270
Epoch 7 finetune loss: 0.1012
Epoch 8 finetune loss: 0.0960
Epoch 9 finetune loss: 0.0782
Epoch 10 finetune loss: 0.0742


In [None]:
# torch.save(quantized_model.state_dict(), "int8.pt")
torch.save(distil_model.state_dict(), "finetuned_model.pt")


# print("FP32 size:", os.path.getsize("fp32.pt") / 1e6, "MB")
# print("INT8 size:", os.path.getsize("int8.pt") / 1e6, "MB")


### Теперь сравним просто дистилированную модель до и после файнтютинга:

In [None]:
# state_dict = torch.load("/content/sample_data/student_model_fp32.pt")
# for k in state_dict:
#     state_dict[k] = state_dict[k].float()

# distil_model.load_state_dict(state_dict)

# distil_model.to(device)
# evaluate_model(distil_model, val_loader, device)

In [16]:
pruned_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to(device)
for param in pruned_model.embeddings.parameters():
    param.requires_grad = False

In [17]:
state_dict = torch.load("/content/sample_data/finetuned_model.pt",  map_location=torch.device("cpu"))
for k in state_dict:
    state_dict[k] = state_dict[k].float()

pruned_model.load_state_dict(state_dict)

<All keys matched successfully>

In [18]:
evaluate(pruned_model, val_loader, device)

0.6568115942028986

## Точность возросла на ~ 2 %

In [21]:
import torch.quantization

quantized_model_after_pruning = torch.quantization.quantize_dynamic(
    pruned_model,
    {nn.Linear},
    dtype=torch.qint8
)


In [26]:
evaluate(quantized_model_after_pruning.cpu(), val_loader, 'cpu')

0.6579710144927536

In [28]:
torch.save(pruned_model.state_dict(), "fp32.pt")
torch.save(quantized_model_after_pruning.state_dict(), "int8.pt")


print("FP32 size:", os.path.getsize("fp32.pt") / 1e6, "MB")
print("INT8 size:", os.path.getsize("int8.pt") / 1e6, "MB")


FP32 size: 160.774324 MB
INT8 size: 140.132924 MB


## Тут возросла точность у обеих моделей, квантование не сильно порезало точность также

In [None]:
import torch.quantization as quant

def qat_training(model, train_loader, val_loader, device, epochs=3, lr=5e-5):
    model.qconfig = quant.get_default_qat_qconfig('fbgemm')
    quant.prepare_qat(model, inplace=True)

    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs["logits"]
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} QAT loss: {total_loss/len(train_loader):.4f}")


    quantized_model = quant.convert(model.eval(), inplace=False)
    return quantized_model


## Тут сделаем qat, обучим дистилированную модель:

In [32]:
import torch
import torch.nn as nn
import torch.ao.quantization as quant
from torch.ao.quantization import float_qparams_weight_only_qconfig, get_default_qat_qconfig
import copy

In [33]:
qat_model = CleanBERT(copy.deepcopy(bert_model.bert.embeddings), config).to("cpu")

distilled_state_dict = torch.load("/content/sample_data/student_model_fp32.pt", map_location="cpu")
qat_model.load_state_dict(distilled_state_dict)

# Заменяем NonDynamicallyQuantizableLinear на обычные nn.Linear рекурсивно
def replace_non_dynamically_quantizable_linear(module):
    for name, child in module.named_children():
        if isinstance(child, torch.nn.modules.linear.NonDynamicallyQuantizableLinear):
            new_linear = nn.Linear(child.in_features, child.out_features, bias=child.bias is not None)
            new_linear.weight.data = child.weight.data.clone()
            if child.bias is not None:
                new_linear.bias.data = child.bias.data.clone()
            setattr(module, name, new_linear)
        else:
            replace_non_dynamically_quantizable_linear(child)

replace_non_dynamically_quantizable_linear(qat_model)

qat_model.qconfig = None
qat_model.qconfig_dict = {
    "": get_default_qat_qconfig('fbgemm'),
    "embeddings": float_qparams_weight_only_qconfig,
}

quant.prepare_qat(qat_model, inplace=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
qat_model.to(device)
qat_model.train()


def compute_accuracy(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs["logits"]
            val_loss += criterion(logits, labels).item()

            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    avg_loss = val_loss / len(data_loader)
    return accuracy, avg_loss

def train_qat_model(model, train_loader, val_loader, device, epochs=3):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=1, factor=0.5)

    best_val_acc = 0.0
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(epochs):
        model.train()
        epoch_train_loss = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')

        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs["logits"]
            loss = criterion(logits, labels)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_train_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

        avg_train_loss = epoch_train_loss / len(train_loader)
        val_acc, avg_val_loss =  compute_accuracy(model, val_loader, device)

        scheduler.step(val_acc)

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_qat_model.pth")

        print(f"\nEpoch {epoch+1}/{epochs}:")
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"Val Accuracy: {val_acc:.4f}")
        print(f"Best Val Accuracy: {best_val_acc:.4f}")
        print("-" * 50)

    return model, history




In [34]:
trained_qat_model, training_history = train_qat_model(
    qat_model,
    train_loader,
    val_loader,
    device,
    epochs=1
)

Epoch 1/1: 100%|██████████| 459/459 [01:26<00:00,  5.29it/s, loss=0.000388]



Epoch 1/1:
Train Loss: 0.1398 | Val Loss: 2.4577
Val Accuracy: 0.6441
Best Val Accuracy: 0.6441
--------------------------------------------------


## Сначала обучил на одну эпоху, точность (возросла с 62,7%) :

In [35]:
evaluate(trained_qat_model, val_loader, device)

0.6440579710144928

## Тут обучаю побольше:

In [None]:
trained_qat_model, training_history = train_qat_model(
    qat_model,
    train_loader,
    val_loader,
    device,
    epochs=3
)

Epoch 1/3: 100%|██████████| 459/459 [01:25<00:00,  5.34it/s, loss=0.000431]



Epoch 1/3:
Train Loss: 0.1423 | Val Loss: 2.4796
Val Accuracy: 0.6557
Best Val Accuracy: 0.6557
--------------------------------------------------


Epoch 2/3: 100%|██████████| 459/459 [01:26<00:00,  5.28it/s, loss=0.000493]



Epoch 2/3:
Train Loss: 0.1136 | Val Loss: 2.6727
Val Accuracy: 0.6504
Best Val Accuracy: 0.6557
--------------------------------------------------


Epoch 3/3: 100%|██████████| 459/459 [01:26<00:00,  5.33it/s, loss=0.000317]



Epoch 3/3:
Train Loss: 0.0912 | Val Loss: 2.7288
Val Accuracy: 0.6371
Best Val Accuracy: 0.6557
--------------------------------------------------


## Тут лосс очень скачет сильно, лучшая точность на первой эпохе в итоге

In [None]:
evaluate(trained_qat_model, val_loader, device)

0.6371014492753623

### Максимальная точность при pruning + finetuning + квантование модели