In [1]:
!pip install torchmetrics
!pip install torchvision

Defaulting to user installation because normal site-packages is not writeable
Collecting torchmetrics
  Downloading torchmetrics-1.6.1-py3-none-any.whl.metadata (21 kB)
Collecting torch>=2.0.0 (from torchmetrics)
  Downloading torch-2.6.0-cp312-cp312-win_amd64.whl.metadata (28 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.12.0-py3-none-any.whl.metadata (5.6 kB)
Collecting sympy==1.13.1 (from torch>=2.0.0->torchmetrics)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Downloading torchmetrics-1.6.1-py3-none-any.whl (927 kB)
   ---------------------------------------- 0.0/927.3 kB ? eta -:--:--
   ---------------------------------------- 0.0/927.3 kB ? eta -:--:--
   -- ------------------------------------- 61.4/927.3 kB ? eta -:--:--
   ----------- ---------------------------- 256.0/927.3 kB 3.2 MB/s eta 0:00:01
   ---------------------- ----------------- 522.2/927.3 kB 4.1 MB/s eta 0:00:01
   ------------------------------



Defaulting to user installation because normal site-packages is not writeable
Collecting torchvision
  Downloading torchvision-0.21.0-cp312-cp312-win_amd64.whl.metadata (6.3 kB)
Downloading torchvision-0.21.0-cp312-cp312-win_amd64.whl (1.6 MB)
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
    --------------------------------------- 0.0/1.6 MB 1.3 MB/s eta 0:00:02
   ---- ----------------------------------- 0.2/1.6 MB 2.0 MB/s eta 0:00:01
   ------------ --------------------------- 0.5/1.6 MB 3.9 MB/s eta 0:00:01
   --------------------------- ------------ 1.1/1.6 MB 6.1 MB/s eta 0:00:01
   ---------------------------------------- 1.6/1.6 MB 7.6 MB/s eta 0:00:00
Installing collected packages: torchvision
Successfully installed torchvision-0.21.0


In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall

In [5]:
# Load datasets
from torchvision import datasets
import torchvision.transforms as transforms

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())


100%|██████████| 26.4M/26.4M [00:16<00:00, 1.59MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 98.9kB/s]
100%|██████████| 4.42M/4.42M [00:02<00:00, 1.59MB/s]
100%|██████████| 5.15k/5.15k [00:00<?, ?B/s]


In [6]:
# Start coding here
# Use as many cells as you need
classes = train_data.classes
num_classes = len(train_data.classes)

In [7]:
#Define some relevant variables
num_input_channels = 1
num_output_channels = 16
image_size = train_data[0][0].shape[1]

In [8]:
#Define CNN
class MultiClassImageClassifier(nn.Module):
    #Define the init method
    def __init__(self, num_classes):
        super(MultiClassImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        
        #Create a fully connected layer
        self.fc = nn.Linear(num_output_channels * (image_size//2)**2, num_classes)
     
    def forward(self, x):
        #pass inputs through each layer
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [9]:
#Define the training set DataLoader
dataloader_train = DataLoader(
     train_data,
     batch_size=10,
     shuffle = True,
)

In [10]:
# Define training function
def train_model(optimizer, net, num_epochs):
    num_processed = 0
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        running_loss = 0
        num_processed = 0
        for features, labels in dataloader_train:
            optimizer.zero_grad()
            outputs = net(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_processed += len(labels)
        print(f'epoch {epoch}, loss: {running_loss / num_processed}')
        
    train_loss = running_loss / len(dataloader_train)


In [11]:
# Train for 1 epoch
net = MultiClassImageClassifier(num_classes)
optimizer = optim.Adam(net.parameters(), lr=0.001)

train_model(
    optimizer=optimizer,
    net=net,
    num_epochs=1,
)


epoch 0, loss: 0.040946135414523695


In [13]:
# Test the model on the test set
              
# Define the test set DataLoader
dataloader_test = DataLoader(
    test_data,
    batch_size=10,
    shuffle=False,
)
# Define the metrics
accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
precision_metric = Precision(task='multiclass', num_classes=num_classes, average=None)
recall_metric = Recall(task='multiclass', num_classes=num_classes, average=None)

# Run model on test set
net.eval()
predictions = []
for i, (features, labels) in enumerate(dataloader_test):
    output = net.forward(features.reshape(-1, 1, image_size, image_size))
    cat = torch.argmax(output, dim=-1)
    predictions.extend(cat.tolist())
    accuracy_metric(cat, labels)
    precision_metric(cat, labels)
    recall_metric(cat, labels)

# Compute the metrics
accuracy = accuracy_metric.compute().item()
precision = precision_metric.compute().tolist()
recall = recall_metric.compute().tolist()
print('Accuracy:', accuracy)
print('Precision (per class):', precision)
print('Recall (per class):', recall)

Accuracy: 0.8808000087738037
Precision (per class): [0.7935779690742493, 0.9691848754882812, 0.8444924354553223, 0.8838383555412292, 0.7656940817832947, 0.9766260385513306, 0.7129186391830444, 0.9203791618347168, 0.960629940032959, 0.9741200804710388]
Recall (per class): [0.8650000095367432, 0.9750000238418579, 0.7820000052452087, 0.875, 0.8659999966621399, 0.9610000252723694, 0.5960000157356262, 0.9710000157356262, 0.9760000109672546, 0.9409999847412109]
