# MNIST: обучение модели 

In [1]:
import datetime
import yaml
from copy import deepcopy

import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from torchvision import datasets, transforms
import torchvision.models as models

2023-11-29 02:15:47.917375: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-29 02:15:47.941556: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Конфиг 

In [2]:
class CFG:
    dataset_name = "mnist"
    model_name = "resnet18"
    process_name = "train"
    full_model_name = f"{dataset_name}_{model_name}_{process_name}"
    
    gpu_num = 0
    device = f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu"
    num_workers = 16
    
    batch_size = 32
    train_size = 50000
    val_size = 10000
    test_size = 10000
    
    max_epoch_num = 1000
    early_stopping_patience = 10
    lr = 1e-4
    
    save_best_model = True
    seed = 42
    
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(CFG.seed)

## Подготовка данных 

In [3]:
normalize = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

mnist_train_val = datasets.MNIST("../input/", train=True, download=True, transform=normalize)
mnist_test = datasets.MNIST("../input/", train=False, download=True, transform=normalize)

In [4]:
generator = torch.Generator().manual_seed(CFG.seed)
mnist_train, mnist_val = torch.utils.data.random_split(
    mnist_train_val, 
    [CFG.train_size, CFG.val_size],
    generator=generator
)

In [8]:
assert len(mnist_train.indices) == CFG.train_size
assert len(mnist_val.indices) == CFG.val_size
assert len(mnist_test) == CFG.test_size

print(len(mnist_train.indices), len(mnist_val.indices), len(mnist_test))

50000 10000 10000


In [9]:
train_dataloader = torch.utils.data.DataLoader(
    mnist_train, 
    batch_size=CFG.batch_size, 
    shuffle=True, 
    num_workers=CFG.num_workers
)

val_dataloader = torch.utils.data.DataLoader(
    mnist_val, 
    batch_size=CFG.batch_size, 
    shuffle=False, 
    num_workers=CFG.num_workers
)

test_dataloader = torch.utils.data.DataLoader(
    mnist_test, 
    batch_size=CFG.batch_size, 
    shuffle=False, 
    num_workers=CFG.num_workers
)

## Модель 

In [10]:
class MnistModel(nn.Module):
    def __init__(self, base_model):
        super(MnistModel, self).__init__()
        self.base_model = base_model
        self.output_fc = nn.Linear(1000, 10)
        
    def forward(self, x):
        x = self.base_model(x)
        x = self.output_fc(x)
        return x

In [11]:
model = MnistModel(base_model=models.resnet18())
model.to(CFG.device);

## Обучение 

In [19]:
def validate_model(model, dataloader, score_f):
    model.eval()
    
    mean_score = 0
    batches_n = 0
    with torch.no_grad(), torch.cuda.amp.autocast():
        for batch_i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(CFG.device), targets.to(CFG.device)
            pred = model(inputs)

            mean_score += score_f(pred, targets)
            batches_n += 1
            
    mean_score /= batches_n
    return float(mean_score)

In [13]:
loss_f = nn.CrossEntropyLoss()
accuracy_f = lambda pred, targets: float((pred.argmax(axis=1) == targets.to(CFG.device)).float().mean())

optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)

In [14]:
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"../logs/{CFG.dataset_name}/{CFG.model_name}/{CFG.process_name}/{curr_time}"
writer = SummaryWriter(log_dir=log_dir)

