In [None]:
import basic as ib
import torch.nn as nn
import torch
import time
import matplotlib.pyplot as plt

In [None]:
class Conv(nn.Module):
    def __init__(self, num_classes, input_channels=3, act=nn.ReLU):
        '''
        :param num_classes: Сколько классов
        :param input_channels: Число входных каналов
        '''
        base_channels = 16  # базовое значение количества промежуточных каналов
        super().__init__()
        def conv_block(in_ch, out_ch, kernel=3, stride=1, padding=1):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=padding, bias=False),
                act(),
            )

        self.features = nn.Sequential(
            conv_block(input_channels, base_channels),
            conv_block(base_channels, base_channels),
            nn.MaxPool2d(2),

            conv_block(base_channels, base_channels*2),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.classifier = nn.Linear(base_channels*2, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
class ConvBN(nn.Module):  # Batch нормализация
    def __init__(self, num_classes, input_channels=3, act=nn.ReLU):
        '''
        :param num_classes: Сколько классов
        :param input_channels: Число входных каналов
        '''
        base_channels = 16  # базовое значение количества промежуточных каналов
        super().__init__()
        def conv_block(in_ch, out_ch, kernel=3, stride=1, padding=1):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(out_ch),
                act(),
            )

        self.features = nn.Sequential(
            conv_block(input_channels, base_channels),
            conv_block(base_channels, base_channels),
            nn.MaxPool2d(2),

            conv_block(base_channels, base_channels*2),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.classifier = nn.Linear(base_channels*2, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
class ConvLN(nn.Module):  # Layer нормализация
    def __init__(self, num_classes, input_channels=3, act=nn.ReLU):
        '''
        :param num_classes: Сколько классов
        :param input_channels: Число входных каналов
        '''
        base_channels = 16  # базовое значение количества промежуточных каналов
        super().__init__()
        def conv_block(in_ch, out_ch, kernel=3, stride=1, padding=1):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride, padding=padding, bias=False),
                nn.GroupNorm(1, out_ch),
                act(),
            )

        self.features = nn.Sequential(
            conv_block(input_channels, base_channels),
            conv_block(base_channels, base_channels),
            nn.MaxPool2d(2),

            conv_block(base_channels, base_channels*2),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.classifier = nn.Linear(base_channels*2, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Устройство:", device)

Устройство: cuda


In [7]:
train_losses, train_accuracies, val_losses, val_accuracies = {}, {}, {}, {}

In [None]:
def plot_download(val_data, name_db, metric):  # Загрузка для сравнений между моделями
    ib.clear_output(wait=True)

    plt.figure(figsize=(6,4))
    name = ["Layer", "Batch", "Without"]
    names = ["LayerNorm", "BatchNorm", "Prime"]
    for i in range(3):
        plt.plot(val_data[names[i]], label=f"val CEL {name[i]}")
        plt.xlabel("epoch"); plt.ylabel(metric)
    plt.title(f"{name_db}")
    plt.legend(); plt.grid(True)
    display(plt.gcf())

    plt.savefig(f'graphs/{name_db}.{metric}.png', dpi=500, bbox_inches='tight')

    plt.close()

In [None]:
%%time

names = {
    Conv: "Prime",
    ConvBN: "BatchNorm",
    ConvLN: "LayerNorm",
}

for data_base in ["CIFAR10", "FashionMNIST", "KMNIST", "MNIST", "SVHN"]:  # Датасеты
    train_losses[data_base], train_accuracies[data_base], val_losses[data_base], val_accuracies[data_base] = {}, {}, {}, {}
    # Загрузка данных
    train_loader, val_loader, test_loader, num_classes, input_size, is_gray = ib.get_dataloaders(
        data_base, num_workers=16
    )
    cannels = 3  # RGB
    for model_basic in [Conv, ConvBN, ConvLN]:
        model = model_basic(num_classes, input_channels=cannels).to(device)
        train_losses[data_base][names[model_basic]], train_accuracies[data_base][names[model_basic]], val_losses[data_base][names[model_basic]], val_accuracies[data_base][names[model_basic]] = ib.train(model, device, train_loader, val_loader, names[model_basic], data_base)
    plot_download(val_losses[data_base], data_base, "Loss")
    plot_download(val_accuracies[data_base], data_base, "Accuracy")