In [1]:
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
import tabulate
import numpy as np
from sklearn.metrics import accuracy_score, classification_report

# pip install transformers==4.45.2 



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [2]:
dataset = load_dataset('dair-ai/emotion')
print(dataset)

labels = dataset["train"].features["label"].names
print(labels)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']


In [3]:
model_name = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [4]:
def tokenize_data(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

# Токенизация всех данных
tokenized_dataset = dataset.map(tokenize_data, batched=True)
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

In [5]:
def compute_metrics(eval_pred):
    logits, true_labels = eval_pred  # Переименовали переменную
    predictions = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(true_labels, predictions),
        **classification_report(
            true_labels, 
            predictions,
            target_names=labels,  # Используем глобальный список названий
            output_dict=True,
            zero_division=0
        )["macro avg"]
    }


### Full finetuning

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(labels)
)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    # max_steps=1000,
    num_train_epochs=3,
    logging_steps=200,
    weight_decay=0.01,
    metric_for_best_model="f1-score",
    logging_dir="./logs",
    report_to="none",
    load_best_model_at_end=True,
    save_total_limit=2,
    # lr_scheduler_type="cosine",
    # warmup_steps=500
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    compute_metrics=compute_metrics
)

results = trainer.evaluate(tokenized_dataset["test"])
print(tabulate.tabulate(
    results.items(),
    headers=["Метрика", "Значение"],
    tablefmt="grid",
    floatfmt=".4f"
))
trainer.train()

results = trainer.evaluate(tokenized_dataset["test"])
print(tabulate.tabulate(
    results.items(),
    headers=["Метрика", "Значение"],
    tablefmt="grid",
    floatfmt=".4f"
))

# Training time: 31m


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total trainable parameters: 109,486,854


100%|██████████| 125/125 [00:23<00:00,  5.27it/s]


+-----------------------------+------------+
| Метрика                     |   Значение |
| eval_loss                   |     1.8294 |
+-----------------------------+------------+
| eval_model_preparation_time |     0.0000 |
+-----------------------------+------------+
| eval_accuracy               |     0.1230 |
+-----------------------------+------------+
| eval_precision              |     0.0759 |
+-----------------------------+------------+
| eval_recall                 |     0.1425 |
+-----------------------------+------------+
| eval_f1-score               |     0.0772 |
+-----------------------------+------------+
| eval_support                |  2000.0000 |
+-----------------------------+------------+
| eval_runtime                |    29.0365 |
+-----------------------------+------------+
| eval_samples_per_second     |    68.8790 |
+-----------------------------+------------+
| eval_steps_per_second       |     4.3050 |
+-----------------------------+------------+


  7%|▋         | 200/3000 [01:38<22:20,  2.09it/s]

{'loss': 1.1608, 'grad_norm': 12.450788497924805, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.2}



  7%|▋         | 200/3000 [02:00<22:20,  2.09it/s]

{'eval_loss': 0.644923210144043, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.7615, 'eval_precision': 0.7601021273808356, 'eval_recall': 0.5647326144337398, 'eval_f1-score': 0.5460492846470477, 'eval_support': 2000.0, 'eval_runtime': 22.6554, 'eval_samples_per_second': 88.279, 'eval_steps_per_second': 5.517, 'epoch': 0.2}


 13%|█▎        | 400/3000 [03:36<20:48,  2.08it/s]  

{'loss': 0.4599, 'grad_norm': 14.408431053161621, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.4}


                                                  
 13%|█▎        | 400/3000 [03:59<20:48,  2.08it/s]

{'eval_loss': 0.322631299495697, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9005, 'eval_precision': 0.8812252738934707, 'eval_recall': 0.8641717734232692, 'eval_f1-score': 0.8651459615859552, 'eval_support': 2000.0, 'eval_runtime': 22.7832, 'eval_samples_per_second': 87.784, 'eval_steps_per_second': 5.487, 'epoch': 0.4}


 20%|██        | 600/3000 [05:35<19:02,  2.10it/s]  

{'loss': 0.2807, 'grad_norm': 3.124316930770874, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.6}


                                                  
 20%|██        | 600/3000 [05:57<19:02,  2.10it/s]

{'eval_loss': 0.2428482174873352, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9125, 'eval_precision': 0.8905662177831241, 'eval_recall': 0.8741643316423366, 'eval_f1-score': 0.8819053717437976, 'eval_support': 2000.0, 'eval_runtime': 22.6018, 'eval_samples_per_second': 88.489, 'eval_steps_per_second': 5.531, 'epoch': 0.6}


 27%|██▋       | 800/3000 [07:33<17:34,  2.09it/s]  

{'loss': 0.231, 'grad_norm': 19.92230987548828, 'learning_rate': 1.4666666666666666e-05, 'epoch': 0.8}


                                                  
 27%|██▋       | 800/3000 [07:56<17:34,  2.09it/s]

{'eval_loss': 0.2153342217206955, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9265, 'eval_precision': 0.9000284674097313, 'eval_recall': 0.8979340920323375, 'eval_f1-score': 0.8975037059040692, 'eval_support': 2000.0, 'eval_runtime': 22.5921, 'eval_samples_per_second': 88.527, 'eval_steps_per_second': 5.533, 'epoch': 0.8}


 33%|███▎      | 1000/3000 [09:32<15:59,  2.08it/s] 

{'loss': 0.2157, 'grad_norm': 9.32910442352295, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.0}


                                                   
 33%|███▎      | 1000/3000 [09:54<15:59,  2.08it/s]

{'eval_loss': 0.2090146541595459, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9245, 'eval_precision': 0.8837928522368763, 'eval_recall': 0.922862172876505, 'eval_f1-score': 0.9000972049913175, 'eval_support': 2000.0, 'eval_runtime': 22.615, 'eval_samples_per_second': 88.437, 'eval_steps_per_second': 5.527, 'epoch': 1.0}


 40%|████      | 1200/3000 [11:37<14:19,  2.09it/s]  

{'loss': 0.1414, 'grad_norm': 9.07302474975586, 'learning_rate': 1.2e-05, 'epoch': 1.2}


                                                   
 40%|████      | 1200/3000 [11:59<14:19,  2.09it/s]

