# Multi-class Classification


In [1]:
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np

## Load fashing-MNIST dataset
An MNIST-like dataset of 70,000 28x28 labeled fashion images

https://github.com/zalandoresearch/fashion-mnist

![Samples](https://raw.githubusercontent.com/zalandoresearch/fashion-mnist/master/doc/img/fashion-mnist-sprite.png)

In [2]:
batch_size = 32
lr = 0.01

train_fashion_mnist = dset.FashionMNIST(root='./', train=True, download=True, transform=transforms.ToTensor())
train_data_loader = torch.utils.data.DataLoader(train_fashion_mnist, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

test_fashion_mnist = dset.FashionMNIST(root='./', train=False, download=True, transform=transforms.ToTensor())
test_data_loader = torch.utils.data.DataLoader(test_fashion_mnist, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./FashionMNIST/raw/train-images-idx3-ubyte.gz to ./FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Initialize Model Parameters

In [8]:
num_inputs, num_outputs = 784, 10

W = torch.randn(num_inputs, num_outputs)
b = torch.zeros(num_outputs)
params = [W, b]

## Define the classifier

In [4]:
def linear(X):
  return torch.mm(X, W) + b

def softmax(z):
  return torch.exp(z) / torch.sum(torch.exp(z), axis=1).view(-1,1)

def to_classlabel(z):
    return z.argmax(axis=1)

## Define the loss function

In [5]:
def cross_entropy(output, y_target):
    return - torch.sum(torch.log(output) * (y_target), axis=1)

## Train a model

In [9]:
for epoch in range(10):
  epoch_loss = []
  for i_batch, sample_batched in enumerate(train_data_loader):
    inputs, labels = sample_batched
    inputs = inputs.reshape(-1, num_inputs)
    one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=10)
    pre_softmax = linear(inputs)
    prob_distr = softmax(pre_softmax)
    loss = cross_entropy(prob_distr, one_hot_labels)
    batch_loss = torch.mean(loss)
    
    epoch_loss.append(batch_loss)

    dscores = (prob_distr - one_hot_labels)
    dW = inputs.T.mm(dscores)
    W -= lr * dW
    db = torch.sum(dscores, axis=0)
    b -= lr * db

  print(epoch, np.mean(epoch_loss))

  accu_number = 0.
  for X, y in test_data_loader:
    inputs = X.reshape(-1, num_inputs)
    predicted_class = to_classlabel(linear(inputs))
    accu_number += torch.sum(predicted_class == y)
  print('testing accuracy: %.4f' % (accu_number/len(test_data_loader.dataset)))


0 4.530234
testing accuracy: 0.5822
1 2.0363443
testing accuracy: 0.6416
2 1.6637751
testing accuracy: 0.6742
3 1.469287
testing accuracy: 0.6903
4 1.3481125
testing accuracy: 0.7062
5 1.2628102
testing accuracy: 0.7186
6 1.1967351
testing accuracy: 0.7273
7 1.144163
testing accuracy: 0.7364
8 1.1015186
testing accuracy: 0.7404
9 1.0634748
testing accuracy: 0.7488
