## BERT Fine-Tuning for Russian News Classification

This notebook fine-tunes the `DeepPavlov/rubert-base-cased` model on a preprocessed Russian news corpus.  
  
1. Tokenization
2. DataLoader Initialization  
3. Model Initialization, device and precision setup
4. Weighted cross entropy loss function
5. Training and Evaluation Loops  
6. Training Metrics Visualization  
7. Final evaluation on hold-out test set 

In [None]:
%%capture
!pip install datasets
!pip install 'numpy < 2.0'

In [1]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_cosine_schedule_with_warmup
import multiprocessing
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import StepLR
from torch.optim import AdamW
# import torch_xla.core.xla_model as xm
from datasets import Dataset, Features, ClassLabel, Value, load_from_disk
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import warnings

warnings.filterwarnings('ignore')

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


#### Load cleaned data

In [6]:
WORKING_DIR = '/content/drive/MyDrive/news_classifier'

fields = ['text', 'topic']
train_df = pd.read_csv(WORKING_DIR + 'data/train.csv',
                       dtype={'topic': object,
                              'text_clean': object},
                       usecols=fields)
val_df = pd.read_csv(WORKING_DIR + 'data/val.csv',
                     dtype={'topic': object,
                            'text_clean': object},
                     usecols=fields)
test_df = pd.read_csv(WORKING_DIR + 'data/test.csv',
                      dtype={'topic': object,
                             'text_clean': object},
                      usecols=fields)

In [7]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

model_name = 'DeepPavlov/rubert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [8]:
unique_labels = sorted(train_df['topic'].unique())

features = Features({
    'text': Value('string'),
    'topic': ClassLabel(names=unique_labels)
})

## 1. Tokenization

In [None]:
def tokenize_batch(batch):
    toks = tokenizer(
        batch['text'],
        truncation=True,
        padding='max_length',
        max_length=512
    )
    toks['labels'] = batch['topic']
    return toks


train_ds = Dataset.from_pandas(train_df, features=features)
val_ds = Dataset.from_pandas(val_df, features=features)
test_ds = Dataset.from_pandas(test_df, features=features)

train_ds = train_ds.map(tokenize_batch,
                        batched=True,
                        remove_columns=train_df.columns.tolist())
val_ds = val_ds.map(tokenize_batch,
                    batched=True,
                    remove_columns=val_df.columns.tolist())
test_ds = test_ds.map(tokenize_batch,
                      batched=True,
                      remove_columns=test_df.columns.tolist())

train_ds.set_format('torch')
val_ds.set_format('torch')
test_ds.set_format('torch')

Map:   0%|          | 0/715605 [00:00<?, ? examples/s]

Map:   0%|          | 0/39756 [00:00<?, ? examples/s]

Map:   0%|          | 0/39856 [00:00<?, ? examples/s]

In [None]:
train_ds.save_to_disk(WORKING_DIR + 'data/train_ds.hf')
val_ds.save_to_disk(WORKING_DIR + 'data/val_ds.hf')
test_ds.save_to_disk(WORKING_DIR + 'data/test_ds.hf')

Saving the dataset (0/5 shards):   0%|          | 0/715605 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/39756 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/39856 [00:00<?, ? examples/s]

In [20]:
train_ds = load_from_disk(WORKING_DIR + 'data/train_ds.hf')
val_ds = load_from_disk(WORKING_DIR + 'data/val_ds.hf')
test_ds = load_from_disk(WORKING_DIR + 'data/test_ds.hf')

## 2. DataLoaders

In [21]:
num_cores = multiprocessing.cpu_count()
BATCH_SIZE = 128
NUM_EPOCHS = 4

train_loader = DataLoader(
    train_ds,
    sampler=RandomSampler(train_ds),
    batch_size=BATCH_SIZE,
    num_workers=0 # num_cores
)
val_loader = DataLoader(
    val_ds,
    sampler=SequentialSampler(val_ds),
    batch_size=BATCH_SIZE * 2,
    num_workers=0 # num_cores
)
test_loader = DataLoader(
    test_ds,
    sampler=SequentialSampler(test_ds),
    batch_size=BATCH_SIZE * 2,
    num_workers=0 # num_cores
)

