In [None]:
import pickle as pkl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.001

log_interval = 10
dataset = 'SVHN'

if dataset == 'MNIST':
    train_loader = torch.utils.data.DataLoader(
      torchvision.datasets.MNIST('./files/', train=True, download=True,
                                 transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize(
                                     (0.1307,), (0.3081,))
                                 ])),
      batch_size=batch_size_train, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
      torchvision.datasets.MNIST('./files/', train=False, download=True,
                                 transform=torchvision.transforms.Compose([
                                 torchvision.transforms.ToTensor(),
                                 torchvision.transforms.Normalize(
                                     (0.1307,), (0.3081,))
                                 ])),
      batch_size=batch_size_test, shuffle=False)
    img_size = [1, 28, 28]
elif dataset == 'SVHN':
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.SVHN(
            './files/', split='train', download=True,
            transform=torchvision.transforms.Compose(
                [torchvision.transforms.Resize((16, 16)),
                 torchvision.transforms.ToTensor(),
                 torchvision.transforms.Normalize((0.4519, ), (0.1919 )),
                 ])),
        batch_size=batch_size_train, shuffle=True, pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.SVHN(
            './files/', split='test', download=True,
            transform=torchvision.transforms.Compose(
                [torchvision.transforms.Resize((16, 16)),
                 torchvision.transforms.ToTensor(),
                 torchvision.transforms.Normalize((0.4519, ), (0.1919 )),
                 ])),
        batch_size=batch_size_train, shuffle=False, pin_memory=True,
    )
    img_size = [3, 16, 16]
else:
    assert False

In [None]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)

import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
fig

In [None]:
class Net(nn.Module):
    def __init__(self, size=10):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(img_size[0] * img_size[1] * img_size[2], size)
        self.fc2 = nn.Linear(size, 10)

    def forward(self, x):
        assert x.shape[1:] == torch.Size(img_size)
        b = x.shape[0]
        x = x.view(b, img_size[0] * img_size[1] * img_size[2])
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
def train_test(size=10, epochs=20, step=5, weight_decay=0.0001):
    network = Net(size=size)
    network_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        network.parameters(), 
        lr=learning_rate,
        weight_decay=weight_decay)
    
    def train():
      network.train()
      train_loss = 0
      correct = 0
      for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = network_loss(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()

      train_loss /= len(train_loader.dataset)
      print('\nTrain set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))       

    def test():
      network.eval()
      test_loss = 0
      correct = 0
      with torch.no_grad():
        for data, target in test_loader:
          output = network(data)
          test_loss += network_loss(output, target).item()
          pred = output.data.max(1, keepdim=True)[1]
          correct += pred.eq(target.data.view_as(pred)).sum()
      test_loss /= len(test_loader.dataset)
      print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
          
    test()
    for epoch in range(1, epochs + 1):
      train()
      test()       
      if epoch % step == 0:
        weights = {x: y.cpu().numpy() for x,y in network.state_dict().items()} 
        with open(f"weights/{dataset}_{size}_{epoch}_{weight_decay}.pkl", "wb") as file:
            pkl.dump(weights, file)

In [None]:
for size in range(5, 30, 5):
    print(f"Size: {size}")
    train_test(size=size, epochs=40, step=10, weight_decay=0.0001)