Fashion Forward is a new AI-based e-commerce clothing retailer.
They want to use image classification to automatically categorize new product listings, making it easier for customers to find what they're looking for. It will also assist in inventory management by quickly sorting items.

As a data scientist tasked with implementing a garment classifier, your primary objective is to develop a machine learning model capable of accurately categorizing images of clothing items into distinct garment types such as shirts, trousers, shoes, etc.


In [1]:
# Run the cells below first

In [2]:
!pip install torchmetrics



In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall

In [4]:
# Load datasets
from torchvision import datasets
import torchvision.transforms as transforms

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 12453082.07it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 270579.26it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5010021.35it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 13894644.14it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [5]:


class FashionClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FashionClassifier, self).__init__()
        self.feature_extracter = nn.Sequential(
                                    nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2),
                                    nn.Flatten())
        self.classifier = nn.Sequential(
                                    nn.Linear(16 * 14 * 14, num_classes),
                                    nn.Softmax())  # Adjust the input size here

    def forward(self, x):
        x = self.feature_extracter(x)
        x = self.classifier(x)
        return x


In [6]:
num_classes = len(train_data.classes)
dataloader_train = DataLoader(
                            train_data,
                            batch_size=10,
                            shuffle=True)

In [7]:
net = FashionClassifier(num_classes)

net

FashionClassifier(
  (feature_extracter): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Linear(in_features=3136, out_features=10, bias=True)
    (1): Softmax(dim=None)
  )
)

In [8]:
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(2):  # Change the loop variable
    losses = 0
    steps = 0
    for images, labels in dataloader_train:
        optimizer.zero_grad()
        outputs = net(images)  # Add unsqueeze to add batch dimension
        loss = criterion(outputs, labels)  # Convert labels to tensor
        loss.backward()
        optimizer.step()
        losses+=loss
        steps += len(labels)
    print(f"Epoch {epoch+1}: loss {losses/steps}")



  input = module(input)


Epoch 1: loss 0.16391554474830627
Epoch 2: loss 0.15867219865322113


In [9]:
dataloader_test = DataLoader(
                            test_data,
                            batch_size=10,
                            shuffle=True)

In [10]:
accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
precision_metric = Precision(task='multiclass', num_classes=num_classes, average='none')
recall_metric = Recall(task='multiclass', num_classes=num_classes, average='none')


In [11]:
predictions = []

net.eval()
with torch.no_grad():
    for images, labels in dataloader_test:
        outputs = net(images)
        cat = torch.argmax(outputs, 1)
        predictions.extend(cat.tolist())
        accuracy_metric(cat, labels)
        precision_metric(cat, labels)
        recall_metric(cat, labels)
        
accuracy = accuracy_metric.compute().item()
precision = precision_metric.compute().tolist()
recall = recall_metric.compute().tolist()

print(f"Accuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}")

Accuracy: 0.8758999705314636
Precision: [0.780662477016449, 0.981670081615448, 0.8213573098182678, 0.8696939945220947, 0.7806072235107422, 0.9844720363616943, 0.7004830837249756, 0.9033148884773254, 0.9505813717842102, 0.9748163819313049]
Recall: [0.871999979019165, 0.9639999866485596, 0.8230000138282776, 0.8809999823570251, 0.796999990940094, 0.9509999752044678, 0.5799999833106995, 0.9810000061988831, 0.9810000061988831, 0.9290000200271606]
