In [None]:
import torch
import random
import numpy as np
import torchvision.datasets
import gc


random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.set_default_dtype(torch.float16)

CIFAR_train = torchvision.datasets.CIFAR10('./', download=True, train=True)
CIFAR_test = torchvision.datasets.CIFAR10('./', download=True, train=False)


X_train = torch.HalfTensor(CIFAR_train.data)
y_train = torch.LongTensor(CIFAR_train.targets)
X_test = torch.HalfTensor(CIFAR_test.data)
y_test = torch.LongTensor(CIFAR_test.targets)


X_train /= 255.
X_test /= 255.


X_train = X_train.permute(0, 3, 1, 2)
X_test = X_test.permute(0, 3, 1, 2)


_DEVICE = 'cuda:0'


class _NetworkTrain:

  def __init__(
      self,
      net,
      loss = torch.nn.CrossEntropyLoss,
      optimizer = torch.optim.Adam,
      loss_kwargs: dict[str, object] = {},
      optimizer_kwargs: dict[str, object] = dict(lr=1.0e-3)
  ):
    self.net = net
    self.net.to(_DEVICE)
    self.loss = loss(**loss_kwargs)
    self.optimizer = optimizer(params=self.net.parameters(), **optimizer_kwargs)

  @staticmethod
  def reset_mem():
    gc.collect()
    torch.cuda.empty_cache()

  @property
  def name(self):
    return type(self.net).__name__

  def __call__(self, X_train, y_train, X_test, y_test, n_epochs = 30, batch_size = 100):
    self.reset_mem()
    test_accuracy_history = []
    test_loss_history = []

    X_test = X_test.to(_DEVICE)
    y_test = y_test.to(_DEVICE)

    for epoch in range(n_epochs):
        order = np.random.permutation(len(X_train))
        for start_index in range(0, len(X_train), batch_size):
            self.optimizer.zero_grad()
            self.net.train()

            batch_indexes = order[start_index:start_index+batch_size]

            X_batch = X_train[batch_indexes].to(_DEVICE)
            y_batch = y_train[batch_indexes].to(_DEVICE)

            preds = self.net.forward(X_batch)

            loss_value = self.loss(preds, y_batch)
            loss_value.backward()

            self.optimizer.step()

        self.net.eval()
        test_preds = self.net.forward(X_test)
        test_loss_history.append(self.loss(test_preds, y_test).data.cpu())

        accuracy = (test_preds.argmax(dim=1) == y_test).float().mean().data.cpu()
        test_accuracy_history.append(accuracy)

        print(f"epoch={epoch} accuracy={accuracy.float()}")

    return test_loss_history, test_accuracy_history


In [None]:
class ResLayer(torch.nn.Module):
  def __init__(self, channels, activation=torch.nn.ReLU, use_batch_norm=True):
    torch.nn.Module.__init__(self)

    batch_norm_layer = lambda:\
      torch.nn.BatchNorm2d(num_features=channels) if use_batch_norm else torch.nn.Identity()
      
    self.layers = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
        batch_norm_layer(),
        torch.nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
        batch_norm_layer(),
        activation(),
    )

  def forward(self, x):
    return self.layers(x) + x

class ResNetXXX(torch.nn.Module):
  def __init__(self, **kwargs):
    torch.nn.Module.__init__(self)
    self.layers = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=3, out_channels=8, kernel_size=7, padding=3),
        ResLayer(8, **kwargs),
        torch.nn.MaxPool2d(2),
        torch.nn.Conv2d(in_channels=8, out_channels=16, kernel_size=1),
        ResLayer(16, **kwargs),
        torch.nn.MaxPool2d(2),
        torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1),
        ResLayer(32, **kwargs),
        torch.nn.MaxPool2d(2),
        torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1),
        ResLayer(64, **kwargs),
        torch.nn.MaxPool2d(2),
        torch.nn.Flatten(),
        torch.nn.Linear(4*64,100),
        torch.nn.ReLU(),
        torch.nn.Linear(100,10)
    )

  def forward(self, x):
    return self.layers(x)

In [None]:
from torchvision.models import resnet18
# del net
# net = ResNetXXX(use_batch_norm=True)
train = _NetworkTrain(net, optimizer=torch.optim.Adam, optimizer_kwargs=dict(lr=1e-2))

In [None]:
accuracies = {}
losses = {}

losses[train.name], accuracies[train.name] = train(X_train, y_train, X_test, y_test)

In [None]:
import matplotlib.pyplot as plt
for experiment_id, accuracy in accuracies.items():
    plt.plot(accuracy, label=experiment_id)
plt.legend()
plt.title('Validation Accuracy');