In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

# Load Dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

## Display sample images

In [None]:
image, label = train_dataset[0]
print(image.shape)

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(10, 15))

for i in range(5):
    image, label = train_dataset[i]

    axes[i].imshow(image.squeeze(), cmap='gray')
    axes[i].set_title(f"Label: {label}")

plt.tight_layout()
plt.grid(False)
plt.show()

# Model

In [None]:
class MyModel(nn.Module):
    def __init__(self, input_size, channels, num_classes):
        super(MyModel, self).__init__()

        self.input_size = input_size
        self.channels = channels
        self.num_classes = num_classes

        # layer 1
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        # self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # layer 2
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        # self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # layer 3
        self.cnn3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        # self.relu3 = nn.ReLU()


        # output layer
        self.linear = nn.Linear(64* (self.input_size // 4) * (self.input_size // 4), self.num_classes)
        # self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.cnn1(x)
        # x = self.relu1(x)
        x = self.pool1(x)
        x = self.cnn2(x)
        # x = self.relu2(x)
        x = self.pool2(x)
        x = self.cnn3(x)
        # x = self.relu3(x)

        x = x.view(-1, 64* (self.input_size // 4) * (self.input_size // 4))       
        x = self.linear(x)
        # x = self.softmax(x)

        return x

In [None]:
channels = 1
input_size = 28
num_classes = 10

model = MyModel(input_size=input_size, channels=channels, num_classes=num_classes)
image, label = train_dataset[0]

predicted = model(image)

plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"{label}")
plt.show()

print(f"Predicted {predicted}")

# Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
num_epochs = 10
learning_rate = 1e-4

model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
from sklearn.metrics import accuracy_score, f1_score

training_losses = []
training_accuracies = []
training_f1_scores = []

test_losses = []
test_accuracies = []
test_f1_scores = []


for epoch in range(num_epochs):

    model.train()
    train_loss = 0.0
    train_preds = []
    train_labels = []

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images, labels in progress_bar:
        images = images.float().to(device)
        labels = labels.long().to(device)

        predicted_logits = model(images)
        # predicted_labels = torch.argmax(predicted_logits)

        loss = criterion(predicted_logits, labels)
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix(loss=loss.item())
        train_loss += loss.item() * images.size(0)

        preds = torch.argmax(predicted_logits, dim=1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

    avg_train_loss = train_loss / len(train_loader.dataset)
    train_acc = accuracy_score(train_labels, train_preds)
    train_f1 = f1_score(train_labels, train_preds, average='weighted')

    training_losses.append(avg_train_loss)
    training_accuracies.append(train_acc)
    training_f1_scores.append(train_f1)

    # Evaluation
    model.eval()
    test_loss = 0.0
    test_preds = []
    test_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device).float()
            labels = labels.to(device).long()

            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * images.size(0)

            preds = torch.argmax(outputs, dim=1)
            test_preds.extend(preds.cpu().numpy())
            test_labels.extend(labels.cpu().numpy())

    avg_test_loss = test_loss / len(test_loader.dataset)
    test_acc = accuracy_score(test_labels, test_preds)
    test_f1 = f1_score(test_labels, test_preds, average='weighted')

    test_losses.append(avg_test_loss)
    test_accuracies.append(test_acc)
    test_f1_scores.append(test_f1)

    print(f"\nEpoch {epoch+1}/{num_epochs} Summary:")
    print(f"Train Loss: {avg_train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
    print(f"Test  Loss: {avg_test_loss:.4f} | Acc: {test_acc:.4f} | F1: {test_f1:.4f}")

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')

In [None]:
scores = ['losses', 'accuracies', 'f1_scores']

for index, score in enumerate(scores):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    if score == 'losses':
        axes[0].plot(training_losses)
        axes[1].plot(test_losses)
    elif score == 'accuracies':
        axes[0].plot(training_accuracies)
        axes[1].plot(test_accuracies)
    else:
        axes[0].plot(training_f1_scores)
        axes[1].plot(test_f1_scores)

    axes[0].set_title(f"Training {score}")
    axes[1].set_title(f"Testing {score}")

    plt.show()
    

# Checking whats in the intermediate levels

In [None]:
# for name, param in model.named_parameters():
#     if 'weight' in name:
#         print(f"Layer: {name}, Shape: {param.shape}")
#         # print(param.data)

#         sp = int(math.sqrt(param.shape[0]))
#         fig, axes = plt.subplots(sp, sp, figsize=(10, 15))

#         for index, d in enumerate(param.data):
#             row = index // sp
#             col = index % sp
            
#             axes[row, col].imshow(d.cpu().numpy().squeeze(), cmap='gray')
#         plt.show()    