# Baseline: Обучение Multi-Branch MLP на размеченных данных


In [1]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

from model import MultiBranchMLP
from data_module import DataModule
from lightning_module import BaseLightningModule

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import random
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


## 1. Загрузка данных


In [2]:
data_dir = '../data'

# --- Нормализация входов (минимальная правка baseline) ---
import data_module as dm_module

train_path = os.path.join(data_dir, 'train_labeled.csv')
train_df = pd.read_csv(train_path)

X_train = train_df.drop('target', axis=1).values.astype(np.float32) / 255.0
MEAN_T = torch.from_numpy(X_train.mean(axis=0)).float()
STD_T = torch.from_numpy(X_train.std(axis=0) + 1e-6).float()

# Патчим CSVDataset так, чтобы DataModule автоматически отдавал нормализованные тензоры
_BaseCSVDataset = dm_module.CSVDataset

class CSVDatasetNormalized(_BaseCSVDataset):
    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx]).float() / 255.0
        x = (x - MEAN_T) / STD_T
        if self.has_target:
            y = int(self.y[idx])
            return x, y
        return x

dm_module.CSVDataset = CSVDatasetNormalized
# --- конец блока нормализации ---

dm = DataModule(
    data_dir=data_dir,
    batch_size=128,
    num_workers=0
)

dm.setup()

print(f'Input dimension: {dm.input_dim}')
print(f'Number of classes: {dm.n_classes}')
print(f'Labeled train samples: {len(dm.train_labeled_dataset)}')
print(f'Test samples: {len(dm.test_dataset)}')


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000
Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000


## 2. Анализ дисбаланса классов и вычисление весов


In [3]:
train_labels = dm.train_labeled_dataset.y

unique_labels = np.unique(train_labels)
class_weights = compute_class_weight(
    'balanced',
    classes=unique_labels,
    y=train_labels
)

print(f'Class weights: {dict(zip(unique_labels, class_weights))}')

class_weights_tensor = torch.FloatTensor(class_weights)


Class weights: {0: 1.0738255033557047, 1: 0.963855421686747, 2: 1.0256410256410255, 3: 1.103448275862069, 4: 0.9523809523809523, 5: 0.935672514619883, 6: 0.9523809523809523, 7: 1.0596026490066226, 8: 0.9248554913294798, 9: 1.0457516339869282}


## 3. Создание модели


In [4]:
model = MultiBranchMLP(
    input_dim=dm.input_dim,
    hidden_dim=256,
    output_dim=dm.n_classes,
    num_blocks=4,
    dropout=0.1,
    combine_mode='concat'
)

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')


Model parameters: 4,080,650


## 4. Создание Lightning модуля


In [5]:
loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)

lightning_model = BaseLightningModule(
    model=model,
    loss_fn=loss_fn,
    optimizer_type='adamw',
    learning_rate=1e-3,
    task_type='multiclass'
)


## 5. Обучение модели


In [6]:
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='best_model-{epoch:02d}-{val_accuracy:.4f}',
    monitor='val_accuracy',
    mode='max',
    save_top_k=1,
    save_last=True
)

trainer = Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback],
    enable_checkpointing=True,
    logger=True,
    enable_progress_bar=True,
    enable_model_summary=True,
    accelerator='auto',
    devices='auto'
)

trainer.fit(lightning_model, dm)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 5070 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode  | FLOPs
-------------------------------------------------------------
0 | model   | MultiBranchMLP   | 4.1 M  | train | 0    
1 | loss_fn | CrossEntropyLoss | 0      | train | 0    
2 | metrics | ModuleDict       | 0      | train | 0    
-------------------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.323    Total estimated model params size (MB)
70        Modules in train mode


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000


