# Pytorch - fashionMNIST

### Setup notebook

In [None]:
# from __future__ import print_function, division
# import argparse
# import itertools
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
# import pandas as pd
# from pathlib import Path
# import os
# from PIL import Image
# import random
# import shutil
# import sys
import time


import torch 
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
#from torchvision import datasets
import torchvision.models as models
import torchvision.transforms as transforms

from mymods.lauthom import *

# plot inline
%matplotlib inline

# set seeds for reproduction
np.random.seed(0)
torch.manual_seed(0)

# interactive mode on
plt.ion()

### Data

In [None]:
get_path('*', 'fashionmnist/*')

In [None]:
# Check data directory
data_dir = '../../_data/fashionmnist'

from subprocess import check_output
print(check_output(["ls", data_dir]).decode("utf8"))
# Any results you write to the current directory are saved as output.

#### Transformations

In [None]:
import torchvision.transforms as transforms

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(5),
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

#### Loaders

In [None]:
M_BATCH = 8
WORKERS = 0 # disable multicore error message
phases = ['train', 'val']

# download to processed and raw folders
image_datasets = {x: datasets.FashionMNIST(data_dir, 
                                           download=True, 
                                           transform=data_transforms[x]) 
                  for x in phases}

dataloaders = {x: data.DataLoader(image_datasets[x], 
                                  batch_size=M_BATCH,
                                  shuffle=True, 
                                  num_workers=WORKERS)
              for x in phases}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

print(image_datasets['train'])
dataloaders['train'].dataset

In [None]:
print(next(iter(dataloaders['train']))[1])

#### Labels

In [None]:
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

#### Visualise data

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0)) # convert to np
    plt.imshow(inp)
    plt.axis('off')
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
plt.figure(figsize=(20,5))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

### Create model

In [None]:
class FashionMnistNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1), 
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 64 * 7 * 7) # x.view == x.reshape
        x = self.classifier(x)
        return x

In [None]:
EPOCHS = 4
net = FashionMnistNet()
print(net)

### Loss and optimiser

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

### Train and validate model

#### Print stats - helper functions

In [None]:
def time_format(secs):
    """Convert seconds to h:mm:ss"""
    m, s = divmod(secs, 60)
    h, m = divmod(m, 60)
    return "%d:%02d:%02d" % (h, m, s)


def print_header():
    """Print header"""
    h_template = """{:8}\t\t {:8}\t\t    {:12}\t {:8}\t\t {:8}"""
    print()
    print(h_template.format('Phase', 'Epoch', 'Loss', 'Accurracy', 'Duration'))
            

def print_stat(phase, epoch, loss, acc, duration):
    """"""
    p_template = """{:8}\t\t {:8}\t\t {:8.4f}\t\t    {:8.1f}\t\t {:8}"""
    print(p_template.format(phase, epoch, loss, acc*100, time_format(duration)))

In [None]:
def train_val(net, loader, scheduler, criterion, optimizer, phase):
    net.train(phase == 'train')
    running_loss = 0
    running_accuracy = 0
    count = 0
    
    for i, (X, y) in enumerate(loader):
        X, y = Variable(X, requires_grad=(phase=='train')), Variable(y)
        
        output = net(X)
        loss = criterion(output, y)
        
        if phase == 'train':
            #scheduler.step()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        running_loss += loss.data.item() #FJE [0]
        pred = output.data.max(1, keepdim=True)[1]
        pred = pred.squeeze().numpy()
        y = y.squeeze().numpy()
        running_accuracy += sum(pred==y)
        count += len(y)
        
    return running_loss/count, running_accuracy/count #.dataset)

In [None]:
for epoch in range(EPOCHS):
    start = time.time()
    print_header()

    for phase in ['train', 'val']:
        loss, acc = train_val(net, dataloaders[phase], scheduler, criterion, optimizer, phase)
    
        if phase == 'val':
            scheduler.step(loss)

        end = time.time()
        print_stat(phase, epoch, loss, acc, end-start)