In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, Precision, Recall

from torchvision import datasets, transforms

# -------------------------------
# Hyperparameters (Tunable)
# -------------------------------
batch_size = 64
learning_rate = 0.001
num_epochs = 10
dropout_rate = 0.3
num_output_channels = 32  # Increased filters

# -------------------------------
# Load datasets
# -------------------------------
train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

num_classes = len(train_data.classes)
num_input_channels = 1
image_size = train_data[0][0].shape[1]

# -------------------------------
# Define CNN with Dropout and BatchNorm
# -------------------------------
class MultiClassImageClassifier(nn.Module):
    def __init__(self, num_classes):
        super(MultiClassImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_output_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(num_output_channels * (image_size // 2) ** 2, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# -------------------------------
# Define DataLoaders
# -------------------------------
dataloader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# -------------------------------
# Training Function
# -------------------------------
def train_model(optimizer, net, num_epochs):
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        net.train()
        running_loss = 0.0
        for features, labels in dataloader_train:
            optimizer.zero_grad()
            outputs = net(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(dataloader_train)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# -------------------------------
# Train the model
# -------------------------------
net = MultiClassImageClassifier(num_classes)
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

train_model(optimizer=optimizer, net=net, num_epochs=num_epochs)

# -------------------------------
# Evaluate the model
# -------------------------------
accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
precision_metric = Precision(task='multiclass', num_classes=num_classes, average=None)
recall_metric = Recall(task='multiclass', num_classes=num_classes, average=None)

net.eval()
predictions = []
with torch.no_grad():
    for features, labels in dataloader_test:
        output = net(features)
        predicted = torch.argmax(output, dim=1)
        predictions.extend(predicted.tolist())
        accuracy_metric(predicted, labels)
        precision_metric(predicted, labels)
        recall_metric(predicted, labels)

# -------------------------------
# Print Results
# -------------------------------
accuracy = accuracy_metric.compute().item()
precision = precision_metric.compute().tolist()
recall = recall_metric.compute().tolist()

print('\nOverall Accuracy:', f'{accuracy:.4f}\n')

print('Precision per class:')
for cls, prec in zip(train_data.classes, precision):
    print(f'{cls:15s}: {prec:.4f}')

print('\nRecall per class:')
for cls, rec in zip(train_data.classes, recall):
    print(f'{cls:15s}: {rec:.4f}')






00%|█████████████████████████████████████████████████████████████████████████████| 5.15k/5.15k [00:00<00:00, 3.43MB/s]

Epoch [1/10], Loss: 0.4511
