We use image classification neural network to categorize new products of an e-commerce clothing retailer.


In [344]:
!pip install torchmetrics

Defaulting to user installation because normal site-packages is not writeable


In [345]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall

In [346]:
# Load datasets
from torchvision import datasets
import torchvision.transforms as transforms

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Data Presentation

In [347]:
print(type(train_data), len(train_data))
print(type(train_data[0]), len(train_data[0]))

print(type(train_data[0][0]), type(train_data[0][1]))
print(train_data[0][0].shape)

print(len(train_data.classes))

<class 'torchvision.datasets.mnist.FashionMNIST'> 60000
<class 'tuple'> 2
<class 'torch.Tensor'> <class 'int'>
torch.Size([1, 28, 28])
10


Data shape : 60000 x (28x28 + 1)

Input : 60000 images de taille 28x28, noir&blanc(?)

Label : int de 0 à 9



# Network definition

 - Conv : 1 x 28x28 -> k x 28x28        (Without padding:26)
 - ReLU
 - MaxPool : k x 28x28 -> k x 14x14
 - Linear : k x 14x14 -> 10

In [348]:
k = 8  # output channels of the convolution
n_class = 10

A higher value for k only slighly increases the training time.

In [349]:
class Classifier(nn.Module):
    
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.conv = nn.Conv2d(1, k, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flat = nn.Flatten()
        self.lin = nn.Linear(k * 14**2, n_class)
       
    def forward(self, x):
        x= self.conv(x)
        x= self.relu(x)
        x= self.maxpool(x)
        x= self.flat(x)
        x= self.lin(x)
        return x

# Training

In [350]:
dataloader_train = DataLoader(train_data, shuffle = True, batch_size = 10)

In [351]:
def trainer(opti, model, n_epoch):
    criterion = nn.CrossEntropyLoss()

    for epoch in range(n_epoch):
        running_loss = 0
        nb_done = 0
    
        for imgs, lbls in dataloader_train:
            opti.zero_grad()
            outs = model(imgs)
            loss = criterion(outs, lbls)
            loss.backward()
            opti.step()
            
            running_loss += loss.item()
            nb_done += len(lbls)
        print(f'Epoch: {epoch}, running loss: {running_loss/nb_done}')
        # print average loss every epoch 
        
    final_loss = running_loss / len(dataloader_train)  # average loss on the last epoch

    

In [352]:
model = Classifier(n_class)
opti = optim.Adam(model.parameters(), lr=0.001)

trainer(opti = opti, model = model, n_epoch = 1)

Epoch: 0, running loss: 0.045706265081077196


# Evaluation

In [353]:
dataloader_test = DataLoader(test_data, batch_size=10, shuffle=False)

In [354]:
acc = Accuracy(task='multiclass', num_classes = 10)
prec = Precision(task='multiclass', num_classes = 10, average=None)
rec = Recall(task='multiclass', num_classes = 10, average=None)

model.eval()
predicted = []
for i, (imgs, lbls) in enumerate(dataloader_test):
    out = model.forward(imgs.reshape(-1,1, 28,28))
    cat = torch.argmax(out, dim=-1)
    
    predicted.extend(list(cat))    
    prec(cat, lbls)    
    rec(cat, lbls)
    acc(cat, lbls)

precision = prec.compute().tolist()
recall = rec.compute().tolist()
accuracy = acc.compute().item()

print(f"Precision: {precision} \n Recall: {recall} \n Accuracy : {accuracy}")


Precision: [0.8608893752098083, 0.9936908483505249, 0.7691588997840881, 0.8402646780014038, 0.7644320130348206, 0.971727728843689, 0.6663190722465515, 0.9290060997009277, 0.9582089781761169, 0.9098591804504395] 
 Recall: [0.7549999952316284, 0.9449999928474426, 0.8230000138282776, 0.8889999985694885, 0.8209999799728394, 0.9279999732971191, 0.6389999985694885, 0.9160000085830688, 0.9629999995231628, 0.968999981880188] 
 Accuracy : 0.864799976348877