## 3. Model Initialization, device and precision setup

In [12]:
total_classes = train_df['topic'].nunique()

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=total_classes,
    problem_type="single_label_classification"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

use_fp16 = device.type == 'cuda'
if use_fp16:
    scaler = torch.cuda.amp.GradScaler()

## 4. Define weighted cross entropy loss function

In [None]:
total_samples = train_df.shape[0]

class_weights = torch.tensor([
    total_samples / (total_classes * count)
    for count in train_df['topic'].value_counts().sort_index()
], device=device)

loss_fn = nn.CrossEntropyLoss(weight=class_weights)

#### Define optimizer and scheduler

In [None]:
LEARNING_RATE = 1e-5

total_steps  = (train_df.shape[0] // BATCH_SIZE) * NUM_EPOCHS
warmup_steps = int(0.1 * total_steps)

optimizer = AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=0.01
)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
    num_cycles=1.5
)

#### Checkpoint Management

In [None]:
import os
import shutil
import re

def _prune_checkpoints(save_path, max_checkpoints):
    all_items = os.listdir(save_path)
    pattern = re.compile(r"^checkpoint-(\d+)$")
    checkpoints = []
    for item in all_items:
        match = pattern.match(item)
        if match:
            step_num = int(match.group(1))
            checkpoints.append((step_num, item))
    if len(checkpoints) <= max_checkpoints:
        return
    checkpoints = sorted(checkpoints, key=lambda x: x[0])
    num_to_delete = len(checkpoints) - max_checkpoints
    for i in range(num_to_delete):
        old_ckpt_name = checkpoints[i][1]
        old_ckpt_path = os.path.join(save_path, old_ckpt_name)
        shutil.rmtree(old_ckpt_path)

## 5. Training and Evaluation Loops  

In [None]:
MAX_GRAD_NORM = 1.0
LOG_STEP = 1000

def train_epoch(model, data_loader, optimizer, scheduler, loss_fn, device,
                save_path, checkpoint_interval, max_checkpoints, step_start,
                train_batch_losses, train_batch_accs, scaler=None):
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    epoch_loss = 0.0

    loop = tqdm(data_loader, desc="Train", leave=False)
    for step, batch in enumerate(loop, start=step_start):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = loss_fn(logits, labels)

        epoch_loss += loss.item()

        if scaler is not None:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()

        scheduler.step()

        batch_loss = loss.item()
        train_batch_losses.append(batch_loss)

        preds = torch.argmax(logits, dim=1)
        batch_correct = (preds == labels).sum().item()
        batch_total = labels.size(0)
        running_correct += batch_correct
        running_total += batch_total

        if step % checkpoint_interval == 0:
            ckpt_dir = os.path.join(save_path, f"checkpoint-{step}")
            os.makedirs(ckpt_dir, exist_ok=True)
            torch.save(model.state_dict(),   os.path.join(ckpt_dir, "model.pt"))
            torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optimizer.pt"))
            torch.save(scheduler.state_dict(), os.path.join(ckpt_dir, "scheduler.pt"))
            _prune_checkpoints(save_path, max_checkpoints)

        if step % LOG_STEP == 0:
            avg_acc = 100. * running_correct / running_total
            train_batch_accs.append(avg_acc)
            running_correct = 0
            running_total = 0

        loop.set_postfix(batch_loss=batch_loss)


def eval_epoch(model, data_loader, loss_fn, device):
    model.eval()
    epoch_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        loop = tqdm(data_loader, desc="Validation", leave=False)
        for batch in loop:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = loss_fn(logits, labels)
            epoch_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            loop.set_postfix(val_loss=loss.item(), val_acc=100. * correct / total)

    avg_loss = epoch_loss / len(data_loader)
    avg_acc = 100. * correct / total
    return avg_loss, avg_acc

#### Fine-Tuning

In [None]:
best_val_acc = 0.0

train_batch_losses = []
train_batch_accs   = []
val_epoch_losses   = []
val_epoch_accs     = []

