# Baseline: Обучение Multi-Branch MLP на размеченных данных


In [3]:
!pip install pytorch_lightning

import os
import sys
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

from model import MultiBranchMLP
from data_module import DataModule
from lightning_module import BaseLightningModule

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import random
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)




## 1. Загрузка данных


In [4]:
data_dir = '/content/data'

dm = DataModule(
    data_dir=data_dir,
    batch_size=128,
    num_workers=4
)

dm.setup()

print(f'Input dimension: {dm.input_dim}')
print(f'Number of classes: {dm.n_classes}')
print(f'Labeled train samples: {len(dm.train_labeled_dataset)}')
print(f'Test samples: {len(dm.test_dataset)}')


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1440
Validation samples: 160
Test samples: 4000
Input dimension: 3072
Number of classes: 10
Labeled train samples: 1440
Test samples: 4000


## 2. Анализ дисбаланса классов и вычисление весов


In [11]:
train_labels = dm.train_labeled_dataset.dataset.y

unique_labels = np.unique(train_labels)
class_weights = compute_class_weight(
    'balanced',
    classes=unique_labels,
    y=train_labels
)

print(f'Class weights: {dict(zip(unique_labels, class_weights))}')

class_weights_tensor = torch.FloatTensor(class_weights)


Class weights: {np.int64(0): np.float64(1.0738255033557047), np.int64(1): np.float64(0.963855421686747), np.int64(2): np.float64(1.0256410256410255), np.int64(3): np.float64(1.103448275862069), np.int64(4): np.float64(0.9523809523809523), np.int64(5): np.float64(0.935672514619883), np.int64(6): np.float64(0.9523809523809523), np.int64(7): np.float64(1.0596026490066226), np.int64(8): np.float64(0.9248554913294798), np.int64(9): np.float64(1.0457516339869282)}


## 3. Создание модели


In [12]:
model = MultiBranchMLP(
    input_dim=dm.input_dim,
    hidden_dim=256,
    output_dim=dm.n_classes,
    num_blocks=4,
    dropout=0.1,
    combine_mode='concat'
)

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')


Model parameters: 4,080,650


## 4. Создание Lightning модуля


In [13]:
loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)

lightning_model = BaseLightningModule(
    model=model,
    loss_fn=loss_fn,
    optimizer_type='adamw',
    learning_rate=1e-3,
    task_type='multiclass'
)


## 5. Обучение модели


In [14]:
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='best_model-{epoch:02d}-{val_accuracy:.4f}',
    monitor='val_accuracy',
    mode='max',
    save_top_k=1,
    save_last=True
)

trainer = Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback],
    enable_checkpointing=True,
    logger=True,
    enable_progress_bar=True,
    enable_model_summary=True,
    accelerator='auto',
    devices='auto'
)

trainer.fit(lightning_model, dm)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: True, using: 1 TPU cores


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1440
Validation samples: 160
Test samples: 4000


Output()

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


## 6. Генерация псевдо-меток

### 6.1 Загрузка лучшей baseline модели

In [21]:
# Загружаем лучшую модель
best_model_path = checkpoint_callback.best_model_path
print(f'Loading best model from: {best_model_path}')

if best_model_path:
    base_model = BaseLightningModule.load_from_checkpoint(
        best_model_path,
        model=lightning_model.model,
        loss_fn=lightning_model.loss_fn)
else:
    base_model = lightning_model

base_model.eval()
device = next(base_model.parameters()).device
base_model.to(device)

# Загружаем неразмеченные данные
unlabeled_loader = dm.unlabeled_dataloader()

pseudo_labels = []
original_features = []

with torch.no_grad():
    for batch in unlabeled_loader:
        features = batch[0].to(device)
        if features.dim() == 1:
          features = features.unsqueeze(0)
        logits = base_model(features)
        probs = torch.softmax(logits, dim=1)
        max_probs, preds = torch.max(probs, dim=1)

        # Отбираем уверенные предсказания
        confidence_threshold = 0.9
        confident_mask = max_probs > confidence_threshold

        confident_preds = preds[confident_mask]
        confident_features = features[confident_mask]

        if confident_preds.numel() > 0:
            pseudo_labels.append(confident_preds.cpu().numpy())
            original_features.append(confident_features.cpu().numpy())

