In [55]:
import time
from typing import Optional, Tuple, List
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [48]:
torch.cuda.is_available()

True

In [49]:
torch.__version__

'2.0.1+cu117'

In [50]:
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root='../models/mnist',
    train=True,
    transform=ToTensor(),
    download=True,
)

test_data = datasets.MNIST(
    root='../models/mnist',
    train=False,
    transform=ToTensor(),
    download=True,
)

In [51]:
class TrainHelper:
    @staticmethod
    def train(cnn: nn.Module,
              *,
              epochs: int,
              train_dataset: datasets.MNIST,
              test_dataset: Optional[datasets.MNIST] = None,
              print_results: bool = True,
              batch_size: int,
              device_name: str, 
              writer: Optional[SummaryWriter] = None) -> List[float]:
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=1)

        device = torch.device(device_name)

        cnn.to(device)
        cnn.train()

        optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
        loss_func = nn.CrossEntropyLoss()

        eval_results: List[float] = []

        steps = 0

        for epoch in range(epochs):
            for images, labels in train_loader:
                images = Variable(images.to(device))
                labels = Variable(labels.to(device))

                output = cnn(images)
                loss = loss_func(output, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                steps += 1
                if writer is not None:
                    writer.add_scalar("training loss", loss.item(), steps)

            if test_dataset is not None:
                eval_result = TrainHelper.test(cnn, test_dataset, device)
                eval_results.append(eval_result)
                if writer is not None:
                    writer.add_scalar("test accuracy", eval_result, steps)
                if print_results:
                    print(f"epoch {epoch}, accuracy = {eval_result}, loss = {loss.detach()}")
                cnn.train()

        return eval_results

    @staticmethod
    def test(cnn: nn.Module, test_dataset: datasets.MNIST, device=None) -> float:
        cnn.eval()
        loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=1)
        correct = 0
        incorrect = 0

        for images, labels in loader:
            if device is not None:
                images = images.to(device)

            results = cnn(images)
            predictions = results.detach().cpu().numpy().argmax(axis=1)
            oks = (predictions == labels.numpy()).sum()
            correct += oks
            incorrect += len(predictions) - oks

        return correct / (correct + incorrect)

    @staticmethod
    def train_models(models: List[nn.Module], device_name: str) -> Tuple[int, float]:
        """
        generator yields pair (trainable parameters count, best accuracy) for each network
        :param device_name: 'cuda' or 'cpu'
        """
        assert len(models) > 0

        for model in models:
            start = time.time()
            eval_results = TrainHelper.train(
                cnn=model,
                epochs=20,
                train_dataset=train_data,
                test_dataset=test_data,
                batch_size=2048,
                device_name=device_name,
                print_results=False
            )
            end = time.time()
            best_acc = max(eval_results)
            params_count = TrainHelper.total_parameters_count(model)
            print(f"best accuracy = {best_acc}, parameters = {params_count}, training time = {end - start}")
            yield params_count, best_acc

    @staticmethod
    def total_parameters_count(model: nn.Module) -> int:
        return sum(np.prod(p.size()) for p in model.parameters())

    @staticmethod
    def print_parameters(model: nn.Module):
        print(f"total parameters = {TrainHelper.total_parameters_count(model)}")
        for p in model.parameters():
            print(f"size {np.prod(p.size())}: {p.size()}")

In [52]:
class MyConvModel(nn.Module):
    def __init__(self, channels: int):
        super(MyConvModel, self).__init__()

        c = channels
        self.layers = nn.Sequential(
            self.conv(1, c, kernel_size=3),         # 28 - 26
            self.conv(c, c, kernel_size=3),         # 26 - 24
            nn.MaxPool2d(2),                        # 24 - 12

            self.conv(c, c * 2, kernel_size=3),     # 12 - 10
            self.conv(c * 2, c * 2, kernel_size=3), # 10 - 8
            nn.MaxPool2d(2),                        # 8 - 4

            self.conv(c * 2, c * 4, kernel_size=3), # 4 - 2
            self.conv(c * 4, c * 4, kernel_size=2), # 2 - 1

            nn.Conv2d(c * 4, 10, kernel_size=1, padding='valid', bias=True),
            nn.Flatten(),
        )

    def conv(self, in_ch: int, out_ch: int, *, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding='valid', bias=False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.1),
        )

    def forward(self, x: torch.Tensor):
        return self.layers(x)

In [56]:
model = MyConvModel(16)
with SummaryWriter(f'my_mnist/{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}') as writer:
    TrainHelper.train(model, epochs=10, train_dataset=train_data, test_dataset=test_data, print_results=True, batch_size=2000, device_name='cuda', writer=writer)

epoch 0, accuracy = 0.6563, loss = 0.4836249053478241
epoch 1, accuracy = 0.9804, loss = 0.2099757343530655
epoch 2, accuracy = 0.9862, loss = 0.12426448613405228
epoch 3, accuracy = 0.9896, loss = 0.08274900168180466
epoch 4, accuracy = 0.9918, loss = 0.06063356623053551
epoch 5, accuracy = 0.9921, loss = 0.04493088275194168
epoch 6, accuracy = 0.9925, loss = 0.04389992728829384
epoch 7, accuracy = 0.9931, loss = 0.03025810979306698
epoch 8, accuracy = 0.9939, loss = 0.026903457939624786
epoch 9, accuracy = 0.9945, loss = 0.018505867570638657