{'eval_loss': 0.18555593490600586, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.937, 'eval_precision': 0.9139159415747301, 'eval_recall': 0.9115310924951586, 'eval_f1-score': 0.9122375104390636, 'eval_support': 2000.0, 'eval_runtime': 22.5539, 'eval_samples_per_second': 88.676, 'eval_steps_per_second': 5.542, 'epoch': 1.2}


 47%|████▋     | 1400/3000 [13:35<12:44,  2.09it/s]  

{'loss': 0.136, 'grad_norm': 10.079730987548828, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.4}


                                                   
 47%|████▋     | 1400/3000 [13:57<12:44,  2.09it/s]

{'eval_loss': 0.17194339632987976, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.938, 'eval_precision': 0.9240609767463629, 'eval_recall': 0.8994831643040618, 'eval_f1-score': 0.9102885815804113, 'eval_support': 2000.0, 'eval_runtime': 22.5748, 'eval_samples_per_second': 88.594, 'eval_steps_per_second': 5.537, 'epoch': 1.4}


 53%|█████▎    | 1600/3000 [15:33<11:12,  2.08it/s]  

{'loss': 0.1347, 'grad_norm': 27.92207908630371, 'learning_rate': 9.333333333333334e-06, 'epoch': 1.6}


                                                   
 53%|█████▎    | 1600/3000 [15:57<11:12,  2.08it/s]

{'eval_loss': 0.18055330216884613, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.935, 'eval_precision': 0.9290738093772801, 'eval_recall': 0.8938237433904922, 'eval_f1-score': 0.9081864328775313, 'eval_support': 2000.0, 'eval_runtime': 23.9373, 'eval_samples_per_second': 83.551, 'eval_steps_per_second': 5.222, 'epoch': 1.6}


 60%|██████    | 1800/3000 [17:37<10:02,  1.99it/s]  

{'loss': 0.1479, 'grad_norm': 3.723676919937134, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.8}


                                                   
 60%|██████    | 1800/3000 [18:01<10:02,  1.99it/s]

{'eval_loss': 0.1632811278104782, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9345, 'eval_precision': 0.9243434837212448, 'eval_recall': 0.8953627125900215, 'eval_f1-score': 0.9075461113805877, 'eval_support': 2000.0, 'eval_runtime': 23.7992, 'eval_samples_per_second': 84.036, 'eval_steps_per_second': 5.252, 'epoch': 1.8}


 67%|██████▋   | 2000/3000 [19:42<08:19,  2.00it/s]  

{'loss': 0.1262, 'grad_norm': 7.469344139099121, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.0}


                                                   
 67%|██████▋   | 2000/3000 [20:05<08:19,  2.00it/s]

{'eval_loss': 0.1696544736623764, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9395, 'eval_precision': 0.9227163567548965, 'eval_recall': 0.9132120184940509, 'eval_f1-score': 0.9159539834204358, 'eval_support': 2000.0, 'eval_runtime': 23.8509, 'eval_samples_per_second': 83.854, 'eval_steps_per_second': 5.241, 'epoch': 2.0}


 73%|███████▎  | 2200/3000 [21:46<06:23,  2.08it/s]  

{'loss': 0.1003, 'grad_norm': 9.503037452697754, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.2}


                                                   
 73%|███████▎  | 2200/3000 [22:09<06:23,  2.08it/s]

{'eval_loss': 0.17316670715808868, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.938, 'eval_precision': 0.938420135046726, 'eval_recall': 0.893202889017849, 'eval_f1-score': 0.9123046656849496, 'eval_support': 2000.0, 'eval_runtime': 22.6416, 'eval_samples_per_second': 88.333, 'eval_steps_per_second': 5.521, 'epoch': 2.2}


 80%|████████  | 2400/3000 [23:45<04:47,  2.08it/s]  

{'loss': 0.1007, 'grad_norm': 3.27270770072937, 'learning_rate': 4.000000000000001e-06, 'epoch': 2.4}


                                                   
 80%|████████  | 2400/3000 [24:09<04:47,  2.08it/s]

{'eval_loss': 0.1694089025259018, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9385, 'eval_precision': 0.9294616329299622, 'eval_recall': 0.9049360563654827, 'eval_f1-score': 0.9138466295290327, 'eval_support': 2000.0, 'eval_runtime': 23.9532, 'eval_samples_per_second': 83.496, 'eval_steps_per_second': 5.219, 'epoch': 2.4}


 87%|████████▋ | 2600/3000 [25:51<03:25,  1.95it/s]  

{'loss': 0.0797, 'grad_norm': 0.03714936599135399, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.6}


                                                   
 87%|████████▋ | 2600/3000 [26:15<03:25,  1.95it/s]

{'eval_loss': 0.1637253612279892, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9415, 'eval_precision': 0.9215808349656992, 'eval_recall': 0.9128564128424931, 'eval_f1-score': 0.9167776953515946, 'eval_support': 2000.0, 'eval_runtime': 23.8742, 'eval_samples_per_second': 83.772, 'eval_steps_per_second': 5.236, 'epoch': 2.6}


 93%|█████████▎| 2800/3000 [27:53<01:35,  2.09it/s]

{'loss': 0.0841, 'grad_norm': 6.354440689086914, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.8}


                                                   
 93%|█████████▎| 2800/3000 [28:15<01:35,  2.09it/s]

{'eval_loss': 0.17408889532089233, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.941, 'eval_precision': 0.9320635248049668, 'eval_recall': 0.9038654907118119, 'eval_f1-score': 0.9162336492036562, 'eval_support': 2000.0, 'eval_runtime': 22.698, 'eval_samples_per_second': 88.113, 'eval_steps_per_second': 5.507, 'epoch': 2.8}


100%|██████████| 3000/3000 [29:51<00:00,  2.09it/s]

{'loss': 0.0927, 'grad_norm': 0.522882342338562, 'learning_rate': 0.0, 'epoch': 3.0}


                                                   
100%|██████████| 3000/3000 [30:14<00:00,  2.09it/s]

{'eval_loss': 0.16834279894828796, 'eval_model_preparation_time': 0.0, 'eval_accuracy': 0.9405, 'eval_precision': 0.9206623386735681, 'eval_recall': 0.9136635887827942, 'eval_f1-score': 0.9165278983047488, 'eval_support': 2000.0, 'eval_runtime': 22.6711, 'eval_samples_per_second': 88.218, 'eval_steps_per_second': 5.514, 'epoch': 3.0}


