In [46]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import FashionMNIST

from collections import defaultdict


In [47]:
NAMES = ["T-shirt or top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
COLORS = ["red", "green", "blue", "yellow", "aqua", "navy", "maroon", "magenta", "orange", "crimson"]
ENUM = 100

In [48]:
class ConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super().__init__()
        self.conv = nn.Conv2d(ch_in, ch_out,
                              kernel_size=(3, 3), stride=stride)
        self.bn = nn.BatchNorm2d(ch_out)
        self.relu = nn.ReLU()

    def forward(self, input):
        x = self.conv(input)
        x = self.bn(x)
        x = self.relu(x)
        return x


class NeuralNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        layer_config = ((64, 2), (64, 1), (128, 2), (128, 1))

        ch_in = 1
        block_list = []
        for ch_out, stride in layer_config:
            block = ConvBlock(ch_in, ch_out, stride)
            block_list.append(block)
            ch_in = ch_out

        self.backbone = nn.Sequential(*block_list)

        ### Add bottleneck layer  ###
        ch_bn = 2  # number of channels in bottleneck
        # it is called neck because it's between backbone and head
        self.neck = nn.Linear(layer_config[-1][0], ch_bn)

        self.head = nn.Linear(ch_bn, num_classes)

    def forward(self, input):
        featuremap = self.backbone(input)
        squashed = F.adaptive_avg_pool2d(featuremap, output_size=(1, 1))
        squeezed = squashed.view(squashed.shape[0], -1)

        # Save bottleneck result in class member
        self.neck_res = self.neck(squeezed)

        pred = self.head(self.neck_res)
        return pred

    @classmethod
    def loss(cls, pred, gt):
        return F.cross_entropy(pred, gt)

In [49]:
class Trainer:
    def __init__(self):

        self.train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(size=(28, 28), scale=(0.7, 1.1)),
            transforms.ToTensor(),
        ])
        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        train_dataset = FashionMNIST("./data", train=True,
                                     transform=self.train_transform,
                                     download=True)
        val_dataset = FashionMNIST("./data", train=False,
                                   transform=self.val_transform,
                                   download=True)

        # Save val_dataset in class member
        self.val_dataset = val_dataset

        # Save 10 samples of each class
        self.samples = defaultdict(list)
        for idx, sample in enumerate(val_dataset):
            _, label = sample
            if len(self.samples[label]) < 10:
                self.samples[label].append(val_dataset[idx])

        self.plots_data = []

        batch_size = 1024
        self.train_loader = data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True, num_workers=4)
        self.val_loader = data.DataLoader(val_dataset, batch_size=batch_size,
                                          shuffle=False, num_workers=4)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.net = NeuralNet()
        self.net.to(self.device)

        self.logger = SummaryWriter()
        self.i_batch = 0

    def train(self):

        num_epochs = ENUM

        optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3)

        for i_epoch in range(num_epochs):
            self.i_epoch = i_epoch
            self.net.train()

            for feature_batch, gt_batch in self.train_loader:
                feature_batch = feature_batch.to(self.device)
                gt_batch = gt_batch.to(self.device)

                pred_batch = self.net(feature_batch)

                loss = NeuralNet.loss(pred_batch, gt_batch)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                self.logger.add_scalar("train/loss", loss.item(), self.i_batch)

                if self.i_batch % 100 == 0:
                    print(f"batch={self.i_batch} loss={loss.item():.6f}")

                self.i_batch += 1

            self.validate()

    def validate(self):
        self.net.eval()

        loss_all = []
        pred_all = []
        gt_all = []

        self.save_plots_data()

        for feature_batch, gt_batch in self.val_loader:
            feature_batch = feature_batch.to(self.device)
            gt_batch = gt_batch.to(self.device)

            with torch.no_grad():
                pred_batch = self.net(feature_batch)
                loss = NeuralNet.loss(pred_batch, gt_batch)

            loss_all.append(loss.item())
            pred_all.append(pred_batch.cpu().numpy())
            gt_all.append(gt_batch.cpu().numpy())

        loss_mean = np.mean(np.array(loss_all))
        pred_all = np.argmax(np.concatenate(pred_all, axis=0), axis=1)
        gt_all = np.concatenate(np.array(gt_all))

        accuracy = np.sum(np.equal(pred_all, gt_all)) / len(pred_all)

        self.logger.add_scalar("val/loss", loss_mean, self.i_batch)
        self.logger.add_scalar("val/accuracy", accuracy, self.i_batch)

        print(f"Val_loss={loss_mean} val_accu={accuracy:.6f}")

    def save_plots_data(self):
        # Run through saved samples and save plot data
        for label, samples in self.samples.items():
            all_samples = [sample[0][None, :, :, :] for sample in samples]
            feature_batch = torch.cat(all_samples, 0).to(self.device)

            with torch.no_grad():
                self.net(feature_batch).to(self.device)

            xs = self.net.neck_res[:, 0].cpu().detach().numpy()
            ys = self.net.neck_res[:, 1].cpu().detach().numpy()

            for idx, (x, y) in enumerate(zip(xs, ys)):
                self.plots_data.append((self.i_epoch, label, idx, x, y))
                # print(self.plots_data[-1])


