# AlexNet of Fashion-MNIST

In [1]:
import cv2
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
show=ToPILImage()

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))

GeForce RTX 2080 Ti


In [4]:
batch_size = 512
resize = 32
mean = (0.1307,)
std = (0.3081,)

In [5]:
def imshow(img):
    img_clone = img.clone().cpu()
    img_clone = img_clone / 2 + 0.5
    npimg = img_clone.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

## Load data

In [6]:
transform = transforms.Compose([transforms.Resize((resize, resize)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean, std)])

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                             download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                            download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

In [7]:
classes = {0: "T-shirt/top",
           1: "Trouser",
           2: "Pullover",
           3: "Dress",
           4: "Coat",
           5: "Sandal",
           6: "Shirt",
           7: "Sneaker",
           8: "Bag",
           9: "Ankle boot"}

## Network

In [8]:
class AlexNet(nn.Module):

    def __init__(self, num_classes=10, input_channels=1):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=11, stride=4, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [9]:
net = AlexNet().cuda()
print(net)

AlexNet(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(5, 5))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Linear(in_features=256, out_features=10, bias=True)
)


## Training

In [10]:
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(net.parameters(),lr=0.001, weight_decay=5e-4)

In [11]:
debug_mode = False

In [12]:
num_epochs = 100

for epoch in range(num_epochs):
    start = time.time()
    running_loss = 0
    for i, data in enumerate(trainloader):
        images, labels = data

        images = images.cuda()
        labels = labels.cuda()

        if debug_mode:
            imshow(torchvision.utils.make_grid(images))
            plt.show()
            print([classes[lab] for lab in labels.clone().cpu().numpy()])
        
        optimizer.zero_grad()

        outputs = net(images)
        
        if debug_mode:
            print(images.shape)
            print(outputs)
            
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss
        
    end = time.time()
    print(f'Epoch: {epoch} | loss: {running_loss:0.7f} | time: {end - start:0.3f} s')

Epoch: 0 | loss: 92.9589996 | time: 5.999 s
Epoch: 1 | loss: 47.4661636 | time: 5.957 s
Epoch: 2 | loss: 40.0920067 | time: 5.978 s
Epoch: 3 | loss: 36.3879890 | time: 5.955 s
Epoch: 4 | loss: 33.0154381 | time: 5.967 s
Epoch: 5 | loss: 31.0527153 | time: 6.009 s
Epoch: 6 | loss: 29.1672039 | time: 5.982 s
Epoch: 7 | loss: 28.9566517 | time: 5.888 s
Epoch: 8 | loss: 26.1431255 | time: 6.032 s
Epoch: 9 | loss: 25.2694702 | time: 5.986 s
Epoch: 10 | loss: 23.3023987 | time: 5.964 s
Epoch: 11 | loss: 22.4270554 | time: 5.930 s
Epoch: 12 | loss: 21.9857197 | time: 5.900 s
Epoch: 13 | loss: 20.6792336 | time: 5.941 s
Epoch: 14 | loss: 19.7683334 | time: 5.962 s
Epoch: 15 | loss: 19.4387627 | time: 6.044 s
Epoch: 16 | loss: 18.3228951 | time: 6.153 s
Epoch: 17 | loss: 17.0260353 | time: 6.207 s
Epoch: 18 | loss: 16.7188644 | time: 6.064 s
Epoch: 19 | loss: 16.3051643 | time: 6.006 s
Epoch: 20 | loss: 14.6625071 | time: 5.987 s
Epoch: 21 | loss: 14.0410147 | time: 6.012 s
Epoch: 22 | loss: 13

## Test

In [13]:
net.eval()

correct=0
total=0
for data in testloader:
    images,labels = data
    images = images.cuda()
    labels = labels.cuda()
    outputs = net(images)
    _, predicted = torch.max(outputs, dim=1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

In [14]:
print(f'Accuracy of the network on the {total} test images: {correct.to(dtype=torch.float) / float(total):.4%}')

Accuracy of the network on the 10000 test images: 89.1000%