100%|██████████| 3000/3000 [30:21<00:00,  1.65it/s]


{'train_runtime': 1821.7417, 'train_samples_per_second': 26.348, 'train_steps_per_second': 1.647, 'train_loss': 0.23278478399912517, 'epoch': 3.0}


100%|██████████| 125/125 [00:22<00:00,  5.59it/s]

+-----------------------------+------------+
| Метрика                     |   Значение |
| eval_loss                   |     0.1777 |
+-----------------------------+------------+
| eval_model_preparation_time |     0.0000 |
+-----------------------------+------------+
| eval_accuracy               |     0.9275 |
+-----------------------------+------------+
| eval_precision              |     0.8902 |
+-----------------------------+------------+
| eval_recall                 |     0.8693 |
+-----------------------------+------------+
| eval_f1-score               |     0.8788 |
+-----------------------------+------------+
| eval_support                |  2000.0000 |
+-----------------------------+------------+
| eval_runtime                |    22.7870 |
+-----------------------------+------------+
| eval_samples_per_second     |    87.7690 |
+-----------------------------+------------+
| eval_steps_per_second       |     5.4860 |
+-----------------------------+------------+
| epoch   




### Linear probing
Выбор трансформерной головы обусловлен попыткой улучшить контекстуализацию фич через механизм внимания, так как классическая полносвязная сеть (Linear → ReLU → Dropout → Linear) не справлялась с задачей. Однако это не помогло повысить качество, несмотря на теоретический потенциал.



In [None]:
from torch import nn

from torch import nn
from transformers import BertConfig

class TransformerClassifierHead(nn.Module):
    def __init__(self, hidden_size, num_labels, num_layers=1, dropout=0.1):
        super().__init__()
        self.config = BertConfig(
            hidden_size=hidden_size,
            num_attention_heads=12,
            intermediate_size=3072,
            num_hidden_layers=num_layers,
            hidden_dropout_prob=dropout
        )
        
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=self.config.num_attention_heads,
                dim_feedforward=self.config.intermediate_size,
                dropout=dropout,
                activation="gelu",
                batch_first=True
            ),
            num_layers=num_layers
        )
        
        self.pooler = nn.Linear(hidden_size, hidden_size)
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, hidden_states):
        if hidden_states.dim() == 2:
            hidden_states = hidden_states.unsqueeze(1)  # [batch, 1, hidden]
        
        transformer_out = self.transformer(hidden_states)  # [batch, 1, hidden]
        
        pooled = self.tanh(self.pooler(transformer_out[:, 0, :]))  # [batch, hidden]
        pooled = self.dropout(pooled)
        
        return self.classifier(pooled)

# Модифицируем модель BERT
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(labels)
)

# Замораживаем все слои BERT
for param in model.parameters():
    param.requires_grad = False

# Добавляем кастомную голову
hidden_size = model.config.hidden_size  # 768 для bert-base
num_labels = len(labels)                        
model.classifier = TransformerClassifierHead(hidden_size, num_labels)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params:,}")

# Проверяем, что обучаются только параметры классификатора
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Обучается: {name}")

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    learning_rate=1e-4, # Больше lr, так как обучаем только голову
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    # max_steps=10000,
    num_train_epochs=10,
    logging_steps=1000,
    weight_decay=0.01,
    metric_for_best_model="f1-score",
    logging_dir="./logs",
    report_to="none",
    load_best_model_at_end=True,
    save_total_limit=2,
    lr_scheduler_type="cosine",
    # warmup_steps=500
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    compute_metrics=compute_metrics
)

trainer.train()
results = trainer.evaluate(tokenized_dataset["test"])
print(tabulate.tabulate(
    results.items(),
    headers=["Метрика", "Значение"],
    tablefmt="grid",
    floatfmt=".4f"
))

# Training time: 37m

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total trainable parameters: 7,683,078
Обучается: classifier.transformer.layers.0.self_attn.in_proj_weight
Обучается: classifier.transformer.layers.0.self_attn.in_proj_bias
Обучается: classifier.transformer.layers.0.self_attn.out_proj.weight
Обучается: classifier.transformer.layers.0.self_attn.out_proj.bias
Обучается: classifier.transformer.layers.0.linear1.weight
Обучается: classifier.transformer.layers.0.linear1.bias
Обучается: classifier.transformer.layers.0.linear2.weight
Обучается: classifier.transformer.layers.0.linear2.bias
Обучается: classifier.transformer.layers.0.norm1.weight
Обучается: classifier.transformer.layers.0.norm1.bias
Обучается: classifier.transformer.layers.0.norm2.weight
Обучается: classifier.transformer.layers.0.norm2.bias
Обучается: classifier.pooler.weight
Обучается: classifier.pooler.bias
Обучается: classifier.classifier.weight
Обучается: classifier.classifier.bias


 10%|█         | 1000/10000 [03:13<29:05,  5.16it/s]
 10%|█         | 1000/10000 [03:13<29:05,  5.16it/s]

