In [14]:
skip_training = False

In [15]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import tools

In [16]:
data_dir = tools.select_data_dir()

The data directory is ../data


In [17]:
device = torch.device('cpu')
# device = torch.device('cuda:0')

In [18]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # Scale images to [-1, 1]
])

trainset = torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform)

classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
           'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

In [19]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        """
        Args:
          in_channels (int):  Number of input channels.
          out_channels (int): Number of output channels.
          stride (int):       Controls the stride.
        """
        super(Block, self).__init__()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels))
        else:
            self.skip = None

        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels))

    def forward(self, x):
        
        out = self.block(x)
        if self.skip is not None:
            out = out + self.skip(x)
        else:
            out = out + x
        out = F.relu(out)
        return out

In [20]:
def test_Block_shapes():

    # The number of channels and resolution do not change
    batch_size = 20
    x = torch.zeros(batch_size, 16, 28, 28)
    block = Block(in_channels=16, out_channels=16)
    y = block(x)
    assert y.shape == torch.Size([batch_size, 16, 28, 28]), "Bad shape of y: y.shape={}".format(y.shape)

    # Increase the number of channels
    block = Block(in_channels=16, out_channels=32)
    y = block(x)
    assert y.shape == torch.Size([batch_size, 32, 28, 28]), "Bad shape of y: y.shape={}".format(y.shape)

    # Decrease the resolution
    block = Block(in_channels=16, out_channels=16, stride=2)
    y = block(x)
    assert y.shape == torch.Size([batch_size, 16, 14, 14]), "Bad shape of y: y.shape={}".format(y.shape)

    # Increase the number of channels and decrease the resolution
    block = Block(in_channels=16, out_channels=32, stride=2)
    y = block(x)
    assert y.shape == torch.Size([batch_size, 32, 14, 14]), "Bad shape of y: y.shape={}".format(y.shape)

    print('Success')

test_Block_shapes()

Success


In [21]:
# Group of blocks
class GroupOfBlocks(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks, stride=1):
        super(GroupOfBlocks, self).__init__()

        first_block = Block(in_channels, out_channels, stride)
        other_blocks = [Block(out_channels, out_channels) for _ in range(1, n_blocks)]
        self.group = nn.Sequential(first_block, *other_blocks)

    def forward(self, x):
        return self.group(x)

In [22]:
class ResNet(nn.Module):
    def __init__(self, n_blocks, n_channels=64, num_classes=10):
        """
        Args:
          n_blocks (list):   A list with three elements which contains the number of blocks in 
                             each of the three groups of blocks in ResNet.
                             For instance, n_blocks = [2, 4, 6] means that the first group has two blocks,
                             the second group has four blocks and the third one has six blocks.
          n_channels (int):  Number of channels in the first group of blocks.
          num_classes (int): Number of classes.
        """
        assert len(n_blocks) == 3, "The number of groups should be three."
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_channels, kernel_size=5, stride=1, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(n_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.group1 = GroupOfBlocks(n_channels, n_channels, n_blocks[0])
        self.group2 = GroupOfBlocks(n_channels, 2*n_channels, n_blocks[1], stride=2)
        self.group3 = GroupOfBlocks(2*n_channels, 4*n_channels, n_blocks[2], stride=2)

        self.avgpool = nn.AvgPool2d(kernel_size=4, stride=1)
        self.fc = nn.Linear(4*n_channels, num_classes)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x, verbose=False):
        """
        Args:
          x of shape (batch_size, 1, 28, 28): Input images.
          verbose: True if you want to print the shapes of the intermediate variables.
        
        Returns:
          y of shape (batch_size, 10): Outputs of the network.
        """
        if verbose: print(x.shape)
        x = self.conv1(x)
        if verbose: print('conv1:  ', x.shape)
        x = self.bn1(x)
        if verbose: print('bn1:    ', x.shape)
        x = self.relu(x)
        if verbose: print('relu:   ', x.shape)
        x = self.maxpool(x)
        if verbose: print('maxpool:', x.shape)

        x = self.group1(x)
        if verbose: print('group1: ', x.shape)
        x = self.group2(x)
        if verbose: print('group2: ', x.shape)
        x = self.group3(x)
        if verbose: print('group3: ', x.shape)

        x = self.avgpool(x)
        if verbose: print('avgpool:', x.shape)

        x = x.view(-1, self.fc.in_features)
        if verbose: print('x.view: ', x.shape)
        x = self.fc(x)
        if verbose: print('out:    ', x.shape)

        return x