checkpoint_interval = 1000
max_checkpoints = 5
global_step = 0
checkpoint_path = '/content/checkpoints'

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")

    train_epoch(
        model, train_loader, optimizer, scheduler, loss_fn, device,
        checkpoint_path, checkpoint_interval, max_checkpoints, global_step,
        train_batch_losses, train_batch_accs, scaler=scaler if use_fp16 else None
    )

    val_loss, val_acc = eval_epoch(
        model, val_loader, loss_fn, device
    )
    print(f"Val   loss: {val_loss:.4f} | Val   acc: {val_acc:.2f}%")

    val_epoch_losses.append(val_loss)
    val_epoch_accs.append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), WORKING_DIR + "models/best_bert.pt")
        print(f"Best model saved (val_acc = {best_val_acc:.2f}%)")

In [None]:
plt.plot(range(1, len(train_batch_losses) + 1), train_batch_losses, label='Train Batch Loss')
plt.xlabel('Batch number')
plt.ylabel('Loss')
plt.title('Train Loss per Batch')
plt.legend()
plt.grid()
plt.show()

x_acc = [i * LOG_STEP for i in range(1, len(train_batch_accs) + 1)]
plt.plot(x_acc, train_batch_accs, label=f'Train Accuracy (every {LOG_STEP} batches)')
plt.xlabel('Batch number')
plt.ylabel('Accuracy (%)')
plt.title('Train Accuracy over Batches')
plt.legend()
plt.grid()
plt.show()

plt.plot(range(1, len(val_epoch_losses) + 1), val_epoch_losses, label='Val Loss (per Epoch)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss per Epoch')
plt.legend()
plt.grid()
plt.show()

plt.plot(range(1, len(val_epoch_accs) + 1), val_epoch_accs, label='Val Accuracy (per Epoch)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Validation Accuracy per Epoch')
plt.legend()
plt.grid()
plt.show()

#### Test

In [14]:
model.load_state_dict(torch.load(WORKING_DIR + "models/best_bert.pt"))
model.to(device)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1

In [15]:
def predict_one(text: str):
    inputs = tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=512,
        return_tensors='pt'
    )
    return inputs

In [16]:
def classify(text: str):
    inputs = predict_one(text).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = nn.functional.softmax(logits, dim=-1)
        pred_idx = torch.argmax(probs, dim=-1).item()
        confidence = probs[0, pred_idx].item()
    return pred_idx, confidence, probs.squeeze().tolist()

In [17]:
texts = [
    "Экономические итоги первого квартала перевыполнили прогнозы.",
    "Новый фильм режиссёра выйдет в прокат этим летом."
]
[unique_labels[classify(text)[0]] for text in texts]

['Экономика', 'Культура']

## 7. Final evaluation on hold-out test set 

In [22]:
from torch.utils.data import DataLoader

val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE * 2)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE * 2)

In [25]:
from tqdm.auto import tqdm

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


  0%|          | 0/2485 [00:00<?, ?it/s]

In [26]:
from sklearn.metrics import accuracy_score, classification_report

print(
    classification_report(
        all_labels,
        all_preds,
        target_names=unique_labels
    )
)

                   precision    recall  f1-score   support

      Бывший СССР       0.90      0.91      0.91      2905
              Дом       0.54      0.88      0.67       363
         Из жизни       0.69      0.87      0.77      2857
   Интернет и СМИ       0.71      0.83      0.77      2534
         Культура       0.90      0.82      0.86      2267
              Мир       0.93      0.76      0.84      7177
  Наука и техника       0.78      0.81      0.79      3455
      Путешествия       0.74      0.86      0.80      1278
           Россия       0.89      0.67      0.76      7021
Силовые структуры       0.44      0.86      0.59      1531
            Спорт       0.98      0.98      0.98      3207
         Ценности       0.83      0.76      0.79      1153
        Экономика       0.83      0.86      0.85      4008

         accuracy                           0.80     39756
        macro avg       0.78      0.84      0.80     39756
     weighted avg       0.84      0.81      0.82     3