In [1]:
from torchvision.datasets import MNIST, mnist
from torchvision import transforms

In [2]:
import torch.nn.functional as F

In [3]:
import torch
from torch import nn
from torch.autograd import Variable
from torch.distributions import Categorical
from torch.utils.data import DataLoader

In [4]:
from tqdm.notebook import tqdm_notebook

In [5]:
import matplotlib.pyplot as plt
import seaborn as sns

In [6]:
from itertools import chain
import pandas as pd

In [7]:
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [8]:
class CustomTargetTransform:
    def __init__(self, num_classes=10):
        self.num_classes = num_classes

    def __call__(self, target):
        new_target = torch.zeros(self.num_classes, dtype=torch.float, device=device)
        new_target[target] = 1
        return new_target

transform = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Lambda(lambda x: x.float().to(device))
])

dataset = mnist.FashionMNIST("data", download=True, train=True, transform=transform, target_transform=CustomTargetTransform())
data_loader = DataLoader(dataset, batch_size=800, shuffle=True)

In [9]:
dataset_target = mnist.FashionMNIST("data", download=True, train=False, transform=transforms.PILToTensor())
target_data = dataset_target.data.unsqueeze(1).float().to(device)
target_labels = dataset_target.targets.float().to(device)

In [10]:
def create_model(
        img_size,
        blocks_out_channels,
        blocks_kernel_size,
        blocks_stride,
        pool_kernel_size,
        pool_stride,
        lr,
    ) -> tuple[nn.Sequential, nn.CrossEntropyLoss, torch.optim.SGD, SummaryWriter]:
    blocks = []
    in_channels = 1
    outs = img_size
    for out_channels, kernel_size, stride in zip(blocks_out_channels, blocks_kernel_size, blocks_stride):
        blocks.append(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        )
        blocks.append(nn.ReLU())
        outs = (outs - kernel_size) // stride + 1
        in_channels = out_channels
    outs = (outs - pool_kernel_size) // pool_stride + 1
    outs = outs * outs * in_channels
    model = nn.Sequential(
        *blocks,
        nn.MaxPool2d(kernel_size=pool_kernel_size, stride=pool_stride),
        nn.Flatten(1),
        nn.Linear(outs, 10),
        nn.Softmax(),
    ).to(device)
    er_f = nn.CrossEntropyLoss()
    optim = torch.optim.SGD(model.parameters(), lr=lr)
    writer = SummaryWriter(comment=f"_{len(blocks_out_channels)}_{lr}_{blocks_kernel_size}_{blocks_stride}_{pool_kernel_size}_{pool_stride}")
    return model, er_f, optim, writer


In [11]:
def write_summary(writer: SummaryWriter, model: nn.Module, params):
    param_dict = {
            "blocks_out_channels": params[0], 
            "blocks_kernel_size": params[1], 
            "blocks_stride": params[2],
            "pool_kernel_size": params[3],
            "pool_stride": params[4],
            "lr": params[5],
        }
    param_table = pd.DataFrame(param_dict)
    predicted = model(target_data)
    predicted_labels = predicted.max(1).indices
    confusion_matrix = torch.zeros(10, 10, device=device)
    for i in range(len(predicted_labels)):
        confusion_matrix[predicted_labels[i].long()][target_labels[i].long()] += 1
    confusion_matrix = confusion_matrix.cpu()
    fig = plt.gcf()
    fig.clear()
    ax = fig.add_subplot(111)
    hist = predicted_labels[predicted_labels == target_labels]
    error_hist = target_labels[predicted_labels != target_labels]
    percent: torch.tensor = ((predicted_labels == target_labels).sum() / len(target_labels))
    sns.heatmap(confusion_matrix, annot=True, fmt="g", ax=ax)
    writer.add_figure("confusion_matrix", fig)
    writer.add_scalar("accuracy", percent)
    fig.clear()
    ax = fig.add_subplot(111)
    sns.histplot(hist.cpu(), stat="count", discrete=True, bins=range(10), ax=ax)
    writer.add_figure("right hist", fig)
    fig.clear()
    ax = fig.add_subplot(111)
    sns.histplot(error_hist.cpu(), stat="count", discrete=True, bins=range(10), ax=ax)
    writer.add_figure("error hist", fig)
    writer.add_text("param_table", param_table.to_markdown())
    writer.add_text("accuracy", str(percent.item()))

In [25]:
params = (
    # ((5, ), (3, ), (3, ), 3, 3, 0.01),
    # ((5, ), (3, ), (1, ), 3, 1, 0.1),
    # ((5, ), (3, ), (1, ), 3, 1, 0.001),
    # ((7, ), (3, ), (1, ), 3, 1, 0.1),
    # ((10, ), (3, ), (1, ), 3, 1, 0.1),
    # ((15, ), (3, ), (1, ), 3, 1, 0.1),
    # ((5, ), (3, ), (1, ), 3, 3, 0.01), # 54
    # ((5, ), (3, ), (3, ), 3, 3, 0.01), # 46
    # ((10, ), (2, ), (1, ), 3, 3, 0.01), # 49
    # ((10, ), (4, ), (1, ), 3, 3, 0.01), # 38
    # ((10, ), (2, ), (2, ), 3, 3, 0.01), # 60
    # ((10, ), (4, ), (2, ), 3, 3, 0.01), # 47

    # ((10, ), (2, ), (2, ), 3, 1, 0.01), # 20
    # ((10, ), (2, ), (2, ), 2, 2, 0.01), # 55
    # ((10, ), (2, ), (2, ), 4, 4, 0.01), # 28
)

In [26]:
for param in tqdm_notebook(params):
    model, er_f, optim, writer = create_model(28, *param)
    total_epochs = 0
    for epoch in tqdm_notebook(range(20)):
        for image, target in tqdm_notebook(data_loader, leave=False):
            optim.zero_grad()
            outs = model(image)
            loss = er_f(outs, target)
            loss.backward()
            optim.step()
            writer.add_scalar("loss", loss, total_epochs)
            total_epochs += 1
    write_summary(writer, model, param)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  return self._call_impl(*args, **kwargs)


  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

  return self._call_impl(*args, **kwargs)


In [127]:
model, er_f, optim, writer = create_model(28, *params[0])

In [57]:
torch.save(model, "model90.pt")