pseudo_labels = np.concatenate(pseudo_labels)
original_features = np.concatenate(original_features)

print(f'Найдено {len(pseudo_labels)} псевдо-меток с уверенностью > {confidence_threshold}')

Loading best model from: /content/checkpoints/best_model-epoch=99-val_accuracy=0.2877.ckpt
Найдено 86 псевдо-меток с уверенностью > 0.9


## 6.2 Создание нового датасета

In [24]:
# Исходные размеченные данные
X_train_labeled = dm.train_labeled_dataset.dataset.X
y_train_labeled = dm.train_labeled_dataset.dataset.y

# Объединяем с псевдо-метками
X_combined = np.concatenate([X_train_labeled, original_features])
y_combined = np.concatenate([y_train_labeled, pseudo_labels])

print(f'Размер нового обучающего датасета: {X_combined.shape}')

Размер нового обучающего датасета: (1686, 3072)


## 7. Переобучение на расширенных данных

In [30]:
# Создаем новый DataModule с расширенными данными
dm_pseudo = DataModule(
    data_dir=data_dir, # Используем тот же data_dir для test
    batch_size=128,
    num_workers=4
)
# Передаем новые данные в setup
dm_pseudo.setup_from_data(X_combined, y_combined, dm.test_dataset.X, dm.test_dataset.y)

# Вычисляем новые веса классов
new_class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(y_combined),
    y=y_combined
)
new_class_weights_tensor = torch.FloatTensor(new_class_weights)

# Создаем новую модель для переобучения
new_model = MultiBranchMLP(
    input_dim=dm.input_dim,
    hidden_dim=256,
    output_dim=dm.n_classes,
    num_blocks=4,
    dropout=0.1,
    combine_mode='concat'
)

new_loss_fn = nn.CrossEntropyLoss(weight=new_class_weights_tensor)

lightning_model_pseudo = BaseLightningModule(
    model=new_model,
    loss_fn=new_loss_fn,
    optimizer_type='adamw',
    learning_rate=1e-3,
    task_type='multiclass'
)

# Новый чекпоинт
checkpoint_callback_pseudo = ModelCheckpoint(
    dirpath='checkpoints_pseudo',
    filename='best_model_pseudo-{epoch:02d}-{val_accuracy:.4f}',
    monitor='val_accuracy',
    mode='max',
    save_top_k=1,
    save_last=True
)

# Новый тренер
trainer_pseudo = Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback_pseudo],
    enable_checkpointing=True,
    logger=True,
    enable_progress_bar=True,
    enable_model_summary=True,
    accelerator='auto',
    devices='auto'
)

trainer_pseudo.fit(lightning_model_pseudo, dm_pseudo)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: True, using: 1 TPU cores


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1440
Validation samples: 160
Test samples: 4000


Output()

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


## 8. Оценка на тестовой выборке


In [31]:
best_model_path_pseudo = checkpoint_callback_pseudo.best_model_path
print(f'Loading best pseudo-label model from: {best_model_path_pseudo}')

if best_model_path_pseudo:
    final_model = BaseLightningModule.load_from_checkpoint(
        best_model_path_pseudo,
        model=model,
        loss_fn=loss_fn,
        optimizer_type='adamw',
        learning_rate=1e-3,
        task_type='multiclass'
    )
else:
    final_model = lightning_model_pseudo

final_test_results = trainer_pseudo.test(final_model, dm_pseudo)

print('\n=== Финальные результаты на тестовой выборке (после псевдо-лейблинга) ===')
for key, value in final_test_results[0].items():
    print(f'{key}: {value:.4f}')


Loading best pseudo-label model from: /content/checkpoints_pseudo/best_model_pseudo-epoch=99-val_accuracy=0.3218.ckpt
Input dimension: 3072
Number of classes: 10
Labeled train samples: 1440
Validation samples: 160
Test samples: 4000


Output()


=== Финальные результаты на тестовой выборке (после псевдо-лейблинга) ===
test_loss: 16.3741
test_accuracy: 0.3343
test_f1_macro: 0.3343