Sanity Checking: |                                                                               | 0/? [00:00<…

Epoch 0: accuracy=0.0859, f1_macro=0.0837


Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Epoch 0: accuracy=0.2347, f1_macro=0.2273


Validation: |                                                                                    | 0/? [00:00<…

Epoch 1: accuracy=0.2689, f1_macro=0.2550


Validation: |                                                                                    | 0/? [00:00<…

Epoch 2: accuracy=0.2838, f1_macro=0.2748


Validation: |                                                                                    | 0/? [00:00<…

Epoch 3: accuracy=0.2987, f1_macro=0.2932


Validation: |                                                                                    | 0/? [00:00<…

Epoch 4: accuracy=0.3073, f1_macro=0.3032


Validation: |                                                                                    | 0/? [00:00<…

Epoch 5: accuracy=0.3153, f1_macro=0.3116


Validation: |                                                                                    | 0/? [00:00<…

Epoch 6: accuracy=0.3181, f1_macro=0.3146


Validation: |                                                                                    | 0/? [00:00<…

Epoch 7: accuracy=0.3209, f1_macro=0.3180


Validation: |                                                                                    | 0/? [00:00<…

Epoch 8: accuracy=0.3228, f1_macro=0.3203


Validation: |                                                                                    | 0/? [00:00<…

Epoch 9: accuracy=0.3252, f1_macro=0.3230


Validation: |                                                                                    | 0/? [00:00<…

Epoch 10: accuracy=0.3268, f1_macro=0.3252


Validation: |                                                                                    | 0/? [00:00<…

Epoch 11: accuracy=0.3279, f1_macro=0.3267


Validation: |                                                                                    | 0/? [00:00<…

Epoch 12: accuracy=0.3283, f1_macro=0.3276


Validation: |                                                                                    | 0/? [00:00<…

Epoch 13: accuracy=0.3294, f1_macro=0.3287


Validation: |                                                                                    | 0/? [00:00<…

Epoch 14: accuracy=0.3295, f1_macro=0.3288


Validation: |                                                                                    | 0/? [00:00<…

Epoch 15: accuracy=0.3309, f1_macro=0.3300


Validation: |                                                                                    | 0/? [00:00<…

Epoch 16: accuracy=0.3316, f1_macro=0.3308


Validation: |                                                                                    | 0/? [00:00<…

Epoch 17: accuracy=0.3315, f1_macro=0.3311


Validation: |                                                                                    | 0/? [00:00<…

Epoch 18: accuracy=0.3315, f1_macro=0.3315


Validation: |                                                                                    | 0/? [00:00<…

Epoch 19: accuracy=0.3317, f1_macro=0.3318


Validation: |                                                                                    | 0/? [00:00<…

Epoch 20: accuracy=0.3325, f1_macro=0.3327


Validation: |                                                                                    | 0/? [00:00<…

Epoch 21: accuracy=0.3327, f1_macro=0.3329


Validation: |                                                                                    | 0/? [00:00<…

Epoch 22: accuracy=0.3327, f1_macro=0.3329


Validation: |                                                                                    | 0/? [00:00<…

Epoch 23: accuracy=0.3335, f1_macro=0.3338


Validation: |                                                                                    | 0/? [00:00<…

Epoch 24: accuracy=0.3341, f1_macro=0.3345


Validation: |                                                                                    | 0/? [00:00<…

Epoch 25: accuracy=0.3351, f1_macro=0.3355


Validation: |                                                                                    | 0/? [00:00<…

Epoch 26: accuracy=0.3351, f1_macro=0.3355


Validation: |                                                                                    | 0/? [00:00<…

Epoch 27: accuracy=0.3360, f1_macro=0.3364


Validation: |                                                                                    | 0/? [00:00<…

Epoch 28: accuracy=0.3363, f1_macro=0.3366


Validation: |                                                                                    | 0/? [00:00<…

Epoch 29: accuracy=0.3367, f1_macro=0.3370


Validation: |                                                                                    | 0/? [00:00<…

Epoch 30: accuracy=0.3375, f1_macro=0.3379


Validation: |                                                                                    | 0/? [00:00<…

Epoch 31: accuracy=0.3374, f1_macro=0.3379


Validation: |                                                                                    | 0/? [00:00<…

Epoch 32: accuracy=0.3375, f1_macro=0.3381


Validation: |                                                                                    | 0/? [00:00<…

Epoch 33: accuracy=0.3375, f1_macro=0.3381


Validation: |                                                                                    | 0/? [00:00<…

Epoch 34: accuracy=0.3370, f1_macro=0.3377


Validation: |                                                                                    | 0/? [00:00<…

Epoch 35: accuracy=0.3367, f1_macro=0.3374


Validation: |                                                                                    | 0/? [00:00<…

Epoch 36: accuracy=0.3363, f1_macro=0.3370


Validation: |                                                                                    | 0/? [00:00<…

Epoch 37: accuracy=0.3369, f1_macro=0.3376


Validation: |                                                                                    | 0/? [00:00<…

Epoch 38: accuracy=0.3366, f1_macro=0.3373


Validation: |                                                                                    | 0/? [00:00<…

Epoch 39: accuracy=0.3370, f1_macro=0.3376


Validation: |                                                                                    | 0/? [00:00<…

Epoch 40: accuracy=0.3365, f1_macro=0.3372


Validation: |                                                                                    | 0/? [00:00<…

Epoch 41: accuracy=0.3366, f1_macro=0.3373


Validation: |                                                                                    | 0/? [00:00<…

Epoch 42: accuracy=0.3364, f1_macro=0.3370


Validation: |                                                                                    | 0/? [00:00<…

Epoch 43: accuracy=0.3367, f1_macro=0.3373


Validation: |                                                                                    | 0/? [00:00<…

Epoch 44: accuracy=0.3366, f1_macro=0.3373


Validation: |                                                                                    | 0/? [00:00<…

Epoch 45: accuracy=0.3363, f1_macro=0.3370


Validation: |                                                                                    | 0/? [00:00<…

Epoch 46: accuracy=0.3362, f1_macro=0.3369


Validation: |                                                                                    | 0/? [00:00<…

Epoch 47: accuracy=0.3360, f1_macro=0.3367


Validation: |                                                                                    | 0/? [00:00<…

Epoch 48: accuracy=0.3356, f1_macro=0.3363


Validation: |                                                                                    | 0/? [00:00<…

Epoch 49: accuracy=0.3353, f1_macro=0.3362


Validation: |                                                                                    | 0/? [00:00<…

Epoch 50: accuracy=0.3347, f1_macro=0.3356


Validation: |                                                                                    | 0/? [00:00<…

Epoch 51: accuracy=0.3346, f1_macro=0.3355


Validation: |                                                                                    | 0/? [00:00<…

Epoch 52: accuracy=0.3345, f1_macro=0.3354


Validation: |                                                                                    | 0/? [00:00<…

Epoch 53: accuracy=0.3344, f1_macro=0.3353


Validation: |                                                                                    | 0/? [00:00<…

Epoch 54: accuracy=0.3344, f1_macro=0.3353


Validation: |                                                                                    | 0/? [00:00<…

Epoch 55: accuracy=0.3344, f1_macro=0.3354


Validation: |                                                                                    | 0/? [00:00<…

Epoch 56: accuracy=0.3343, f1_macro=0.3353


Validation: |                                                                                    | 0/? [00:00<…

Epoch 57: accuracy=0.3344, f1_macro=0.3354


Validation: |                                                                                    | 0/? [00:00<…

Epoch 58: accuracy=0.3343, f1_macro=0.3353


Validation: |                                                                                    | 0/? [00:00<…

Epoch 59: accuracy=0.3343, f1_macro=0.3353


Validation: |                                                                                    | 0/? [00:00<…

Epoch 60: accuracy=0.3339, f1_macro=0.3349


Validation: |                                                                                    | 0/? [00:00<…

Epoch 61: accuracy=0.3338, f1_macro=0.3348


Validation: |                                                                                    | 0/? [00:00<…

Epoch 62: accuracy=0.3338, f1_macro=0.3347


Validation: |                                                                                    | 0/? [00:00<…

Epoch 63: accuracy=0.3338, f1_macro=0.3347


Validation: |                                                                                    | 0/? [00:00<…

Epoch 64: accuracy=0.3340, f1_macro=0.3348


Validation: |                                                                                    | 0/? [00:00<…

Epoch 65: accuracy=0.3338, f1_macro=0.3347


Validation: |                                                                                    | 0/? [00:00<…

Epoch 66: accuracy=0.3337, f1_macro=0.3346


Validation: |                                                                                    | 0/? [00:00<…

Epoch 67: accuracy=0.3339, f1_macro=0.3347


Validation: |                                                                                    | 0/? [00:00<…

Epoch 68: accuracy=0.3341, f1_macro=0.3349


Validation: |                                                                                    | 0/? [00:00<…

Epoch 69: accuracy=0.3342, f1_macro=0.3350


Validation: |                                                                                    | 0/? [00:00<…

Epoch 70: accuracy=0.3343, f1_macro=0.3352


Validation: |                                                                                    | 0/? [00:00<…

Epoch 71: accuracy=0.3344, f1_macro=0.3352


Validation: |                                                                                    | 0/? [00:00<…

Epoch 72: accuracy=0.3344, f1_macro=0.3352


Validation: |                                                                                    | 0/? [00:00<…

Epoch 73: accuracy=0.3344, f1_macro=0.3353


Validation: |                                                                                    | 0/? [00:00<…

Epoch 74: accuracy=0.3346, f1_macro=0.3354


Validation: |                                                                                    | 0/? [00:00<…

Epoch 75: accuracy=0.3347, f1_macro=0.3356


Validation: |                                                                                    | 0/? [00:00<…

Epoch 76: accuracy=0.3349, f1_macro=0.3357


Validation: |                                                                                    | 0/? [00:00<…

Epoch 77: accuracy=0.3349, f1_macro=0.3357


Validation: |                                                                                    | 0/? [00:00<…

Epoch 78: accuracy=0.3348, f1_macro=0.3356


Validation: |                                                                                    | 0/? [00:00<…

Epoch 79: accuracy=0.3348, f1_macro=0.3357


Validation: |                                                                                    | 0/? [00:00<…

Epoch 80: accuracy=0.3349, f1_macro=0.3358


Validation: |                                                                                    | 0/? [00:00<…

Epoch 81: accuracy=0.3348, f1_macro=0.3357


Validation: |                                                                                    | 0/? [00:00<…

Epoch 82: accuracy=0.3350, f1_macro=0.3358


Validation: |                                                                                    | 0/? [00:00<…

Epoch 83: accuracy=0.3351, f1_macro=0.3360


Validation: |                                                                                    | 0/? [00:00<…

Epoch 84: accuracy=0.3352, f1_macro=0.3359


Validation: |                                                                                    | 0/? [00:00<…

Epoch 85: accuracy=0.3352, f1_macro=0.3360


Validation: |                                                                                    | 0/? [00:00<…

Epoch 86: accuracy=0.3353, f1_macro=0.3361


Validation: |                                                                                    | 0/? [00:00<…

Epoch 87: accuracy=0.3353, f1_macro=0.3361


Validation: |                                                                                    | 0/? [00:00<…

Epoch 88: accuracy=0.3352, f1_macro=0.3360


Validation: |                                                                                    | 0/? [00:00<…

Epoch 89: accuracy=0.3353, f1_macro=0.3360


Validation: |                                                                                    | 0/? [00:00<…

Epoch 90: accuracy=0.3354, f1_macro=0.3361


Validation: |                                                                                    | 0/? [00:00<…

Epoch 91: accuracy=0.3351, f1_macro=0.3359


Validation: |                                                                                    | 0/? [00:00<…

Epoch 92: accuracy=0.3349, f1_macro=0.3358


Validation: |                                                                                    | 0/? [00:00<…

Epoch 93: accuracy=0.3351, f1_macro=0.3360


Validation: |                                                                                    | 0/? [00:00<…

Epoch 94: accuracy=0.3349, f1_macro=0.3358


Validation: |                                                                                    | 0/? [00:00<…

Epoch 95: accuracy=0.3349, f1_macro=0.3359


Validation: |                                                                                    | 0/? [00:00<…

Epoch 96: accuracy=0.3350, f1_macro=0.3359


Validation: |                                                                                    | 0/? [00:00<…

Epoch 97: accuracy=0.3347, f1_macro=0.3356


Validation: |                                                                                    | 0/? [00:00<…

Epoch 98: accuracy=0.3346, f1_macro=0.3355


Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: accuracy=0.3346, f1_macro=0.3354


## 6. Оценка на тестовой выборке


In [7]:
best_model_path = checkpoint_callback.best_model_path
print(f'Loading best model from: {best_model_path}')

if best_model_path:
    best_model = BaseLightningModule.load_from_checkpoint(
        best_model_path,
        model=model,
        loss_fn=loss_fn,
        optimizer_type='adamw',
        learning_rate=1e-3,
        task_type='multiclass'
    )
else:
    best_model = lightning_model

test_results = trainer.test(best_model, dm)

print('\n=== Финальные результаты на тестовой выборке ===')
for key, value in test_results[0].items():
    print(f'{key}: {value:.4f}')


Loading best model from: C:\Users\СИВ\Desktop\МИСИС\3 семестр\Глубокое машинное обучение\GIT\dl2025\lesson7\homework\baseline\checkpoints\best_model-epoch=30-val_accuracy=0.3375.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Input dimension: 3072
Number of classes: 10
Labeled train samples: 1600
Test samples: 4000


Testing: |                                                                                       | 0/? [00:00<…

Test results: accuracy=0.3630, f1_macro=0.3631
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.3630000054836273
      test_f1_macro         0.36310774087905884
        test_loss            5.019516468048096
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

=== Финальные результаты на тестовой выборке ===
test_loss: 5.0195
test_accuracy: 0.3630
test_f1_macro: 0.3631
