In [1]:
!pip install -U datasets -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/491.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m491.5/491.5 kB[0m [31m54.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/193.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.3.0 which is incompatible.
torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn.utils.prune as prune

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

BATCH_SIZE   = 16
LR           = 3e-5
EPOCHS       = 5
ALPHA        = 0.5
TEMPERATURE  = 2.55
MAX_LEN      = 128
PRUNE_AMOUNT = 0.5
NUM_LABELS   = 2

TEACHER_INTEL = "Intel/bert-base-uncased-mrpc"
TEACHER_BERT  = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(TEACHER_INTEL)
teacher_intel = AutoModelForSequenceClassification.from_pretrained(
    TEACHER_INTEL, num_labels=NUM_LABELS, output_attentions=True
).to(device).eval()
teacher_bert = AutoModelForSequenceClassification.from_pretrained(
    TEACHER_BERT, num_labels=NUM_LABELS
).to(device).eval()

# ------- Вспомогательные функции -------

def kl_div_loss(student_logits, teacher_logits, temperature=TEMPERATURE):
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs     = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature**2)

def hidden_state_loss(student_hs, teacher_hs):
    return F.mse_loss(student_hs, teacher_hs.detach())

def prepare_mrpc(tokenizer, batch_size=BATCH_SIZE):
    ds = load_dataset("glue", "mrpc")
    def prep(examples):
        return tokenizer(
            examples['sentence1'], examples['sentence2'],
            padding='max_length', truncation=True,
            max_length=MAX_LEN
        )
    ds = ds.map(prep, batched=True)
    ds = ds.rename_column("label", "labels")
    ds.set_format(
        "torch",
        columns=["input_ids", "attention_mask", "token_type_ids", "labels"]
    )
    train_dl = DataLoader(ds['train'], batch_size=batch_size, shuffle=True)
    val_dl   = DataLoader(ds['validation'], batch_size=batch_size)
    return train_dl, val_dl

class TransformerStudent(nn.Module):
    def __init__(self, teacher_emb_dim, hidden_size=768,
                 num_layers=4, num_heads=4, intermediate_size=1024):
        super().__init__()
        self.embed_proj = nn.Linear(teacher_emb_dim, hidden_size)
        self.register_buffer("pos_ids", torch.arange(MAX_LEN).unsqueeze(0))
        self.layers     = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_size, nhead=num_heads,
                dim_feedforward=intermediate_size, dropout=0.1,
                batch_first=True
            ) for _ in range(num_layers)
        ])
        self.pooler     = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()
        self.classifier = nn.Linear(hidden_size, NUM_LABELS)
        self.hid_proj   = nn.Linear(teacher_emb_dim, hidden_size)

    def forward(self, input_emb, attention_mask):
        x = self.embed_proj(input_emb)
        pad_mask = attention_mask == 0
        for layer in self.layers:
            x = layer(x, src_key_padding_mask=pad_mask)
        cls    = x[:, 0]
        pooled = self.activation(self.pooler(cls))
        logits = self.classifier(pooled)
        return logits, x


def evaluate(model, val_dl, teachers=None):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in val_dl:
            ids            = batch['input_ids'].to(device)
            mask           = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels         = batch['labels'].to(device)

            if isinstance(model, TransformerStudent):
                emb = teachers[0].bert.embeddings(
                    input_ids=ids,
                    token_type_ids=token_type_ids,
                    position_ids=model.pos_ids[:, :ids.size(1)].to(device)
                )
                logits, _ = model(emb, mask)
            else:
                outputs = model(
                    input_ids=ids,
                    attention_mask=mask,
                    token_type_ids=token_type_ids
                )
                logits = outputs.logits

            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)
    return correct / total


