## Introduction to Image Processing with Python
### Image Processing (RM1-VIS)
### University of Southern Denmark

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import skimage

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

Implement a convolutional neural network for MNIST digit classification using the PyTorch framework.
* Install with "pip install torchvision". Use a virtual environment.
* Let PyTorch download the full MNIST dataset suitable for its' DataLoader structure.
* The training loop should perform the steps outlined in Table 12.6.
* A good starting point for the network structure is given in Example 12.17. 
* Use F.softmax() in the output layer to normalize outputs to be between 0 and 1.
* Optional: You may be able to get faster and more stable training by adding nn.Dropout() and nn.BatchNorm2d() layers.

In [None]:
# Define the CNN network here. Starting point Example 12.17.
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # allocate the network transforms that you will need in the forward method. E.g. "self.conv1 = nn.Conv2d(1, 6, 5)"
        
    def forward(self, x):

        # call the network layer objects...

        return x


In [None]:
# Choose device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device, '\n')

model = CNN()
model.to(device)

# MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True,
                               download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=50, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False,
                              download=True, transform=ToTensor())
test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False)

# criterion = nn.CrossEntropyLoss()
loss_fun = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.010)

# Train the network
for epoch in range(10):
    correct_train = 0
    total_train = 0

    model.train()  # training mode
    print('epoch #{}'.format(epoch+1))
    for i, (inputs, labels) in enumerate(train_loader):
        inputs.to(device), labels.to(device)

        # set gradient's parameter to zero
        optimizer.zero_grad()

        # forward pass & back propagation
        outputs = model(inputs)
        loss = loss_fun(outputs, labels)
        loss.backward()
        optimizer.step()

        # accuracy calculation
        _, predicted = torch.max(outputs.data, 1)
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)

    print('Training Accuracy: %0.3f' % ((100 * correct_train) / total_train))


# Test #
model.eval()
correct_test = 0
total_test = 0

with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        inputs.to(device), labels.to(device)

        outputs = model(inputs)
        loss = loss_fun(outputs, labels)
    
        # test accuracy calculation
        _, predicted = torch.max(outputs.data, 1)
        correct_test += (predicted == labels).sum().item()
        total_test += labels.size(0)

print('\nTest Accuracy: %0.3f' % ((100 * correct_test) / total_test))

In [None]:
# Visualize example results