In [23]:
def test_ResNet_shapes():
    # Create a network with 2 block in each of the three groups
    n_blocks = [2, 2, 2]  # number of blocks in the three groups
    net = ResNet(n_blocks, n_channels=10)
    net.to(device)

    # Feed a batch of images from the training data to test the network
    with torch.no_grad():
        images, labels = iter(trainloader).next()
        images = images.to(device)
        print('Shape of the input tensor:', images.shape)

        y = net.forward(images, verbose=True)
        print(y.shape)
        assert y.shape == torch.Size([trainloader.batch_size, 10]), "Bad shape of y: y.shape={}".format(y.shape)

    print('Success')

test_ResNet_shapes()

Shape of the input tensor: torch.Size([32, 1, 28, 28])
torch.Size([32, 1, 28, 28])
conv1:   torch.Size([32, 10, 28, 28])
bn1:     torch.Size([32, 10, 28, 28])
relu:    torch.Size([32, 10, 28, 28])
maxpool: torch.Size([32, 10, 14, 14])
group1:  torch.Size([32, 10, 14, 14])
group2:  torch.Size([32, 20, 7, 7])
group3:  torch.Size([32, 40, 4, 4])
avgpool: torch.Size([32, 40, 1, 1])
x.view:  torch.Size([32, 40])
out:     torch.Size([32, 10])
torch.Size([32, 10])
Success


In [24]:
# This function computes the accuracy on the test dataset
def compute_accuracy(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [25]:
# Create the network
n_blocks = [2, 2, 2]  # number of blocks in the three groups
net = ResNet(n_blocks, n_channels=16)
net.to(device)

ResNet(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (group1): GroupOfBlocks(
    (group): Sequential(
      (0): Block(
        (block): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Block(
        (block): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, m

In [26]:
# Training Loop
if not skip_training:
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    epochs = 10
    total_loss = 0.0
    
    for epoch in range(epochs):
        print('Epoch: {}'.format(epoch))
        for idx, (train_x, train_label) in enumerate(trainloader):
            optimizer.zero_grad()
            predict_y = net(train_x.float())
            loss = criterion(predict_y, train_label.long())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if idx % 100 == 0:
                print('idx: {}, loss: {}'.format(idx, loss))

Epoch: 0
idx: 0, loss: 2.332498550415039
idx: 100, loss: 0.7401666641235352
idx: 200, loss: 0.3168022930622101
idx: 300, loss: 0.6699246764183044
idx: 400, loss: 0.3960414230823517
idx: 500, loss: 0.38464730978012085
idx: 600, loss: 0.2812509536743164
idx: 700, loss: 0.3503393828868866
idx: 800, loss: 0.32765451073646545
idx: 900, loss: 0.4673905372619629
idx: 1000, loss: 0.18668022751808167
idx: 1100, loss: 0.27778181433677673
idx: 1200, loss: 0.42161211371421814
idx: 1300, loss: 0.44894370436668396
idx: 1400, loss: 0.1765088438987732
idx: 1500, loss: 0.34054359793663025
idx: 1600, loss: 0.4407752752304077
idx: 1700, loss: 0.2971867024898529
idx: 1800, loss: 0.19874830543994904
Epoch: 1
idx: 0, loss: 0.3573720455169678
idx: 100, loss: 0.29730889201164246
idx: 200, loss: 0.3870174288749695
idx: 300, loss: 0.18340453505516052
idx: 400, loss: 0.4064258635044098
idx: 500, loss: 0.5987681150436401
idx: 600, loss: 0.3327520489692688
idx: 700, loss: 0.3116636276245117
idx: 800, loss: 0.18396

In [27]:
# Save the model to disk
if not skip_training:
    tools.save_model(net, '3_resnet.pth')
else:
    net = ResNet(n_blocks, n_channels=16)
    tools.load_model(net, '3_resnet.pth', device)

Do you want to save the model (type yes to confirm)? yes
Model saved to 3_resnet.pth.


In [28]:
# Compute the accuracy on the test set
accuracy = compute_accuracy(net, testloader)
print('Accuracy of the network on the test images: %.3f' % accuracy)
n_blocks = sum(type(m) == Block for _, m in net.named_modules())
assert n_blocks == 6, f"Wrong number ({n_blocks}) of blocks used in the network."

assert accuracy > 0.9, "Poor accuracy ({:.3f})".format(accuracy)
print('Success')

Accuracy of the network on the test images: 0.917
Success
