In [None]:
%matplotlib inline

import numpy as np
import torch
import matplotlib.pyplot as plt

In [None]:
from torchvision import datasets
import torchvision.transforms as transforms

num_workers = 0

batch_size = 64

transform = transforms.ToTensor()

train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transform)
test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = num_workers)

In [None]:
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()

img = np.squeeze(images[0])
fig = plt.figure(figsize = (3, 3))
ax = fig.add_subplot(111)
ax.imshow(img, cmap = 'gray')

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class NeuralNet(nn.Module):
  def __init__(self, use_batch_norm, input_size = 784, hidden_dim = 256, output_size = 10):
    super(NeuralNet, self).__init__()

    self.input_size = input_size
    self.hidden_dim = hidden_dim
    self.output_size = output_size

    self.use_batch_norm = use_batch_norm

    if use_batch_norm:
      self.fc1 = nn.Linear(input_size, hidden_dim*2, bias = False)
      self.batch_norm1 = nn.BatchNorm1d(hidden_dim*2)
    else:
      self.fc1 = nn.Linear(input_size, hidden_dim*2)

    if use_batch_norm:
      self.fc2 = nn.Linear(input_size, hidden_dim*2, bias = False)
      self.batch_norm2 = nn.BatchNorm1d(hidden_dim*2)
    else:
      self.fc2 = nn.Linear(input_size, hidden_dim*2)

    self.fc3 = nn.Linear(hidden_dim, output_size)

  def forward(self, x):
    x = x.view(-1, 28*28)
    x = self.fc1(x)
    if self.use_batch_norm:
      x = self.batch_norm1(x)
    x = F.relu(x)

    x = self.fc2(x)
    if self.use_batch_norm:
      x = self.batch_norm2(x)
    x = F.relu(x)

    x = self.fc3(x)
    return x

In [None]:
net_batchnorm = NeuralNet(use_batch_norm=True)
net_no_norm = NeuralNet(use_batch_norm=False)

print(net_batchnorm)
print()
print(net_no_norm)

In [None]:
def train(model, n_epochs=10):
  n_epochs = n_epochs
  losses = []

  criterion = nn.CrossEntropyLoss()

  optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

  model.train()

  for epoch in range(1, n_epochs+1):
    train_loss = 0.0

    batch_count = 0
    for batch_idx, (data, target) in enumerate(train_loader):
      optimizer.zero_grad()
      output = model(data)

      loss = criterion(output, target)

      loss.backward()

      optimizer.step()

      train_loss += loss.item()
      batch_count += 1

    losses.append(train_loss/batch_count)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(
        epoch,
        train_loss/batch_count))

  return losses

In [None]:
losses_batchnorm = train(net_batchnorm)

losses_no_norm = train(net_no_norm)

In [None]:
fig, ax = plt.subplots(figsize = (12,8))

plt.plot(losses_batchnorm, label = 'Using batchnorm', alpha = 0.5)
plt.plot(losses_no_norm, label = 'No norm', alpha = 0.5)
plt.title("Training Losses")
plt.legend()

In [None]:
def test(model, train):
  class_correct = list(0. for i in range(10))
  class_total = list(0. for i in range(10))
  test_loss = 0.0

  if(train == True):
    model.train()
  if(train == False):
    model.eval()

  criterion = nn.CrossEntropyLoss()

  for batch_idx, (data, target) in enumerate(test_loader):
    batch_size = data.size(0)
    output = model(data)
    loss = criterion(output, target)
    test_loss += loss.item()*batch_size
    _, pred = torch.max(output, 1)
    correct = np.squeeze(pred.eq(target.data.view_as(pred)))
    for i in range(batch_size):
      label = target.data[i]
      class_correct[label] += correct[i].item()
      class_total[label] += 1

  print('Test Loss: {:.6f}\n'.format(test_loss/len(test_loader.dataset)))

  for i in range(10):
    if class_total[i] > 0:
      print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
          str(i), 100 * class_correct[i] / class_total[i],
          np.sum(class_correct[i]), np.sum(class_total[i])))
    else:
      print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))
  print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
      100. * np.sum(class_correct) / np.sum(class_total),
      np.sum(class_correct), np.sum(class_total)))

In [None]:
test(net_batchnorm, train=True)

In [None]:
test(net_batchnorm, train=False)

In [None]:
test(net_no_norm, train = False)