# Fashion_MNIST by ResNet18
##### written by LarryHYQ

## TRAIN

### Import some necessary packages

In [19]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
import IPython.display as display
%matplotlib inline
from matplotlib import pyplot as plt

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

### Setting some Paeameter

In [20]:
lr = 0.001  # learning rate
batch_size_train = 600  # batch_size for train
batch_size_test = 1000  # batch_size for test
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
all_epochs = 200
writer = SummaryWriter("runs/Fashion_MNIST_ResNet_experiment")

### Setting GPU or CPU

In [21]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Build the ResNet

In [22]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

### Preparing Data

In [23]:
def prepare_data():
    global trainset, trainloader, testset, testloader
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    trainset = torchvision.datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size_train, shuffle=True, num_workers=2)

    testset = torchvision.datasets.FashionMNIST(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size_test, shuffle=False, num_workers=2)

In [24]:
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

### Setting the Model

In [25]:
def set_model():
    global net, optimizer, criterion, start_epoch, best_acc, train_loss_list, test_loss_list, train_acc_list, test_acc_list
    print('==> Building model..')
    net = ResNet50()
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    resume_flag = input("Whether to continue training?(Y/N)\n")

    if (resume_flag == "Y") or (resume_flag == "y"):
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        last_checkpoint = torch.load('./checkpoint/fashionmnist_resnet_last.pth')
        best_checkpoint = torch.load('./checkpoint/fashionmnist_resnet_best.pth')
        net.load_state_dict(last_checkpoint['net'])
        best_acc = best_checkpoint['acc']
        start_epoch = last_checkpoint['epoch'] + 1

    criterion = F.cross_entropy
    optimizer = optim.Adam(net.parameters(), lr=lr)

### Plot in Tensorboard

In [26]:
def plot_tb():
    images, labels = next(iter(trainloader))
    grid = torchvision.utils.make_grid(images)
    writer.add_image("image", grid)
    writer.add_graph(net, images)
    writer.close()

### Training

In [27]:
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader), desc=f'TRAIN Epoch {epoch}',total=len(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        pbar.set_postfix({"Loss":f"{(train_loss/(batch_idx+1)):.3f}","Acc":f"{100*correct/total:.3f} ({correct}/{total})"})
    writer.add_scalar("TRAIN/Loss", (train_loss/(batch_idx+1)), epoch)
    writer.add_scalar("TRAIN/Acuracy", (correct/total), epoch)
    writer.close()

### Testing

In [28]:
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(testloader), desc=f'TEST Epoch {epoch}',total=len(testloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            pbar.set_postfix({"Loss":f"{(test_loss/(batch_idx+1)):.3f}","Acc":f"{100*correct/total:.3f} ({correct}/{total})"})
        writer.add_scalar("TEST/Loss", (test_loss/(batch_idx+1)), epoch)
        writer.add_scalar("TEST/Accuracy", (correct/total), epoch)   
        writer.close()
        
    # Save checkpoint.
    acc = 100.*correct/total
    print('Saving the last..')
    state = {
        'net': net.state_dict(),
        'acc': acc,
        'epoch': epoch,
    }

    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/fashionmnist_resnet_last.pth')
    print("The last Acc:" + str(acc))
    if acc > best_acc:
        print('Saving the best..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
 
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/fashionmnist_resnet_best.pth')
        best_acc = acc
    print("The best Acc:" + str(best_acc))

### Main

In [29]:
prepare_data()
set_model()
plot_tb()

In [30]:
for epoch in range(start_epoch, start_epoch + all_epochs):
    train(epoch)
    test(epoch)
    display.clear_output(wait=True)