In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm 
import matplotlib.pyplot as plt
import numpy as np

In [32]:
# Define the CNN model
class SCNN(nn.Sequential):
    def __init__(self):
        super(SCNN, self).__init__(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(64 * 20 * 20, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
class regCNN(nn.Sequential):
    def __init__(self):
        super(regCNN, self).__init__(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout2d(0.25),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(64 * 20 * 20, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )
    


In [33]:
# Device setup for M3 Pro
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("Using MPS (M3 Pro GPU)")
else:
    device = torch.device("cpu")
    print("Using CPU (MPS not available)")

Using MPS (M3 Pro GPU)


In [34]:
def train(model,num_epochs,patience,train_loader,val_loader,model_path):
    best_val_loss = float('inf')
    patience_counter = 0
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    train_loss_arr =[]
    val_loss_arr = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        train_loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
        for images, labels in train_loop:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_loop.set_postfix(loss=running_loss / (train_loop.n + 1))
        
        avg_loss = running_loss / len(train_loader)
        train_loss_arr.append(avg_loss)
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_loss /= len(val_loader)
        val_loss_arr.append(val_loss)
        accuracy = 100 * correct / total
        print(f"Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}% , Best Val Loss: {best_val_loss:.2f}")

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f'{model_path}best.pth')
            # print("Best Model saved successfully!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered. Training stopped.")
                break
    return train_loss_arr,val_loss_arr

In [35]:
def test(model_class ,model_path,val_loader):
    model = model_class
    model.load_state_dict(torch.load(model_path, map_location=device)) 
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Validation Accuracy: {accuracy:.2f}%")

In [36]:
def plot(array,filename):
    epochs = np.arange(len(array))  # Epochs are the indices of the array
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, array, marker='o', linestyle='-')
    plt.title('Average Loss vs. Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.grid(True)
    plt.savefig(f'plots/{filename}.png')
    plt.close() 

# Simple

In [37]:
simple_transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# Load Imagenette dataset
qual = '160px'
folder = f'./data/Simple{qual}/'
train_dataset = datasets.Imagenette(root=folder, split='train', size=qual, download=True, transform=simple_transform)
val_dataset = datasets.Imagenette(root=folder, split='val', size=qual, download=True, transform=simple_transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)



In [38]:
model = SCNN().to(device)
train_loss_arr,val_loss_arr = train(model,30,5,train_loader,val_loader,'./model/Simple/')
# print(model)
plot(train_loss_arr,'Simple_Train_loss')
plot(val_loss_arr,'Simple_Val_loss')


Epoch [1/30]: 100%|██████████| 148/148 [00:30<00:00,  4.91it/s, loss=1.74]

Epoch [1/30], Training Loss: 1.7234





Validation Loss: 1.4171, Accuracy: 52.31% , Best Val Loss: inf


Epoch [2/30]: 100%|██████████| 148/148 [00:29<00:00,  5.02it/s, loss=1.22]

Epoch [2/30], Training Loss: 1.1993





Validation Loss: 1.1643, Accuracy: 62.24% , Best Val Loss: 1.42


Epoch [3/30]: 100%|██████████| 148/148 [00:29<00:00,  5.03it/s, loss=0.937]

Epoch [3/30], Training Loss: 0.9245





Validation Loss: 1.1167, Accuracy: 64.66% , Best Val Loss: 1.16


Epoch [4/30]: 100%|██████████| 148/148 [00:29<00:00,  4.99it/s, loss=0.727]

Epoch [4/30], Training Loss: 0.7168





Validation Loss: 1.1253, Accuracy: 66.01% , Best Val Loss: 1.12


Epoch [5/30]: 100%|██████████| 148/148 [00:29<00:00,  5.02it/s, loss=0.53] 

Epoch [5/30], Training Loss: 0.5225





Validation Loss: 1.1492, Accuracy: 66.37% , Best Val Loss: 1.12


Epoch [6/30]: 100%|██████████| 148/148 [00:29<00:00,  5.02it/s, loss=0.336]

Epoch [6/30], Training Loss: 0.3358





Validation Loss: 1.3781, Accuracy: 62.60% , Best Val Loss: 1.12


Epoch [7/30]: 100%|██████████| 148/148 [00:29<00:00,  5.08it/s, loss=0.196]

Epoch [7/30], Training Loss: 0.1951





Validation Loss: 1.5362, Accuracy: 65.30% , Best Val Loss: 1.12


Epoch [8/30]: 100%|██████████| 148/148 [00:30<00:00,  4.87it/s, loss=0.0929]

Epoch [8/30], Training Loss: 0.0923





Validation Loss: 1.7763, Accuracy: 64.46% , Best Val Loss: 1.12
Early stopping triggered. Training stopped.


In [39]:
test(SCNN(),'./model/Simple/best.pth',val_loader)

Validation Accuracy: 64.66%


# Regularisation

In [40]:
# Define transforms
reg_transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15),  
    transforms.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# Load Imagenette dataset
qual = '160px'
folder = f'./data/reg{qual}/'
train_dataset_reg = datasets.Imagenette(root=folder, split='train', size=qual, download=True, transform=reg_transform)
val_dataset_reg = datasets.Imagenette(root=folder, split='val', size=qual, download=True, transform=reg_transform)

train_loader_reg = DataLoader(train_dataset_reg, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader_reg = DataLoader(val_dataset_reg, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)



In [41]:
model2 = regCNN().to(device)
train_loss_arr,val_loss_arr = train(model2,30,5,train_loader_reg,val_loader_reg,'./model/Reg/')
plot(train_loss_arr,'Reg_Train_loss')
plot(val_loss_arr,'Reg_Val_loss')

Epoch [1/30]: 100%|██████████| 148/148 [00:29<00:00,  4.95it/s, loss=2.02]

Epoch [1/30], Training Loss: 1.9955





Validation Loss: 1.7037, Accuracy: 43.52% , Best Val Loss: inf


Epoch [2/30]: 100%|██████████| 148/148 [00:29<00:00,  4.99it/s, loss=1.74]

Epoch [2/30], Training Loss: 1.7176





Validation Loss: 1.5160, Accuracy: 49.83% , Best Val Loss: 1.70


Epoch [3/30]: 100%|██████████| 148/148 [00:29<00:00,  4.97it/s, loss=1.59]

Epoch [3/30], Training Loss: 1.5683





Validation Loss: 1.3746, Accuracy: 54.34% , Best Val Loss: 1.52


Epoch [4/30]: 100%|██████████| 148/148 [00:29<00:00,  4.95it/s, loss=1.49]

Epoch [4/30], Training Loss: 1.4732





Validation Loss: 1.2762, Accuracy: 58.78% , Best Val Loss: 1.37


Epoch [5/30]: 100%|██████████| 148/148 [00:29<00:00,  4.97it/s, loss=1.38]

Epoch [5/30], Training Loss: 1.3607





Validation Loss: 1.2120, Accuracy: 60.48% , Best Val Loss: 1.28


Epoch [6/30]: 100%|██████████| 148/148 [00:29<00:00,  4.97it/s, loss=1.34]

Epoch [6/30], Training Loss: 1.3213





Validation Loss: 1.2080, Accuracy: 60.94% , Best Val Loss: 1.21


Epoch [7/30]: 100%|██████████| 148/148 [00:29<00:00,  4.98it/s, loss=1.29]

Epoch [7/30], Training Loss: 1.2704





Validation Loss: 1.1611, Accuracy: 61.68% , Best Val Loss: 1.21


Epoch [8/30]: 100%|██████████| 148/148 [00:29<00:00,  4.95it/s, loss=1.25]

Epoch [8/30], Training Loss: 1.2305





Validation Loss: 1.0957, Accuracy: 64.84% , Best Val Loss: 1.16


Epoch [9/30]: 100%|██████████| 148/148 [00:29<00:00,  4.98it/s, loss=1.2] 

Epoch [9/30], Training Loss: 1.1820





Validation Loss: 1.0611, Accuracy: 65.89% , Best Val Loss: 1.10


Epoch [10/30]: 100%|██████████| 148/148 [00:29<00:00,  4.97it/s, loss=1.17]

Epoch [10/30], Training Loss: 1.1513





Validation Loss: 1.0430, Accuracy: 66.78% , Best Val Loss: 1.06


Epoch [11/30]: 100%|██████████| 148/148 [00:29<00:00,  4.96it/s, loss=1.13]

Epoch [11/30], Training Loss: 1.1098





Validation Loss: 1.0567, Accuracy: 66.65% , Best Val Loss: 1.04


Epoch [12/30]: 100%|██████████| 148/148 [00:29<00:00,  4.99it/s, loss=1.09]

Epoch [12/30], Training Loss: 1.0774





Validation Loss: 0.9882, Accuracy: 68.10% , Best Val Loss: 1.04


Epoch [13/30]: 100%|██████████| 148/148 [00:29<00:00,  4.96it/s, loss=1.09]

Epoch [13/30], Training Loss: 1.0781





Validation Loss: 0.9884, Accuracy: 68.31% , Best Val Loss: 0.99


Epoch [14/30]: 100%|██████████| 148/148 [00:30<00:00,  4.91it/s, loss=1.04]

Epoch [14/30], Training Loss: 1.0260





Validation Loss: 0.9680, Accuracy: 68.79% , Best Val Loss: 0.99


Epoch [15/30]: 100%|██████████| 148/148 [00:29<00:00,  4.98it/s, loss=1.03]

Epoch [15/30], Training Loss: 1.0170





Validation Loss: 0.9593, Accuracy: 68.66% , Best Val Loss: 0.97


Epoch [16/30]: 100%|██████████| 148/148 [00:29<00:00,  4.99it/s, loss=1.01] 

Epoch [16/30], Training Loss: 0.9929





Validation Loss: 0.9445, Accuracy: 69.35% , Best Val Loss: 0.96


Epoch [17/30]: 100%|██████████| 148/148 [00:29<00:00,  4.94it/s, loss=0.991]

Epoch [17/30], Training Loss: 0.9780





Validation Loss: 0.9203, Accuracy: 70.93% , Best Val Loss: 0.94


Epoch [18/30]: 100%|██████████| 148/148 [00:29<00:00,  4.97it/s, loss=0.955]

Epoch [18/30], Training Loss: 0.9422





Validation Loss: 0.9262, Accuracy: 70.39% , Best Val Loss: 0.92


Epoch [19/30]: 100%|██████████| 148/148 [00:29<00:00,  4.98it/s, loss=0.952]

Epoch [19/30], Training Loss: 0.9391





Validation Loss: 0.9247, Accuracy: 70.73% , Best Val Loss: 0.92


Epoch [20/30]: 100%|██████████| 148/148 [00:29<00:00,  4.94it/s, loss=0.942]

Epoch [20/30], Training Loss: 0.9294





Validation Loss: 0.9265, Accuracy: 69.83% , Best Val Loss: 0.92


Epoch [21/30]: 100%|██████████| 148/148 [00:29<00:00,  4.95it/s, loss=0.906]

Epoch [21/30], Training Loss: 0.8938





Validation Loss: 0.8913, Accuracy: 71.64% , Best Val Loss: 0.92


Epoch [22/30]: 100%|██████████| 148/148 [00:29<00:00,  4.97it/s, loss=0.89] 

Epoch [22/30], Training Loss: 0.8777





Validation Loss: 0.9022, Accuracy: 70.52% , Best Val Loss: 0.89


Epoch [23/30]: 100%|██████████| 148/148 [00:29<00:00,  4.97it/s, loss=0.892]

Epoch [23/30], Training Loss: 0.8798





Validation Loss: 0.8778, Accuracy: 72.10% , Best Val Loss: 0.89


Epoch [24/30]: 100%|██████████| 148/148 [00:29<00:00,  4.98it/s, loss=0.859]

Epoch [24/30], Training Loss: 0.8476





Validation Loss: 0.8835, Accuracy: 71.46% , Best Val Loss: 0.88


Epoch [25/30]: 100%|██████████| 148/148 [00:29<00:00,  4.96it/s, loss=0.86] 

Epoch [25/30], Training Loss: 0.8488





Validation Loss: 0.8695, Accuracy: 72.10% , Best Val Loss: 0.88


Epoch [26/30]: 100%|██████████| 148/148 [00:29<00:00,  4.98it/s, loss=0.858]

Epoch [26/30], Training Loss: 0.8464





Validation Loss: 0.8856, Accuracy: 71.46% , Best Val Loss: 0.87


Epoch [27/30]: 100%|██████████| 148/148 [00:30<00:00,  4.81it/s, loss=0.836]

Epoch [27/30], Training Loss: 0.8247





Validation Loss: 0.8705, Accuracy: 71.82% , Best Val Loss: 0.87


Epoch [28/30]: 100%|██████████| 148/148 [00:29<00:00,  4.93it/s, loss=0.823]

Epoch [28/30], Training Loss: 0.8122





Validation Loss: 0.8744, Accuracy: 72.51% , Best Val Loss: 0.87


Epoch [29/30]: 100%|██████████| 148/148 [00:29<00:00,  4.98it/s, loss=0.808]

Epoch [29/30], Training Loss: 0.7968





Validation Loss: 0.8675, Accuracy: 72.66% , Best Val Loss: 0.87


Epoch [30/30]: 100%|██████████| 148/148 [00:29<00:00,  4.96it/s, loss=0.818]

Epoch [30/30], Training Loss: 0.8066





Validation Loss: 0.8714, Accuracy: 72.64% , Best Val Loss: 0.87


In [42]:
test(regCNN(),'./model/Reg/best.pth',val_loader_reg)

Validation Accuracy: 72.82%


# Transfer Learning

In [45]:
cifar_transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


cifar_train = datasets.CIFAR10(root='./data/cifar/', train=True, download=True, transform=cifar_transform)


cifar_test = datasets.CIFAR10(root='./data/cifar/', train=False, download=True, transform=cifar_transform)


cifar_train_loader = DataLoader(cifar_train, batch_size=64, shuffle=True)
cifar_test_loader = DataLoader(cifar_test, batch_size=64, shuffle=False)

In [46]:
best_model = regCNN().to(device)
best_model.load_state_dict(torch.load("./model/Reg/best.pth"))
for layer in best_model.modules():
    if isinstance(layer, nn.Conv2d):
        for param in layer.parameters():
            param.requires_grad = False

train_loss_arr,val_loss_arr = train(best_model,15,3, cifar_train_loader, cifar_test_loader,'./model/cifar/')

plot(train_loss_arr,'cifar_Train_loss')
plot(val_loss_arr,'cifar_Val_loss')

Epoch [1/15]: 100%|██████████| 782/782 [00:32<00:00, 24.01it/s, loss=1.91]


Epoch [1/15], Training Loss: 1.9046
Validation Loss: 1.5260, Accuracy: 47.51% , Best Val Loss: inf


Epoch [2/15]: 100%|██████████| 782/782 [00:37<00:00, 20.65it/s, loss=1.64]


Epoch [2/15], Training Loss: 1.6381
Validation Loss: 1.4205, Accuracy: 51.20% , Best Val Loss: 1.53


Epoch [3/15]: 100%|██████████| 782/782 [00:37<00:00, 21.07it/s, loss=1.56]


Epoch [3/15], Training Loss: 1.5571
Validation Loss: 1.3521, Accuracy: 52.81% , Best Val Loss: 1.42


Epoch [4/15]: 100%|██████████| 782/782 [00:37<00:00, 20.82it/s, loss=1.51]


Epoch [4/15], Training Loss: 1.5127
Validation Loss: 1.2923, Accuracy: 55.16% , Best Val Loss: 1.35


Epoch [5/15]: 100%|██████████| 782/782 [00:36<00:00, 21.45it/s, loss=1.48]


Epoch [5/15], Training Loss: 1.4773
Validation Loss: 1.2769, Accuracy: 55.80% , Best Val Loss: 1.29


Epoch [6/15]: 100%|██████████| 782/782 [00:37<00:00, 20.77it/s, loss=1.46]


Epoch [6/15], Training Loss: 1.4559
Validation Loss: 1.2610, Accuracy: 56.66% , Best Val Loss: 1.28


Epoch [7/15]: 100%|██████████| 782/782 [00:36<00:00, 21.69it/s, loss=1.44]


Epoch [7/15], Training Loss: 1.4404
Validation Loss: 1.2495, Accuracy: 56.93% , Best Val Loss: 1.26


Epoch [8/15]: 100%|██████████| 782/782 [00:36<00:00, 21.32it/s, loss=1.43]


Epoch [8/15], Training Loss: 1.4298
Validation Loss: 1.2344, Accuracy: 57.23% , Best Val Loss: 1.25


Epoch [9/15]: 100%|██████████| 782/782 [00:35<00:00, 21.81it/s, loss=1.41]


Epoch [9/15], Training Loss: 1.4101
Validation Loss: 1.2164, Accuracy: 58.03% , Best Val Loss: 1.23


Epoch [10/15]: 100%|██████████| 782/782 [00:35<00:00, 22.31it/s, loss=1.4] 


Epoch [10/15], Training Loss: 1.4017
Validation Loss: 1.2089, Accuracy: 58.02% , Best Val Loss: 1.22


Epoch [11/15]: 100%|██████████| 782/782 [00:34<00:00, 22.48it/s, loss=1.39]


Epoch [11/15], Training Loss: 1.3837
Validation Loss: 1.2014, Accuracy: 58.04% , Best Val Loss: 1.21


Epoch [12/15]: 100%|██████████| 782/782 [00:34<00:00, 22.52it/s, loss=1.38]


Epoch [12/15], Training Loss: 1.3827
Validation Loss: 1.2053, Accuracy: 58.38% , Best Val Loss: 1.20


Epoch [13/15]: 100%|██████████| 782/782 [00:37<00:00, 20.90it/s, loss=1.37]


Epoch [13/15], Training Loss: 1.3663
Validation Loss: 1.1824, Accuracy: 58.78% , Best Val Loss: 1.20


Epoch [14/15]: 100%|██████████| 782/782 [00:34<00:00, 22.57it/s, loss=1.37]


Epoch [14/15], Training Loss: 1.3659
Validation Loss: 1.1852, Accuracy: 58.90% , Best Val Loss: 1.18


Epoch [15/15]: 100%|██████████| 782/782 [00:36<00:00, 21.53it/s, loss=1.36]


Epoch [15/15], Training Loss: 1.3570
Validation Loss: 1.1776, Accuracy: 59.62% , Best Val Loss: 1.18


In [47]:
test(regCNN(),'./model/cifar/best.pth',cifar_test_loader)

Validation Accuracy: 59.43%