In [50]:

trainer = Trainer()
trainer.train()
print("Done!")


batch=0 loss=2.360684
Val_loss=1.6073572397232057 val_accu=0.378800


  gt_all = np.concatenate(np.array(gt_all))


batch=100 loss=1.429117
Val_loss=1.3689235329627991 val_accu=0.471400
Val_loss=1.1820245742797852 val_accu=0.565700
batch=200 loss=1.154199
Val_loss=1.0412736535072327 val_accu=0.625300
Val_loss=1.0489192128181457 val_accu=0.612600
batch=300 loss=0.893327
Val_loss=0.8580572843551636 val_accu=0.711100
batch=400 loss=0.869773
Val_loss=0.8164847195148468 val_accu=0.773900
Val_loss=0.8158915162086486 val_accu=0.772000
batch=500 loss=0.782626
Val_loss=0.7533595383167266 val_accu=0.788600
Val_loss=0.753343665599823 val_accu=0.787700
batch=600 loss=0.745082
Val_loss=0.7192492365837098 val_accu=0.794200
batch=700 loss=0.714647
Val_loss=0.788089120388031 val_accu=0.771100
Val_loss=0.8862158298492432 val_accu=0.738300
batch=800 loss=0.728464
Val_loss=0.8215635299682618 val_accu=0.768000
Val_loss=0.6726749956607818 val_accu=0.809900
batch=900 loss=0.667508
Val_loss=0.7473507940769195 val_accu=0.795000
batch=1000 loss=0.671206
Val_loss=0.673649662733078 val_accu=0.810300
Val_loss=0.623761647939682

In [51]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

import pathlib as pl
from datetime import datetime

plt.style.use('Solarize_Light2')

plots_data = trainer.plots_data

# save plots
plots_path = pl.Path("./plots/")
plots_path = plots_path / datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
plots_path.mkdir()

for i in range(ENUM):
    cur_path = plots_path / f"e{i:02}.png"

    # get all plots' data for current epoch
    cur_epoch_data = [plot_data for j, plot_data in enumerate(
        plots_data) if plot_data[0] == i]

    figure(figsize=(12, 10), dpi=80)
    plt.grid()

    plt.xlim(-60, 60)
    plt.ylim(-130, 90)

    for label, (color, name) in enumerate(zip(COLORS, NAMES)):
        x, y = [], []

        for plot_data in cur_epoch_data:
            if plot_data[1] != label:
                continue

            for idx in range(10):
                if plot_data[2] != idx:
                    continue

                x.append(plot_data[3])
                y.append(plot_data[4])

        plt.scatter(x, y, c=color, label=name)

    plt.title(f"epoch {i:02}", fontsize=24)
    plt.legend(loc="upper right")
    plt.savefig(cur_path)
    plt.close()
