# AlexNet of Fashion-MNIST

In [1]:
import cv2
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
show=ToPILImage()

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
torch.cuda.get_device_name(0)

'GeForce GTX 1080 Ti'

In [4]:
batch_size = 32 #512
resize = 32
mean = (0.1307,)
std = (0.3081,)

In [5]:
def imshow(img):
    img_clone = img.clone().cpu()
    img_clone = img_clone / 2 + 0.5
    npimg = img_clone.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

## Load data

In [6]:
transform = transforms.Compose([transforms.Resize((resize, resize)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean, std)])

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                             download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                            download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

In [7]:
classes = {0: "T-shirt/top",
           1: "Trouser",
           2: "Pullover",
           3: "Dress",
           4: "Coat",
           5: "Sandal",
           6: "Shirt",
           7: "Sneaker",
           8: "Bag",
           9: "Ankle boot"}

## Network

In [8]:
class AlexNet(nn.Module):

    def __init__(self, num_classes=10, input_channels=1):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=11, stride=4, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [9]:
net = AlexNet().cuda()
print(net)

AlexNet(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(5, 5))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Linear(in_features=256, out_features=10, bias=True)
)


## Training

In [10]:
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(net.parameters(),lr=0.001, weight_decay=5e-4)

In [11]:
debug_mode = False

In [12]:
num_epochs = 100

for epoch in range(num_epochs):
    start = time.time()
    running_loss = 0
    for i, data in enumerate(trainloader):
        images, labels = data

        images = images.cuda()
        labels = labels.cuda()

        if debug_mode:
            imshow(torchvision.utils.make_grid(images))
            plt.show()
            print([classes[lab] for lab in labels.clone().cpu().numpy()])
        
        optimizer.zero_grad()

        outputs = net(images)
        
        if debug_mode:
            print(images.shape)
            print(outputs)
            
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss
        
    end = time.time()
    print(f'Epoch: {epoch} | loss: {running_loss:0.7f} | time: {end - start:0.3f} s')

Epoch: 0 | loss: 968.8456421 | time: 28.950 s
Epoch: 1 | loss: 658.5689697 | time: 28.312 s
Epoch: 2 | loss: 590.6651001 | time: 28.398 s
Epoch: 3 | loss: 545.9324951 | time: 28.494 s
Epoch: 4 | loss: 519.1873169 | time: 28.524 s
Epoch: 5 | loss: 496.8948669 | time: 28.171 s
Epoch: 6 | loss: 475.8588257 | time: 28.318 s
Epoch: 7 | loss: 462.8978271 | time: 28.262 s
Epoch: 8 | loss: 448.2656860 | time: 28.296 s
Epoch: 9 | loss: 441.5274048 | time: 28.209 s
Epoch: 10 | loss: 431.1635437 | time: 28.318 s
Epoch: 11 | loss: 426.8786316 | time: 28.433 s
Epoch: 12 | loss: 415.0750732 | time: 28.239 s
Epoch: 13 | loss: 409.8858948 | time: 28.173 s
Epoch: 14 | loss: 398.1006775 | time: 28.371 s
Epoch: 15 | loss: 394.2907104 | time: 28.322 s
Epoch: 16 | loss: 389.6551208 | time: 28.241 s
Epoch: 17 | loss: 383.8843994 | time: 28.348 s
Epoch: 18 | loss: 378.9812012 | time: 28.282 s
Epoch: 19 | loss: 372.3363342 | time: 28.289 s
Epoch: 20 | loss: 368.4988708 | time: 28.330 s
Epoch: 21 | loss: 371.4

## Test

In [13]:
net.eval()

correct=0
total=0
for data in testloader:
    images,labels = data
    images = images.cuda()
    labels = labels.cuda()
    outputs = net(images)
    _, predicted = torch.max(outputs, dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

In [14]:
print(f'Accuracy of the network on the {total} test images: {correct.to(dtype=torch.float) / float(total):.4%}')

Accuracy of the network on the 10000 test images: 88.8500%