{'loss': 1.5412, 'grad_norm': 4.264253616333008, 'learning_rate': 9.755282581475769e-05, 'epoch': 1.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 10%|█         | 1000/10000 [03:35<29:05,  5.16it/s]
[A

{'eval_loss': 1.4600449800491333, 'eval_accuracy': 0.4615, 'eval_precision': 0.24825122904141975, 'eval_recall': 0.25091909662664375, 'eval_f1-score': 0.19574071741170282, 'eval_support': 2000.0, 'eval_runtime': 22.7767, 'eval_samples_per_second': 87.809, 'eval_steps_per_second': 5.488, 'epoch': 1.0}


 20%|██        | 2000/10000 [06:50<25:56,  5.14it/s]   
 20%|██        | 2000/10000 [06:50<25:56,  5.14it/s]

{'loss': 1.4626, 'grad_norm': 5.81187629699707, 'learning_rate': 9.045084971874738e-05, 'epoch': 2.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 20%|██        | 2000/10000 [07:12<25:56,  5.14it/s]
[A

{'eval_loss': 1.3982820510864258, 'eval_accuracy': 0.4865, 'eval_precision': 0.30662195687980315, 'eval_recall': 0.26083351200686106, 'eval_f1-score': 0.20303710101790726, 'eval_support': 2000.0, 'eval_runtime': 22.7885, 'eval_samples_per_second': 87.764, 'eval_steps_per_second': 5.485, 'epoch': 2.0}


 30%|███       | 3000/10000 [10:27<22:14,  5.25it/s]   
 30%|███       | 3000/10000 [10:27<22:14,  5.25it/s]

{'loss': 1.4275, 'grad_norm': 4.465834617614746, 'learning_rate': 7.938926261462366e-05, 'epoch': 3.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 30%|███       | 3000/10000 [10:49<22:14,  5.25it/s]
[A

{'eval_loss': 1.355415940284729, 'eval_accuracy': 0.483, 'eval_precision': 0.44928630153475235, 'eval_recall': 0.28275959456768235, 'eval_f1-score': 0.24539208779641578, 'eval_support': 2000.0, 'eval_runtime': 22.8486, 'eval_samples_per_second': 87.533, 'eval_steps_per_second': 5.471, 'epoch': 3.0}


 40%|████      | 4000/10000 [14:04<19:03,  5.25it/s]   
 40%|████      | 4000/10000 [14:04<19:03,  5.25it/s]

{'loss': 1.4088, 'grad_norm': 3.8054888248443604, 'learning_rate': 6.545084971874738e-05, 'epoch': 4.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 40%|████      | 4000/10000 [14:27<19:03,  5.25it/s]
[A

{'eval_loss': 1.3530161380767822, 'eval_accuracy': 0.4845, 'eval_precision': 0.3946630974621512, 'eval_recall': 0.28946987564322474, 'eval_f1-score': 0.24720986085571361, 'eval_support': 2000.0, 'eval_runtime': 22.8285, 'eval_samples_per_second': 87.61, 'eval_steps_per_second': 5.476, 'epoch': 4.0}


 50%|█████     | 5000/10000 [17:41<15:58,  5.21it/s]   
 50%|█████     | 5000/10000 [17:41<15:58,  5.21it/s]

{'loss': 1.3885, 'grad_norm': 5.1490864753723145, 'learning_rate': 5e-05, 'epoch': 5.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 50%|█████     | 5000/10000 [18:04<15:58,  5.21it/s]
[A

{'eval_loss': 1.325097918510437, 'eval_accuracy': 0.4975, 'eval_precision': 0.30872630924097727, 'eval_recall': 0.3037898441966838, 'eval_f1-score': 0.2690423340874244, 'eval_support': 2000.0, 'eval_runtime': 22.9317, 'eval_samples_per_second': 87.216, 'eval_steps_per_second': 5.451, 'epoch': 5.0}


 60%|██████    | 6000/10000 [21:19<12:45,  5.23it/s]   
 60%|██████    | 6000/10000 [21:19<12:45,  5.23it/s]

{'loss': 1.3714, 'grad_norm': 3.827255964279175, 'learning_rate': 3.4549150281252636e-05, 'epoch': 6.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 60%|██████    | 6000/10000 [21:41<12:45,  5.23it/s]
[A

{'eval_loss': 1.31624174118042, 'eval_accuracy': 0.4915, 'eval_precision': 0.4903148553874495, 'eval_recall': 0.28035955337239254, 'eval_f1-score': 0.2436167785672114, 'eval_support': 2000.0, 'eval_runtime': 22.8284, 'eval_samples_per_second': 87.61, 'eval_steps_per_second': 5.476, 'epoch': 6.0}


 70%|███████   | 7000/10000 [25:06<10:25,  4.79it/s]  
 70%|███████   | 7000/10000 [25:06<10:25,  4.79it/s]

{'loss': 1.3596, 'grad_norm': 5.635991096496582, 'learning_rate': 2.061073738537635e-05, 'epoch': 7.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 70%|███████   | 7000/10000 [25:31<10:25,  4.79it/s]
[A

{'eval_loss': 1.2678087949752808, 'eval_accuracy': 0.5285, 'eval_precision': 0.4395739224256516, 'eval_recall': 0.29910331786703154, 'eval_f1-score': 0.26510746448778316, 'eval_support': 2000.0, 'eval_runtime': 24.477, 'eval_samples_per_second': 81.709, 'eval_steps_per_second': 5.107, 'epoch': 7.0}


 80%|████████  | 8000/10000 [29:02<06:28,  5.15it/s]  
 80%|████████  | 8000/10000 [29:02<06:28,  5.15it/s]

{'loss': 1.3395, 'grad_norm': 5.142706871032715, 'learning_rate': 9.549150281252633e-06, 'epoch': 8.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 80%|████████  | 8000/10000 [29:24<06:28,  5.15it/s]
[A

{'eval_loss': 1.2587741613388062, 'eval_accuracy': 0.5225, 'eval_precision': 0.41661475189747293, 'eval_recall': 0.3077334279299246, 'eval_f1-score': 0.2810760540139106, 'eval_support': 2000.0, 'eval_runtime': 22.7735, 'eval_samples_per_second': 87.821, 'eval_steps_per_second': 5.489, 'epoch': 8.0}


 90%|█████████ | 9000/10000 [32:39<03:11,  5.22it/s]  
 90%|█████████ | 9000/10000 [32:39<03:11,  5.22it/s]

{'loss': 1.3319, 'grad_norm': 6.297039985656738, 'learning_rate': 2.4471741852423237e-06, 'epoch': 9.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                    
 90%|█████████ | 9000/10000 [33:01<03:11,  5.22it/s]
[A

{'eval_loss': 1.2516666650772095, 'eval_accuracy': 0.5285, 'eval_precision': 0.46075284316127824, 'eval_recall': 0.32322573573342067, 'eval_f1-score': 0.30103824443112087, 'eval_support': 2000.0, 'eval_runtime': 22.7749, 'eval_samples_per_second': 87.816, 'eval_steps_per_second': 5.489, 'epoch': 9.0}


100%|██████████| 10000/10000 [36:24<00:00,  4.82it/s] 
100%|██████████| 10000/10000 [36:24<00:00,  4.82it/s]

{'loss': 1.3275, 'grad_norm': 7.101324081420898, 'learning_rate': 0.0, 'epoch': 10.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                           
                                                     
100%|██████████| 10000/10000 [36:48<00:00,  4.82it/s]
[A

{'eval_loss': 1.248012900352478, 'eval_accuracy': 0.527, 'eval_precision': 0.45005689601292165, 'eval_recall': 0.3218399257842363, 'eval_f1-score': 0.3024922018707328, 'eval_support': 2000.0, 'eval_runtime': 24.2452, 'eval_samples_per_second': 82.49, 'eval_steps_per_second': 5.156, 'epoch': 10.0}



100%|██████████| 10000/10000 [36:49<00:00,  4.53it/s]


{'train_runtime': 2209.9339, 'train_samples_per_second': 72.4, 'train_steps_per_second': 4.525, 'train_loss': 1.3958416015625, 'epoch': 10.0}


100%|██████████| 125/125 [00:23<00:00,  5.26it/s]

+-------------------------+------------+
| Метрика                 |   Значение |
| eval_loss               |     1.2296 |
+-------------------------+------------+
| eval_accuracy           |     0.5315 |
+-------------------------+------------+
| eval_precision          |     0.4409 |
+-------------------------+------------+
| eval_recall             |     0.3190 |
+-------------------------+------------+
| eval_f1-score           |     0.3073 |
+-------------------------+------------+
| eval_support            |  2000.0000 |
+-------------------------+------------+
| eval_runtime            |    23.9604 |
+-------------------------+------------+
| eval_samples_per_second |    83.4710 |
+-------------------------+------------+
| eval_steps_per_second   |     5.2170 |
+-------------------------+------------+
| epoch                   |    10.0000 |
+-------------------------+------------+





### Prompt tuning
Выбор параметров Prompt Tuning продиктован спецификой задачи классификации текста и особенностями базовой модели BERT:
- num_virtual_tokens=20 — длина промпта, достаточная для кодирования контекста задачи (слишком короткий промпт не передаёт смысл, длинный — усложняет обучение).
- token_dim=768 — соответствует размерности эмбеддингов BERT-base, чтобы избежать конфликтов в архитектуре.
- prompt_tuning_init="TEXT" + prompt_tuning_init_text=... — инициализация промпта осмысленной фразой ("Classify the emotion...") ускоряет сходимость, так как задаёт семантическую направленность.
- base_model_name_or_path и tokenizer_name_or_path — использование предобученного BERT-base-uncased обеспечивает стартовые веса, адаптированные для NLP-задач.

Параметры направлены на баланс между контролем над промптом (через явную текстовую инициализацию) и адаптивностью (20 виртуальных токенов). Это снижает риск "холодного старта" и улучшает интерпретируемость по сравнению со случайной инициализацией.

In [15]:
from peft import (
    PromptTuningConfig,
    get_peft_model,
    TaskType
)
from transformers import AutoModelForSequenceClassification

peft_config = PromptTuningConfig(
    task_type=TaskType.SEQ_CLS,
    num_virtual_tokens=20,  # Длина мягкого промпта
    token_dim=768,          # Размерность эмбеддингов BERT
    prompt_tuning_init="TEXT",
    prompt_tuning_init_text="Classify the emotion in the text:",
    base_model_name_or_path="bert-base-uncased",
    tokenizer_name_or_path="bert-base-uncased"
)

model = AutoModelForSequenceClassification.from_pretrained(
    "google-bert/bert-base-uncased",
    num_labels=len(labels),
    return_dict=True
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()  # Обучаемые параметры: ~0.1%


training_args = TrainingArguments(
    output_dir="./peft_results",
    learning_rate=1e-4,           # Выше обычного для промптов
    per_device_train_batch_size=32,
    num_train_epochs=5,          # Нужно больше эпох
    logging_steps=100,
    save_strategy="no"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    compute_metrics=compute_metrics
)

trainer.train()
results = trainer.evaluate(tokenized_dataset["test"])
print(tabulate.tabulate(
    results.items(),
    headers=["Метрика", "Значение"],
    tablefmt="grid",
    floatfmt=".4f"
))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 15,360 || all params: 109,502,214 || trainable%: 0.0140


 20%|█▉        | 496/2500 [21:48<1:28:07,  2.64s/it]
  4%|▍         | 100/2500 [01:21<32:01,  1.25it/s]
  4%|▍         | 100/2500 [01:21<32:01,  1.25it/s] 

{'loss': 1.7466, 'grad_norm': 0.10798969119787216, 'learning_rate': 9.6e-05, 'epoch': 0.2}


  8%|▊         | 200/2500 [02:42<30:42,  1.25it/s]
  8%|▊         | 200/2500 [02:42<30:42,  1.25it/s] 

{'loss': 1.7287, 'grad_norm': 0.10176584869623184, 'learning_rate': 9.200000000000001e-05, 'epoch': 0.4}


 12%|█▏        | 300/2500 [04:02<29:30,  1.24it/s]
 12%|█▏        | 300/2500 [04:02<29:30,  1.24it/s] 

{'loss': 1.7022, 'grad_norm': 0.09194259345531464, 'learning_rate': 8.800000000000001e-05, 'epoch': 0.6}


 16%|█▌        | 400/2500 [05:22<28:04,  1.25it/s]
 16%|█▌        | 400/2500 [05:22<28:04,  1.25it/s] 

{'loss': 1.7071, 'grad_norm': 0.0901457816362381, 'learning_rate': 8.4e-05, 'epoch': 0.8}


 20%|██        | 500/2500 [06:42<26:46,  1.24it/s]
 20%|██        | 500/2500 [06:42<26:46,  1.24it/s] 

{'loss': 1.6995, 'grad_norm': 0.09638141095638275, 'learning_rate': 8e-05, 'epoch': 1.0}


 24%|██▍       | 600/2500 [08:03<25:20,  1.25it/s]
 24%|██▍       | 600/2500 [08:03<25:20,  1.25it/s] 

{'loss': 1.6944, 'grad_norm': 0.09495647996664047, 'learning_rate': 7.6e-05, 'epoch': 1.2}


 28%|██▊       | 700/2500 [09:23<24:11,  1.24it/s]
 28%|██▊       | 700/2500 [09:23<24:11,  1.24it/s] 

{'loss': 1.6977, 'grad_norm': 0.13558658957481384, 'learning_rate': 7.2e-05, 'epoch': 1.4}


 32%|███▏      | 800/2500 [10:43<22:48,  1.24it/s]
 32%|███▏      | 800/2500 [10:43<22:48,  1.24it/s] 

{'loss': 1.6881, 'grad_norm': 0.15118098258972168, 'learning_rate': 6.800000000000001e-05, 'epoch': 1.6}


 36%|███▌      | 900/2500 [12:04<22:58,  1.16it/s]
 36%|███▌      | 900/2500 [12:04<22:58,  1.16it/s] 

{'loss': 1.6924, 'grad_norm': 0.26179033517837524, 'learning_rate': 6.400000000000001e-05, 'epoch': 1.8}


 40%|████      | 1000/2500 [13:31<21:30,  1.16it/s]
 40%|████      | 1000/2500 [13:31<21:30,  1.16it/s]

{'loss': 1.6839, 'grad_norm': 0.33373138308525085, 'learning_rate': 6e-05, 'epoch': 2.0}


 44%|████▍     | 1100/2500 [14:56<19:13,  1.21it/s]
 44%|████▍     | 1100/2500 [14:56<19:13,  1.21it/s]

{'loss': 1.6747, 'grad_norm': 0.40722575783729553, 'learning_rate': 5.6000000000000006e-05, 'epoch': 2.2}


 48%|████▊     | 1200/2500 [16:16<17:20,  1.25it/s]
 48%|████▊     | 1200/2500 [16:16<17:20,  1.25it/s]

{'loss': 1.6691, 'grad_norm': 0.4235415756702423, 'learning_rate': 5.2000000000000004e-05, 'epoch': 2.4}


 52%|█████▏    | 1300/2500 [17:36<16:04,  1.24it/s]
 52%|█████▏    | 1300/2500 [17:36<16:04,  1.24it/s]

{'loss': 1.6569, 'grad_norm': 0.6730015873908997, 'learning_rate': 4.8e-05, 'epoch': 2.6}


 56%|█████▌    | 1400/2500 [18:56<14:39,  1.25it/s]
 56%|█████▌    | 1400/2500 [18:56<14:39,  1.25it/s]

{'loss': 1.6578, 'grad_norm': 0.5719950199127197, 'learning_rate': 4.4000000000000006e-05, 'epoch': 2.8}


 60%|██████    | 1500/2500 [20:16<13:20,  1.25it/s]
 60%|██████    | 1500/2500 [20:16<13:20,  1.25it/s]

{'loss': 1.6529, 'grad_norm': 0.6595889925956726, 'learning_rate': 4e-05, 'epoch': 3.0}


 64%|██████▍   | 1600/2500 [21:37<12:02,  1.25it/s]
 64%|██████▍   | 1600/2500 [21:37<12:02,  1.25it/s]

{'loss': 1.6537, 'grad_norm': 0.804837703704834, 'learning_rate': 3.6e-05, 'epoch': 3.2}


 68%|██████▊   | 1700/2500 [22:57<10:40,  1.25it/s]
 68%|██████▊   | 1700/2500 [22:57<10:40,  1.25it/s]

{'loss': 1.6528, 'grad_norm': 0.8689733743667603, 'learning_rate': 3.2000000000000005e-05, 'epoch': 3.4}


 72%|███████▏  | 1800/2500 [24:17<09:21,  1.25it/s]
 72%|███████▏  | 1800/2500 [24:17<09:21,  1.25it/s]

{'loss': 1.6499, 'grad_norm': 0.6514583230018616, 'learning_rate': 2.8000000000000003e-05, 'epoch': 3.6}


 76%|███████▌  | 1900/2500 [25:37<07:58,  1.25it/s]
 76%|███████▌  | 1900/2500 [25:37<07:58,  1.25it/s]

{'loss': 1.6479, 'grad_norm': 1.0249298810958862, 'learning_rate': 2.4e-05, 'epoch': 3.8}


 80%|████████  | 2000/2500 [26:58<07:14,  1.15it/s]
 80%|████████  | 2000/2500 [26:58<07:14,  1.15it/s]

{'loss': 1.6319, 'grad_norm': 0.5863732099533081, 'learning_rate': 2e-05, 'epoch': 4.0}


 84%|████████▍ | 2100/2500 [28:24<05:47,  1.15it/s]
 84%|████████▍ | 2100/2500 [28:24<05:47,  1.15it/s]

{'loss': 1.6487, 'grad_norm': 0.8496687412261963, 'learning_rate': 1.6000000000000003e-05, 'epoch': 4.2}


 88%|████████▊ | 2200/2500 [29:49<04:00,  1.25it/s]
 88%|████████▊ | 2200/2500 [29:49<04:00,  1.25it/s]

{'loss': 1.6391, 'grad_norm': 0.9528376460075378, 'learning_rate': 1.2e-05, 'epoch': 4.4}


 92%|█████████▏| 2300/2500 [31:09<02:39,  1.26it/s]
 92%|█████████▏| 2300/2500 [31:09<02:39,  1.26it/s]

{'loss': 1.6434, 'grad_norm': 0.6751308441162109, 'learning_rate': 8.000000000000001e-06, 'epoch': 4.6}


 96%|█████████▌| 2400/2500 [32:29<01:19,  1.25it/s]
 96%|█████████▌| 2400/2500 [32:29<01:19,  1.25it/s]

{'loss': 1.6386, 'grad_norm': 1.1440510749816895, 'learning_rate': 4.000000000000001e-06, 'epoch': 4.8}


100%|██████████| 2500/2500 [33:49<00:00,  1.25it/s]
100%|██████████| 2500/2500 [33:49<00:00,  1.25it/s]
100%|██████████| 2500/2500 [33:49<00:00,  1.23it/s]


{'loss': 1.6401, 'grad_norm': 0.8795555830001831, 'learning_rate': 0.0, 'epoch': 5.0}
{'train_runtime': 2029.4097, 'train_samples_per_second': 39.42, 'train_steps_per_second': 1.232, 'train_loss': 1.6719270263671875, 'epoch': 5.0}


100%|██████████| 250/250 [00:26<00:00,  9.50it/s]

+-------------------------+------------+
| Метрика                 |   Значение |
| eval_loss               |     1.6225 |
+-------------------------+------------+
| eval_accuracy           |     0.3755 |
+-------------------------+------------+
| eval_precision          |     0.1232 |
+-------------------------+------------+
| eval_recall             |     0.1901 |
+-------------------------+------------+
| eval_f1-score           |     0.1456 |
+-------------------------+------------+
| eval_support            |  2000.0000 |
+-------------------------+------------+
| eval_runtime            |    26.4330 |
+-------------------------+------------+
| eval_samples_per_second |    75.6630 |
+-------------------------+------------+
| eval_steps_per_second   |     9.4580 |
+-------------------------+------------+
| epoch                   |     5.0000 |
+-------------------------+------------+





### Lora
Выбор параметров LoRA обусловлен балансом между эффективностью, стабильностью обучения и вычислительными затратами:
- r=8 (ранг) — оптимален для захвата основных паттернов данных без избыточной параметризации (слишком низкий r теряет информацию, высокий — увеличивает риск переобучения).
- lora_alpha=16 — коэффициент масштабирования, согласованный с r (часто используют alpha = 2*r), чтобы сохранить соотношение влияния оригинальных и адаптивных весов.
- lora_dropout=0.1 — умеренная регуляризация для улучшения обобщающей способности.
- target_modules=["query", "value"] — слои, связанные с механизмом внимания, наиболее критичны для адаптации модели к задаче.
- bias="none" — исключение смещений уменьшает число параметров и упрощает обучение.

Параметры следуют рекомендациям оригинальной работы по LoRA и эмпирическим практикам для задач классификации (TaskType.SEQ_CLS), обеспечивая воспроизводимость и стабильность результатов.

In [None]:

from peft import (
    LoraConfig, 
    TaskType, 
    get_peft_model
)
from transformers import AutoModelForSequenceClassification

lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        inference_mode=False,
        r=8,                # Ранг адаптеров
        lora_alpha=16,      # Коэффициент масштабирования
        lora_dropout=0.1,   # Дропаут для регуляризации
        target_modules=["query", "value"],  # Слои для применения LoRA
        bias="none"
    )


model = AutoModelForSequenceClassification.from_pretrained(
    "google-bert/bert-base-uncased",
    num_labels=len(labels)
)

lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()


training_args = TrainingArguments(
    output_dir=f"./lora_results",
    learning_rate=3e-4,
    per_device_train_batch_size=32,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    logging_steps=200
)

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    compute_metrics=compute_metrics
)

trainer.train()
results = trainer.evaluate(tokenized_dataset["test"])
print(tabulate.tabulate(
    results.items(),
    headers=["Метрика", "Значение"],
    tablefmt="grid",
    floatfmt=".4f"
))

# Training time: 36m

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 299,526 || all params: 109,786,380 || trainable%: 0.2728


  2%|▏         | 105/5000 [01:33<1:13:00,  1.12it/s]
  8%|▊         | 200/2500 [02:48<31:26,  1.22it/s]
  8%|▊         | 200/2500 [02:48<31:26,  1.22it/s]

{'loss': 1.349, 'grad_norm': 2.0876481533050537, 'learning_rate': 0.000276, 'epoch': 0.4}


 16%|█▌        | 400/2500 [05:33<28:45,  1.22it/s]
 16%|█▌        | 400/2500 [05:33<28:45,  1.22it/s]

{'loss': 0.8296, 'grad_norm': 3.4698805809020996, 'learning_rate': 0.00025199999999999995, 'epoch': 0.8}


 20%|██        | 500/2500 [06:55<27:22,  1.22it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


{'eval_loss': 0.47649213671684265, 'eval_accuracy': 0.8365, 'eval_precision': 0.8267511170691882, 'eval_recall': 0.7541225013167661, 'eval_f1-score': 0.7820812942751005, 'eval_support': 2000.0, 'eval_runtime': 28.4486, 'eval_samples_per_second': 70.302, 'eval_steps_per_second': 8.788, 'epoch': 1.0}


 24%|██▍       | 600/2500 [08:46<25:54,  1.22it/s]  
 24%|██▍       | 600/2500 [08:46<25:54,  1.22it/s]

{'loss': 0.5413, 'grad_norm': 4.005258083343506, 'learning_rate': 0.00022799999999999999, 'epoch': 1.2}


 32%|███▏      | 800/2500 [11:31<23:23,  1.21it/s]
 32%|███▏      | 800/2500 [11:31<23:23,  1.21it/s]

{'loss': 0.3978, 'grad_norm': 3.885700225830078, 'learning_rate': 0.000204, 'epoch': 1.6}


 40%|████      | 1000/2500 [14:15<20:26,  1.22it/s]
 40%|████      | 1000/2500 [14:15<20:26,  1.22it/s]

{'loss': 0.3599, 'grad_norm': 4.011159896850586, 'learning_rate': 0.00017999999999999998, 'epoch': 2.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                   

{'eval_loss': 0.2431042194366455, 'eval_accuracy': 0.9215, 'eval_precision': 0.8936483063804851, 'eval_recall': 0.9011226544768344, 'eval_f1-score': 0.8970911641926804, 'eval_support': 2000.0, 'eval_runtime': 28.3684, 'eval_samples_per_second': 70.501, 'eval_steps_per_second': 8.813, 'epoch': 2.0}


 48%|████▊     | 1200/2500 [17:27<16:56,  1.28it/s]  
 48%|████▊     | 1200/2500 [17:27<16:56,  1.28it/s]

{'loss': 0.2753, 'grad_norm': 2.9487192630767822, 'learning_rate': 0.000156, 'epoch': 2.4}


 56%|█████▌    | 1400/2500 [20:04<14:21,  1.28it/s]
 56%|█████▌    | 1400/2500 [20:04<14:21,  1.28it/s]

{'loss': 0.2438, 'grad_norm': 3.2253737449645996, 'learning_rate': 0.00013199999999999998, 'epoch': 2.8}


 60%|██████    | 1500/2500 [21:22<13:02,  1.28it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


{'eval_loss': 0.19373109936714172, 'eval_accuracy': 0.9285, 'eval_precision': 0.8998014687986707, 'eval_recall': 0.9034757221810993, 'eval_f1-score': 0.9011861731340605, 'eval_support': 2000.0, 'eval_runtime': 27.0616, 'eval_samples_per_second': 73.906, 'eval_steps_per_second': 9.238, 'epoch': 3.0}


 64%|██████▍   | 1600/2500 [23:09<11:45,  1.28it/s]  
 64%|██████▍   | 1600/2500 [23:09<11:45,  1.28it/s]

{'loss': 0.2356, 'grad_norm': 2.7473952770233154, 'learning_rate': 0.00010799999999999998, 'epoch': 3.2}


 72%|███████▏  | 1800/2500 [25:45<09:08,  1.28it/s]
 72%|███████▏  | 1800/2500 [25:45<09:08,  1.28it/s]

{'loss': 0.2073, 'grad_norm': 1.419308066368103, 'learning_rate': 8.4e-05, 'epoch': 3.6}


 80%|████████  | 2000/2500 [28:22<06:31,  1.28it/s]
 80%|████████  | 2000/2500 [28:22<06:31,  1.28it/s]

{'loss': 0.1996, 'grad_norm': 5.523075580596924, 'learning_rate': 5.9999999999999995e-05, 'epoch': 4.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   

 80%|████████  | 2000/2500 [28:49<06:31,  1.28it/s]
[A
[A

{'eval_loss': 0.1785881668329239, 'eval_accuracy': 0.936, 'eval_precision': 0.9144537697201054, 'eval_recall': 0.9291076530980072, 'eval_f1-score': 0.920991715768876, 'eval_support': 2000.0, 'eval_runtime': 25.8254, 'eval_samples_per_second': 77.443, 'eval_steps_per_second': 9.68, 'epoch': 4.0}


 88%|████████▊ | 2200/2500 [31:25<03:54,  1.28it/s]  
 88%|████████▊ | 2200/2500 [31:25<03:54,  1.28it/s]

{'loss': 0.18, 'grad_norm': 4.138106822967529, 'learning_rate': 3.5999999999999994e-05, 'epoch': 4.4}


 96%|█████████▌| 2400/2500 [34:02<01:18,  1.28it/s]
 96%|█████████▌| 2400/2500 [34:02<01:18,  1.28it/s]

{'loss': 0.1844, 'grad_norm': 2.3966174125671387, 'learning_rate': 1.1999999999999999e-05, 'epoch': 4.8}


100%|██████████| 2500/2500 [35:20<00:00,  1.28it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


{'eval_loss': 0.1654692143201828, 'eval_accuracy': 0.938, 'eval_precision': 0.9193392246932405, 'eval_recall': 0.9178323360827396, 'eval_f1-score': 0.9184753633939998, 'eval_support': 2000.0, 'eval_runtime': 27.0808, 'eval_samples_per_second': 73.853, 'eval_steps_per_second': 9.232, 'epoch': 5.0}
{'train_runtime': 2148.5027, 'train_samples_per_second': 37.235, 'train_steps_per_second': 1.164, 'train_loss': 0.40780635681152344, 'epoch': 5.0}


100%|██████████| 250/250 [00:26<00:00,  9.35it/s]

+-------------------------+------------+
| Метрика                 |   Значение |
| eval_loss               |     0.1978 |
+-------------------------+------------+
| eval_accuracy           |     0.9220 |
+-------------------------+------------+
| eval_precision          |     0.8764 |
+-------------------------+------------+
| eval_recall             |     0.8873 |
+-------------------------+------------+
| eval_f1-score           |     0.8813 |
+-------------------------+------------+
| eval_support            |  2000.0000 |
+-------------------------+------------+
| eval_runtime            |    26.8544 |
+-------------------------+------------+
| eval_samples_per_second |    74.4760 |
+-------------------------+------------+
| eval_steps_per_second   |     9.3090 |
+-------------------------+------------+
| epoch                   |     5.0000 |
+-------------------------+------------+





### Вывод


In [16]:
from tabulate import tabulate

data = [
    ["Full Finetuning", 0.1777, 0.9275, 0.8902, 0.8693, 0.8788, 2000, 22.7870, 87.769, 5.486, 3],
    ["Linear Probing", 1.2296, 0.5315, 0.4409, 0.3190, 0.3073, 2000, 23.9604, 83.471, 5.217, 10],
    ["Prompt Tuning", 1.6225, 0.3755, 0.1232, 0.1901, 0.1456, 2000, 26.4330, 75.663, 9.458, 5],
    ["LoRA", 0.1978, 0.9220, 0.8764, 0.8873, 0.8813, 2000, 26.8544, 74.476, 9.309, 5]
]

headers = [
    "Метод", "eval_loss", "eval_accuracy", "eval_precision", 
    "eval_recall", "eval_f1", "eval_support", "eval_runtime", 
    "samples/sec", "steps/sec", "epoch"
]

print(tabulate(data, headers=headers, tablefmt="grid", floatfmt=".4f"))

+-----------------+-------------+-----------------+------------------+---------------+-----------+----------------+----------------+---------------+-------------+---------+
| Метод           |   eval_loss |   eval_accuracy |   eval_precision |   eval_recall |   eval_f1 |   eval_support |   eval_runtime |   samples/sec |   steps/sec |   epoch |
| Full Finetuning |      0.1777 |          0.9275 |           0.8902 |        0.8693 |    0.8788 |           2000 |        22.7870 |       87.7690 |      5.4860 |       3 |
+-----------------+-------------+-----------------+------------------+---------------+-----------+----------------+----------------+---------------+-------------+---------+
| Linear Probing  |      1.2296 |          0.5315 |           0.4409 |        0.3190 |    0.3073 |           2000 |        23.9604 |       83.4710 |      5.2170 |      10 |
+-----------------+-------------+-----------------+------------------+---------------+-----------+----------------+----------------+---

Full Finetuning и LoRA демонстрируют высокие показатели (F1-score ~0.88), что говорит об их эффективности. LoRA почти не уступает Full Finetuning, что делает его оптимальным выбором для задач, где важно сохранить вычислительную эффективность.      
Prompt Tuning (F1-score 0.1456) и Linear Probing (F1-score 0.3073) значительно уступают в качестве. Эти методы требуют пересмотра (например, изменения архитектуры или гиперпараметров) или отказа от их использования в текущей постановке задачи.