def train_and_prune(config_name, teachers, train_dl, val_dl):
    print(f"=== Student distillation from {config_name} ===")
    # Дистилляция
    embed_dim = teachers[0].bert.embeddings.word_embeddings.embedding_dim
    student = TransformerStudent(embed_dim).to(device)
    optimizer = AdamW(student.parameters(), lr=LR)

    for epoch in range(1, EPOCHS+1):
        student.train()
        for batch in train_dl:
            ids            = batch['input_ids'].to(device)
            mask           = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels         = batch['labels'].to(device)

            with torch.no_grad():
                emb = teachers[0].bert.embeddings(
                    input_ids=ids,
                    token_type_ids=token_type_ids,
                    position_ids=student.pos_ids[:, :ids.size(1)].to(device)
                )
                # логиты от всех учителей
                logits_list = [
                    t(
                        input_ids=ids,
                        attention_mask=mask,
                        token_type_ids=token_type_ids
                    ).logits for t in teachers
                ]
                t_logits = sum(logits_list) / len(logits_list)
                # скрытые состояния от каждого учителя (если один — берём его)
                hs_list = [
                    t(
                        input_ids=ids,
                        attention_mask=mask,
                        token_type_ids=token_type_ids,
                        output_hidden_states=True
                    ).hidden_states[-1]
                    for t in teachers
                ]
                # усредняем hidden_states
                t_hidden = sum(hs_list) / len(hs_list)

            s_logits, s_hs = student(emb, mask)
            t_hs_proj      = student.hid_proj(t_hidden)

            loss = (
                ALPHA * F.cross_entropy(s_logits, labels)
                + (1 - ALPHA) * kl_div_loss(s_logits, t_logits)
                + hidden_state_loss(s_hs, t_hs_proj)
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        acc = evaluate(student, val_dl, teachers)
        print(f"Epoch {epoch}: val_acc = {acc:.4f}")

    # Оценка до прунинга
    acc_before = evaluate(student, val_dl, teachers)
    print(f"Student val_acc (before prune) [{config_name}]: {acc_before:.4f}")

    # Прунинг
    params_to_prune = [(m, 'weight') for m in student.modules() if isinstance(m, nn.Linear)]
    prune.global_unstructured(
        params_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=PRUNE_AMOUNT
    )
    print(f"Applied global L1 pruning to student [{config_name}].")

    # Оценка после прунинга
    acc_after = evaluate(student, val_dl, teachers)
    print(f"Student val_acc (after prune)  [{config_name}]: {acc_after:.4f}\n")

def main():
    set_seed()
    train_dl, val_dl = prepare_mrpc(tokenizer)

    # Конфигурации дистилляции
    configs = [
        ("Intel", [teacher_intel]),
        ("BERT", [teacher_bert]),
        ("Intel+BERT", [teacher_intel, teacher_bert])
    ]

    # Для каждой конфигурации: дистилл., оценка, прунинг, оценка
    for name, teachers in configs:
        train_and_prune(name, teachers, train_dl, val_dl)

    # Оценка оригинальных teacher-моделей
    print("=== Original Teachers (no finetune) ===")
    for model, name in [(teacher_intel, "Intel"), (teacher_bert, "BERT")]:
        acc = evaluate(model, val_dl)
        print(f"{name} Teacher val_acc (unpruned): {acc:.4f}")

    # Оценка pruned teacher-моделей
    print("=== Pruned Teachers (no finetune) ===")
    for model_name, name in [(TEACHER_INTEL, "Intel"), (TEACHER_BERT, "BERT")]:
        m = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=NUM_LABELS
        ).to(device).eval()
        params = [(mod, 'weight') for mod in m.modules() if isinstance(mod, nn.Linear)]
        prune.global_unstructured(
            params, pruning_method=prune.L1Unstructured, amount=PRUNE_AMOUNT
        )
        acc = evaluate(m, val_dl)
        print(f"{name} Teacher val_acc (pruned):   {acc:.4f}")

if __name__ == "__main__":
    main()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/892 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/649k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/75.7k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/308k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

=== Student distillation from Intel ===




Epoch 1: val_acc = 0.6887
Epoch 2: val_acc = 0.6838
Epoch 3: val_acc = 0.7010
Epoch 4: val_acc = 0.6863
Epoch 5: val_acc = 0.6887
Student val_acc (before prune) [Intel]: 0.6887
Applied global L1 pruning to student [Intel].
Student val_acc (after prune)  [Intel]: 0.7010

=== Student distillation from BERT ===
Epoch 1: val_acc = 0.6765
Epoch 2: val_acc = 0.6838
Epoch 3: val_acc = 0.6985
Epoch 4: val_acc = 0.5711
Epoch 5: val_acc = 0.6887
Student val_acc (before prune) [BERT]: 0.6887
Applied global L1 pruning to student [BERT].
Student val_acc (after prune)  [BERT]: 0.6569

=== Student distillation from Intel+BERT ===
Epoch 1: val_acc = 0.6887
Epoch 2: val_acc = 0.6985
Epoch 3: val_acc = 0.7010
Epoch 4: val_acc = 0.7230
Epoch 5: val_acc = 0.6912
Student val_acc (before prune) [Intel+BERT]: 0.6912
Applied global L1 pruning to student [Intel+BERT].
Student val_acc (after prune)  [Intel+BERT]: 0.6961

=== Original Teachers (no finetune) ===
Intel Teacher val_acc (unpruned): 0.8603
BERT Teach

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERT Teacher val_acc (pruned):   0.6838