In [15]:
best_val_acc = 0
train_time = 0
for epoch_i in range(1, CFG.max_epoch_num + 1):
    ### TRAIN 
    model.train()
        
    mean_train_acc = 0
    train_batches_n = 0
    for batch_i, (inputs, targets) in enumerate(tqdm(train_dataloader)):
        inputs, targets = inputs.to(CFG.device), targets.to(CFG.device)
        
        batch_start_time = datetime.datetime.now()
        
        pred = model(inputs)
        loss = loss_f(pred, targets)

        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_time += (datetime.datetime.now() - batch_start_time).total_seconds()
        
        mean_train_acc += accuracy_f(pred, targets)
        train_batches_n += 1

    mean_train_acc /= train_batches_n
    
    ### VAL
    
    mean_val_acc = validate_model(model, val_dataloader, accuracy_f)
    
    ### TENSORBOARD
    
    writer.add_scalars(
        "Accuracy",
        {"train": mean_train_acc, "val": mean_val_acc}, 
        global_step=epoch_i
    )
    writer.flush()

    ### SAVE BEST
    
    mean_train_acc_perc = round(float(mean_train_acc) * 100, 2)
    mean_val_acc_perc = round(float(mean_val_acc) * 100, 2)
    print(f"Epoch {epoch_i}: train_acc = {mean_train_acc_perc}%, val_acc = {mean_val_acc_perc}%", end="; ")

    if mean_val_acc > best_val_acc:
        best_epoch_i = epoch_i
        best_val_acc = mean_val_acc
        best_train_acc = mean_train_acc
        best_model = deepcopy(model)
        best_model_train_time = train_time
        
        print(f'new best model')
    elif epoch_i - best_epoch_i > CFG.early_stopping_patience:
        print(f'early stopping')
        break
    else:
        print("continue")

  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 1: train_acc = 94.98%, val_acc = 97.84%; new best model


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 2: train_acc = 98.05%, val_acc = 98.44%; new best model


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 3: train_acc = 98.52%, val_acc = 98.58%; new best model


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 4: train_acc = 98.82%, val_acc = 98.35%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 5: train_acc = 99.03%, val_acc = 98.77%; new best model


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 6: train_acc = 99.16%, val_acc = 98.7%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 7: train_acc = 99.28%, val_acc = 98.91%; new best model


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 8: train_acc = 99.31%, val_acc = 98.71%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 9: train_acc = 99.47%, val_acc = 99.13%; new best model


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 10: train_acc = 99.47%, val_acc = 98.95%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 11: train_acc = 99.54%, val_acc = 99.07%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 12: train_acc = 99.61%, val_acc = 98.51%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 13: train_acc = 99.62%, val_acc = 98.93%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 14: train_acc = 99.71%, val_acc = 99.13%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 15: train_acc = 99.68%, val_acc = 98.86%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 16: train_acc = 99.69%, val_acc = 99.12%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 17: train_acc = 99.7%, val_acc = 99.03%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 18: train_acc = 99.73%, val_acc = 99.1%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 19: train_acc = 99.73%, val_acc = 99.07%; continue


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 20: train_acc = 99.76%, val_acc = 99.11%; early stopping


In [20]:
mean_test_acc = validate_model(best_model, test_dataloader, accuracy_f)

print(f"Best model train acc = {round(best_train_acc * 100, 2)}%")
print(f"Best model val acc = {round(best_val_acc * 100, 2)}%")
print(f"Best model test acc = {round(mean_test_acc * 100, 2)}%")

test_acc_str = str(round(mean_test_acc * 100, 2)).replace(".", "_")
if CFG.save_best_model:
    best_model_path = f"../models/{CFG.full_model_name}_acc_{test_acc_str}.torch"
    torch.save(best_model.state_dict(), best_model_path)
    print(f"Best model saved: {best_model_path}")

Best model train acc = 99.47%
Best model val acc = 99.13%
Best model test acc = 99.1%
Best model saved: ../models/mnist_resnet18_train_acc_99_1.torch


## Сохранение результатов 

In [23]:
research_info = {
    "dataset_name": CFG.dataset_name,
    "model_name": CFG.model_name,
    "process_name": CFG.process_name,
    "batch_size": CFG.batch_size,
    "train_size": CFG.train_size,
    "val_size": CFG.val_size,
    "test_size": CFG.test_size,
    "best_model": {
        "best_epoch": best_epoch_i,
        "accuracy": {
            "best_train_acc": best_train_acc,
            "best_val_acc": best_val_acc,
            "best_test_acc": mean_test_acc,
        },
        "path": best_model_path,
        "log_dir": log_dir,
        "train_time_sec": best_model_train_time
    }
}

with open(f"../research_info/{CFG.full_model_name}_acc_{test_acc_str}.yaml", "w") as f:
    yaml.dump(research_info, f)