In [5]:


import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall

# Load datasets
from torchvision import datasets
import torchvision.transforms as transforms

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [25]:

# Get the number of classes
classes = train_data.classes
num_classes = len(train_data.classes)

#For grayscale image it's 1
num_input_channels = 1

num_output_channels = 16

#Get the height of image
image_size = train_data[0][0].shape[1]

# Define CNN
class MultiClassImageClassifier(nn.Module):

    # Define the init method
    def __init__(self, num_classes):
        super(MultiClassImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

        # Create a fully connected layer
        self.fc = nn.Linear(num_output_channels * (image_size//2)**2, num_classes)

    def forward(self, x):
        # Pass inputs through each layer
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Define the training set DataLoader
dataloader_train = DataLoader(
    train_data,
    batch_size=10,
    shuffle=True,
)

# Define training function
def train_model(optimizer, net, num_epochs):
    num_processed = 0
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        running_loss = 0
        num_processed = 0
        for features, labels in dataloader_train:
            optimizer.zero_grad()
            outputs = net(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_processed += len(labels)
        print(f'epoch {epoch}, loss: {running_loss / num_processed}')

    train_loss = running_loss / len(dataloader_train)


# Train for 1 epoch
net = MultiClassImageClassifier(num_classes)
optimizer = optim.Adam(net.parameters(), lr=0.001)

train_model(
    optimizer=optimizer,
    net=net,
    num_epochs=1,
)

# Test the model on the test set

# Define the test set DataLoader
dataloader_test = DataLoader(
    test_data,
    batch_size=10,
    shuffle=False,
)
# Define the metrics
accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
precision_metric = Precision(task='multiclass', num_classes=num_classes, average=None)
recall_metric = Recall(task='multiclass', num_classes=num_classes, average=None)

# Run model on test set
net.eval()
predictions = []
for i, (features, labels) in enumerate(dataloader_test):
    output = net.forward(features.reshape(-1, 1, image_size, image_size))
    cat = torch.argmax(output, dim=-1)
    predictions.extend(cat.tolist())
    accuracy_metric(cat, labels)
    precision_metric(cat, labels)
    recall_metric(cat, labels)

# Compute the metrics
accuracy = accuracy_metric.compute().item()
precision = precision_metric.compute().tolist()
recall = recall_metric.compute().tolist()
# Print accuracy
print('Overall Accuracy:', accuracy)
print()

# Print precision per class
print('Precision per class:')
for i, (cls, prec) in enumerate(zip(train_data.classes, precision)):
    print(f'{cls:15s}: {prec:.4f}')  # :15s adds padding for alignment, :.4f formats to 4 decimal places
print()

# Print recall per class
print('Recall per class:')
for i, (cls, rec) in enumerate(zip(train_data.classes, recall)):
    print(f'{cls:15s}: {rec:.4f}')

epoch 0, loss: 0.040877450966547865
Overall Accuracy: 0.8859999775886536

Precision per class:
T-shirt/top    : 0.8290
Trouser        : 0.9683
Pullover       : 0.8446
Dress          : 0.8720
Coat           : 0.8163
Sandal         : 0.9904
Shirt          : 0.7134
Sneaker        : 0.9024
Bag            : 0.9642
Ankle boot     : 0.9547

Recall per class:
T-shirt/top    : 0.8580
Trouser        : 0.9780
Pullover       : 0.8100
Dress          : 0.8990
Coat           : 0.8400
Sandal         : 0.9240
Shirt          : 0.6620
Sneaker        : 0.9710
Bag            : 0.9700
Ankle boot     : 0.9480
