In [1]:
# Imports for Dataset class
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

# Imports for dataloader
from torch.utils.data import DataLoader

# Imports for CNN
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# Imports for plotting
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np

In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


In [3]:
class ASLDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                                            on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.label_mapping = {}  # Mapping from original labels to continuous range

        # Collect all image paths and labels
        for filename in os.listdir(root_dir):
            if filename.startswith('color_'):
                # Extract label from filename
                parts = filename.split('_')
                label = int(parts[1])  # Assuming the label is the second part of the filename
                self.image_paths.append(os.path.join(root_dir, filename))
                self.labels.append(label)

        # Create a continuous label mapping
        unique_labels = sorted(set(self.labels))
        self.label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
        self.labels = [self.label_mapping[label] for label in self.labels]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Ensure image is in grayscale
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


root_dir = 'dataset5_50x50_split/'
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1] range for grayscale
])

# Create dataset instances
train_dataset = ASLDataset(root_dir=os.path.join(root_dir, 'train'), transform=transform)
val_dataset = ASLDataset(root_dir=os.path.join(root_dir, 'val'), transform=transform)
test_dataset = ASLDataset(root_dir=os.path.join(root_dir, 'test'), transform=transform)


In [4]:
# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [5]:
# Class for CNN following paper 1 architecture
class PaperCNN(nn.Module):
    def __init__(self, num_classes=24, p_dropout = 0.3):
        super(PaperCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.dropout1 = nn.Dropout2d(p_dropout)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.dropout2 = nn.Dropout2d(p_dropout)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(128, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(512)
        self.dropout3 = nn.Dropout2d(0.3)
        
        self.conv4 = nn.Conv2d(512, 64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.dropout4 = nn.Dropout2d(p_dropout)
        
        self.conv5 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        self.dropout5 = nn.Dropout2d(p_dropout)
        
        self.fc1 = nn.Linear(32 * 3 * 3, 512)
        self.dropout_fc1 = nn.Dropout(p_dropout)
        
        self.fc2 = nn.Linear(512, num_classes)
        
    def forward(self, x):
        # convrelu - dropout - bn
        x = self.bn1(self.dropout1(F.relu(self.conv1(x))))
#         print(f"Shape after conv1 + relu + dropout + bn1: {x.shape}")
        
        x = self.bn2(self.dropout2(F.relu(self.conv2(x))))
#         print(f"Shape after conv2 + relu + dropout + bn2: {x.shape}")
        x = self.pool(x)
#         print(f"Shape after pool: {x.shape}")
        
        x = self.bn3(self.dropout3(F.relu(self.conv3(x))))
#         print(f"Shape after conv3 + relu + dropout + bn3: {x.shape}")
        x = self.pool(x)
#         print(f"Shape after pool: {x.shape}")
        
        x = self.bn4(self.dropout4(F.relu(self.conv4(x))))
#         print(f"Shape after conv4 + relu + dropout + bn4: {x.shape}")
        x = self.pool(x)
#         print(f"Shape after pool: {x.shape}")
        
        x = self.bn5(self.dropout5(F.relu(self.conv5(x))))
#         print(f"Shape after conv5 + relu + dropout + bn4: {x.shape}")
        x = self.pool(x)
#         print(f"Shape after pool: {x.shape}")
        
        x = x.view(-1, 32 * 3 * 3)  # Flatten the output of the last conv layer
#         print(f"Shape after view: {x.shape}")
        
        x = self.dropout_fc1(F.relu(self.fc1(x)))
        x = self.fc2(x)
        
        return x

In [6]:
# Function to infer number of classes
def get_num_classes(dataset):
    """
    Get the number of unique classes in the dataset.

    Args:
        dataset (Dataset): The dataset to analyze.

    Returns:
        int: The number of unique classes.
    """
    unique_labels = set()
    for _, label in dataset:
        unique_labels.add(label)
    return len(unique_labels)

In [None]:
# Initialize the model, loss function, and optimizer

p_dropouts = [0.1, 0.2, 0.3, 0.4, 0.5]
lrs = [0.0001, 0.0005, 0.001, 0.005, 0.01]
for p_dropout in p_dropouts:
    for lr in lrs: 
        num_classes = get_num_classes(train_dataset)
        model = PaperCNN(num_classes=num_classes, p_dropout = p_dropout)
        device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
        model.to(device)
        #print("Model initialized")
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        #print("Optimizer initialized")
        
        # Initialize lists to track metrics
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        
        
        # Training loop with tqdm progress bar
        num_epochs = 10
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            correct_train = 0
            total_train = 0
            train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
            for images, labels in train_loader_tqdm:
                # move images and labels to device
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                train_loader_tqdm.set_postfix(loss=running_loss/len(train_loader))
                
                # Calculate training accuracy
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()
        
            epoch_loss = running_loss / len(train_loader)
            train_losses.append(epoch_loss)
            train_accuracy = 100 * correct_train / total_train
            train_accuracies.append(train_accuracy)
            
            #print(f"Epoch {epoch+1}, Loss: {epoch_loss}, Accuracy: {train_accuracy}%")
        
            # Validation loop
            model.eval()
            running_val_loss = 0.0
            correct_val = 0
            total_val = 0
            all_labels = []
            all_predictions = []
            with torch.no_grad():
                for images, labels in val_loader:
                    # move images and labels to device
                    images, labels = images.to(device), labels.to(device)
                    
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    running_val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total_val += labels.size(0)
                    correct_val += (predicted == labels).sum().item()
                    all_labels.extend(labels.cpu().numpy())
                    all_predictions.extend(predicted.cpu().numpy())
            
            epoch_val_loss = running_val_loss / len(val_loader)
            val_losses.append(epoch_val_loss)
            val_accuracy = 100 * correct_val / total_val
            val_accuracies.append(val_accuracy)
        final_train_loss = train_losses[-1]
        final_val_loss = val_losses[-1]
        final_train_accuracy = train_accuracies[-1]
        final_val_accuracy = val_accuracies[-1]
        print(f"Dropout: {p_dropout}, LR: {lr}, Validation Loss: {final_val_loss}, Validation Accuracy: {final_val_accuracy}%")
        
        
        print("Training complete")

Epoch 1/10: 100%|██████████| 1439/1439 [02:18<00:00, 10.37batch/s, loss=1.16]
Epoch 2/10: 100%|██████████| 1439/1439 [02:27<00:00,  9.73batch/s, loss=0.289]
Epoch 3/10: 100%|██████████| 1439/1439 [02:52<00:00,  8.32batch/s, loss=0.155]
Epoch 4/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.13batch/s, loss=0.1]   
Epoch 5/10: 100%|██████████| 1439/1439 [02:59<00:00,  8.03batch/s, loss=0.073] 
Epoch 6/10: 100%|██████████| 1439/1439 [03:02<00:00,  7.87batch/s, loss=0.0567]
Epoch 7/10: 100%|██████████| 1439/1439 [03:11<00:00,  7.51batch/s, loss=0.0476]
Epoch 8/10: 100%|██████████| 1439/1439 [03:17<00:00,  7.30batch/s, loss=0.0382]
Epoch 9/10: 100%|██████████| 1439/1439 [03:06<00:00,  7.71batch/s, loss=0.0274]
Epoch 10/10: 100%|██████████| 1439/1439 [03:06<00:00,  7.72batch/s, loss=0.0312]


Dropout: 0.1, LR: 0.0001, Validation Loss: 0.01621346259610215, Validation Accuracy: 99.44252990066896%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:03<00:00,  7.86batch/s, loss=0.572]
Epoch 2/10: 100%|██████████| 1439/1439 [02:55<00:00,  8.21batch/s, loss=0.125] 
Epoch 3/10: 100%|██████████| 1439/1439 [02:58<00:00,  8.07batch/s, loss=0.0756]
Epoch 4/10: 100%|██████████| 1439/1439 [03:06<00:00,  7.72batch/s, loss=0.0545]
Epoch 5/10: 100%|██████████| 1439/1439 [03:01<00:00,  7.92batch/s, loss=0.0447]
Epoch 6/10: 100%|██████████| 1439/1439 [02:53<00:00,  8.30batch/s, loss=0.0364]
Epoch 7/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.11batch/s, loss=0.0307]
Epoch 8/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.10batch/s, loss=0.0269]
Epoch 9/10: 100%|██████████| 1439/1439 [03:31<00:00,  6.80batch/s, loss=0.0256]
Epoch 10/10: 100%|██████████| 1439/1439 [03:26<00:00,  6.96batch/s, loss=0.0221]


Dropout: 0.1, LR: 0.0005, Validation Loss: 0.017602893635198538, Validation Accuracy: 99.49320900060815%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:27<00:00,  6.93batch/s, loss=0.509]
Epoch 2/10: 100%|██████████| 1439/1439 [03:26<00:00,  6.96batch/s, loss=0.119] 
Epoch 3/10: 100%|██████████| 1439/1439 [03:30<00:00,  6.84batch/s, loss=0.073] 
Epoch 4/10: 100%|██████████| 1439/1439 [03:34<00:00,  6.70batch/s, loss=0.0631]
Epoch 5/10: 100%|██████████| 1439/1439 [03:24<00:00,  7.05batch/s, loss=0.0533]
Epoch 6/10: 100%|██████████| 1439/1439 [03:44<00:00,  6.41batch/s, loss=0.041] 
Epoch 7/10: 100%|██████████| 1439/1439 [03:55<00:00,  6.12batch/s, loss=0.0418]
Epoch 8/10: 100%|██████████| 1439/1439 [03:43<00:00,  6.44batch/s, loss=0.0357]
Epoch 9/10: 100%|██████████| 1439/1439 [03:53<00:00,  6.17batch/s, loss=0.0321]
Epoch 10/10: 100%|██████████| 1439/1439 [03:37<00:00,  6.62batch/s, loss=0.0296]


Dropout: 0.1, LR: 0.001, Validation Loss: 0.02140147427868964, Validation Accuracy: 99.41212244070546%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:40<00:00,  6.53batch/s, loss=0.605]
Epoch 2/10: 100%|██████████| 1439/1439 [03:44<00:00,  6.40batch/s, loss=0.176]
Epoch 3/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.51batch/s, loss=0.116] 
Epoch 4/10: 100%|██████████| 1439/1439 [03:24<00:00,  7.05batch/s, loss=0.0974]
Epoch 5/10: 100%|██████████| 1439/1439 [03:42<00:00,  6.47batch/s, loss=0.0824]
Epoch 6/10: 100%|██████████| 1439/1439 [03:36<00:00,  6.64batch/s, loss=0.071] 
Epoch 7/10: 100%|██████████| 1439/1439 [04:44<00:00,  5.06batch/s, loss=0.0664]
Epoch 8/10: 100%|██████████| 1439/1439 [05:01<00:00,  4.78batch/s, loss=0.0601]
Epoch 9/10: 100%|██████████| 1439/1439 [04:10<00:00,  5.73batch/s, loss=0.0563]
Epoch 10/10: 100%|██████████| 1439/1439 [03:17<00:00,  7.28batch/s, loss=0.0587]


Dropout: 0.1, LR: 0.005, Validation Loss: 0.022899432797216894, Validation Accuracy: 99.24994932090006%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:16<00:00,  7.31batch/s, loss=0.796]
Epoch 2/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.49batch/s, loss=0.265]
Epoch 3/10: 100%|██████████| 1439/1439 [03:35<00:00,  6.66batch/s, loss=0.209]
Epoch 4/10: 100%|██████████| 1439/1439 [03:35<00:00,  6.67batch/s, loss=0.187] 
Epoch 5/10: 100%|██████████| 1439/1439 [03:32<00:00,  6.76batch/s, loss=0.165] 
Epoch 6/10: 100%|██████████| 1439/1439 [03:34<00:00,  6.72batch/s, loss=0.156] 
Epoch 7/10: 100%|██████████| 1439/1439 [03:25<00:00,  7.00batch/s, loss=0.156]
Epoch 8/10: 100%|██████████| 1439/1439 [03:15<00:00,  7.35batch/s, loss=0.145]
Epoch 9/10: 100%|██████████| 1439/1439 [03:26<00:00,  6.98batch/s, loss=0.145] 
Epoch 10/10: 100%|██████████| 1439/1439 [03:39<00:00,  6.55batch/s, loss=0.134] 


Dropout: 0.1, LR: 0.01, Validation Loss: 0.04520075789433101, Validation Accuracy: 98.8242448814109%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:44<00:00,  6.42batch/s, loss=1.59]
Epoch 2/10: 100%|██████████| 1439/1439 [03:53<00:00,  6.16batch/s, loss=0.527]
Epoch 3/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.49batch/s, loss=0.312]
Epoch 4/10: 100%|██████████| 1439/1439 [04:05<00:00,  5.85batch/s, loss=0.217]
Epoch 5/10: 100%|██████████| 1439/1439 [03:50<00:00,  6.25batch/s, loss=0.166]
Epoch 6/10: 100%|██████████| 1439/1439 [03:27<00:00,  6.92batch/s, loss=0.133] 
Epoch 7/10: 100%|██████████| 1439/1439 [03:42<00:00,  6.48batch/s, loss=0.111] 
Epoch 8/10: 100%|██████████| 1439/1439 [03:44<00:00,  6.41batch/s, loss=0.0915]
Epoch 9/10: 100%|██████████| 1439/1439 [03:51<00:00,  6.22batch/s, loss=0.0765]
Epoch 10/10: 100%|██████████| 1439/1439 [03:48<00:00,  6.29batch/s, loss=0.0694]


Dropout: 0.2, LR: 0.0001, Validation Loss: 0.022877586364835008, Validation Accuracy: 99.40198662071762%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:53<00:00,  6.16batch/s, loss=0.808]
Epoch 2/10: 100%|██████████| 1439/1439 [03:50<00:00,  6.25batch/s, loss=0.23] 
Epoch 3/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.50batch/s, loss=0.137] 
Epoch 4/10: 100%|██████████| 1439/1439 [03:44<00:00,  6.42batch/s, loss=0.105] 
Epoch 5/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.50batch/s, loss=0.0845]
Epoch 6/10: 100%|██████████| 1439/1439 [03:21<00:00,  7.15batch/s, loss=0.0653]
Epoch 7/10: 100%|██████████| 1439/1439 [03:48<00:00,  6.29batch/s, loss=0.059] 
Epoch 8/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.49batch/s, loss=0.0541]
Epoch 9/10: 100%|██████████| 1439/1439 [03:42<00:00,  6.48batch/s, loss=0.0473]
Epoch 10/10: 100%|██████████| 1439/1439 [03:45<00:00,  6.40batch/s, loss=0.0409]


Dropout: 0.2, LR: 0.0005, Validation Loss: 0.016107218583195165, Validation Accuracy: 99.62497466045004%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:37<00:00,  6.60batch/s, loss=0.694]
Epoch 2/10: 100%|██████████| 1439/1439 [03:42<00:00,  6.48batch/s, loss=0.203]
Epoch 3/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.48batch/s, loss=0.133] 
Epoch 4/10: 100%|██████████| 1439/1439 [03:45<00:00,  6.39batch/s, loss=0.108] 
Epoch 5/10: 100%|██████████| 1439/1439 [03:45<00:00,  6.37batch/s, loss=0.088] 
Epoch 6/10: 100%|██████████| 1439/1439 [03:46<00:00,  6.36batch/s, loss=0.0762]
Epoch 7/10: 100%|██████████| 1439/1439 [03:45<00:00,  6.38batch/s, loss=0.0667]
Epoch 8/10: 100%|██████████| 1439/1439 [03:46<00:00,  6.37batch/s, loss=0.0588]
Epoch 9/10: 100%|██████████| 1439/1439 [03:49<00:00,  6.27batch/s, loss=0.054] 
Epoch 10/10: 100%|██████████| 1439/1439 [03:51<00:00,  6.22batch/s, loss=0.053] 


Dropout: 0.2, LR: 0.001, Validation Loss: 0.014715108135195755, Validation Accuracy: 99.6148388404622%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [04:21<00:00,  5.50batch/s, loss=0.813]
Epoch 2/10: 100%|██████████| 1439/1439 [04:17<00:00,  5.58batch/s, loss=0.277]
Epoch 3/10: 100%|██████████| 1439/1439 [04:03<00:00,  5.91batch/s, loss=0.193]
Epoch 4/10: 100%|██████████| 1439/1439 [04:07<00:00,  5.82batch/s, loss=0.156]
Epoch 5/10: 100%|██████████| 1439/1439 [04:20<00:00,  5.53batch/s, loss=0.134] 
Epoch 6/10: 100%|██████████| 1439/1439 [04:25<00:00,  5.41batch/s, loss=0.119] 
Epoch 7/10: 100%|██████████| 1439/1439 [04:10<00:00,  5.75batch/s, loss=0.11]  
Epoch 8/10: 100%|██████████| 1439/1439 [03:48<00:00,  6.29batch/s, loss=0.1]   
Epoch 9/10: 100%|██████████| 1439/1439 [03:45<00:00,  6.39batch/s, loss=0.0955]
Epoch 10/10: 100%|██████████| 1439/1439 [03:31<00:00,  6.79batch/s, loss=0.0913]


Dropout: 0.2, LR: 0.005, Validation Loss: 0.027726741982543993, Validation Accuracy: 99.16886276099737%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [04:06<00:00,  5.83batch/s, loss=0.999]
Epoch 2/10: 100%|██████████| 1439/1439 [03:43<00:00,  6.43batch/s, loss=0.423]
Epoch 3/10: 100%|██████████| 1439/1439 [03:40<00:00,  6.53batch/s, loss=0.319]
Epoch 4/10: 100%|██████████| 1439/1439 [03:49<00:00,  6.27batch/s, loss=0.286]
Epoch 5/10: 100%|██████████| 1439/1439 [03:39<00:00,  6.56batch/s, loss=0.257]
Epoch 6/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.49batch/s, loss=0.248]
Epoch 7/10: 100%|██████████| 1439/1439 [03:45<00:00,  6.38batch/s, loss=0.232]
Epoch 8/10: 100%|██████████| 1439/1439 [03:39<00:00,  6.54batch/s, loss=0.221]
Epoch 9/10: 100%|██████████| 1439/1439 [03:40<00:00,  6.51batch/s, loss=0.221]
Epoch 10/10: 100%|██████████| 1439/1439 [03:49<00:00,  6.28batch/s, loss=0.2]  


Dropout: 0.2, LR: 0.01, Validation Loss: 0.055638915639848546, Validation Accuracy: 98.5505777417393%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:39<00:00,  6.55batch/s, loss=1.99]
Epoch 2/10: 100%|██████████| 1439/1439 [03:24<00:00,  7.05batch/s, loss=0.825]
Epoch 3/10: 100%|██████████| 1439/1439 [03:32<00:00,  6.78batch/s, loss=0.536]
Epoch 4/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.19batch/s, loss=0.392]
Epoch 5/10: 100%|██████████| 1439/1439 [03:19<00:00,  7.21batch/s, loss=0.314]
Epoch 6/10: 100%|██████████| 1439/1439 [03:18<00:00,  7.25batch/s, loss=0.253]
Epoch 7/10: 100%|██████████| 1439/1439 [03:29<00:00,  6.87batch/s, loss=0.215]
Epoch 8/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.18batch/s, loss=0.185]
Epoch 9/10: 100%|██████████| 1439/1439 [03:16<00:00,  7.33batch/s, loss=0.157] 
Epoch 10/10: 100%|██████████| 1439/1439 [03:21<00:00,  7.12batch/s, loss=0.142] 


Dropout: 0.3, LR: 0.0001, Validation Loss: 0.038922868408511384, Validation Accuracy: 98.84451652138658%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.17batch/s, loss=1.13] 
Epoch 2/10: 100%|██████████| 1439/1439 [03:24<00:00,  7.03batch/s, loss=0.388]
Epoch 3/10: 100%|██████████| 1439/1439 [03:21<00:00,  7.15batch/s, loss=0.25] 
Epoch 4/10: 100%|██████████| 1439/1439 [03:22<00:00,  7.12batch/s, loss=0.187]
Epoch 5/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.19batch/s, loss=0.151]
Epoch 6/10: 100%|██████████| 1439/1439 [03:35<00:00,  6.69batch/s, loss=0.132] 
Epoch 7/10: 100%|██████████| 1439/1439 [03:22<00:00,  7.10batch/s, loss=0.115] 
Epoch 8/10: 100%|██████████| 1439/1439 [03:28<00:00,  6.89batch/s, loss=0.102] 
Epoch 9/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.17batch/s, loss=0.0915]
Epoch 10/10: 100%|██████████| 1439/1439 [03:11<00:00,  7.51batch/s, loss=0.0903]


Dropout: 0.3, LR: 0.0005, Validation Loss: 0.020920142484042368, Validation Accuracy: 99.46280154064463%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:35<00:00,  9.24batch/s, loss=0.931]
Epoch 2/10: 100%|██████████| 1439/1439 [02:49<00:00,  8.48batch/s, loss=0.323]
Epoch 3/10: 100%|██████████| 1439/1439 [02:37<00:00,  9.14batch/s, loss=0.229]
Epoch 4/10: 100%|██████████| 1439/1439 [02:46<00:00,  8.62batch/s, loss=0.184] 
Epoch 5/10: 100%|██████████| 1439/1439 [02:55<00:00,  8.18batch/s, loss=0.156] 
Epoch 6/10: 100%|██████████| 1439/1439 [02:46<00:00,  8.63batch/s, loss=0.134] 
Epoch 7/10: 100%|██████████| 1439/1439 [02:49<00:00,  8.48batch/s, loss=0.127] 
Epoch 8/10: 100%|██████████| 1439/1439 [02:42<00:00,  8.84batch/s, loss=0.119] 
Epoch 9/10: 100%|██████████| 1439/1439 [02:46<00:00,  8.66batch/s, loss=0.104] 
Epoch 10/10: 100%|██████████| 1439/1439 [02:48<00:00,  8.53batch/s, loss=0.1]   


Dropout: 0.3, LR: 0.001, Validation Loss: 0.01679817287595584, Validation Accuracy: 99.49320900060815%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:47<00:00,  8.61batch/s, loss=1.11] 
Epoch 2/10: 100%|██████████| 1439/1439 [02:49<00:00,  8.49batch/s, loss=0.442]
Epoch 3/10: 100%|██████████| 1439/1439 [02:49<00:00,  8.50batch/s, loss=0.319]
Epoch 4/10: 100%|██████████| 1439/1439 [02:51<00:00,  8.41batch/s, loss=0.261]
Epoch 5/10: 100%|██████████| 1439/1439 [02:51<00:00,  8.39batch/s, loss=0.224]
Epoch 6/10: 100%|██████████| 1439/1439 [02:53<00:00,  8.30batch/s, loss=0.201]
Epoch 7/10: 100%|██████████| 1439/1439 [02:54<00:00,  8.26batch/s, loss=0.185]
Epoch 8/10: 100%|██████████| 1439/1439 [02:55<00:00,  8.20batch/s, loss=0.176]
Epoch 9/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.13batch/s, loss=0.16] 
Epoch 10/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.10batch/s, loss=0.155] 


Dropout: 0.3, LR: 0.005, Validation Loss: 0.031036254780460355, Validation Accuracy: 99.0168254611798%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:55<00:00,  8.21batch/s, loss=1.33]
Epoch 2/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.17batch/s, loss=0.642]
Epoch 3/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.15batch/s, loss=0.495]
Epoch 4/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.13batch/s, loss=0.441]
Epoch 5/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.16batch/s, loss=0.406]
Epoch 6/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.17batch/s, loss=0.371]
Epoch 7/10: 100%|██████████| 1439/1439 [02:58<00:00,  8.05batch/s, loss=0.374]
Epoch 8/10: 100%|██████████| 1439/1439 [03:04<00:00,  7.82batch/s, loss=0.345]
Epoch 9/10: 100%|██████████| 1439/1439 [03:03<00:00,  7.83batch/s, loss=0.329]
Epoch 10/10: 100%|██████████| 1439/1439 [03:04<00:00,  7.80batch/s, loss=0.322]


Dropout: 0.3, LR: 0.01, Validation Loss: 0.07334741837729276, Validation Accuracy: 97.92215690249341%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:54<00:00,  8.25batch/s, loss=2.44]
Epoch 2/10: 100%|██████████| 1439/1439 [02:59<00:00,  8.04batch/s, loss=1.27] 
Epoch 3/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.10batch/s, loss=0.881]
Epoch 4/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.09batch/s, loss=0.681]
Epoch 5/10: 100%|██████████| 1439/1439 [03:01<00:00,  7.95batch/s, loss=0.563]
Epoch 6/10: 100%|██████████| 1439/1439 [2:44:42<00:00,  6.87s/batch, loss=0.481]    
Epoch 7/10: 100%|██████████| 1439/1439 [4:47:00<00:00, 11.97s/batch, loss=0.412]    
Epoch 8/10: 100%|██████████| 1439/1439 [02:15<00:00, 10.61batch/s, loss=0.362]
Epoch 9/10: 100%|██████████| 1439/1439 [02:18<00:00, 10.41batch/s, loss=0.321]
Epoch 10/10: 100%|██████████| 1439/1439 [02:40<00:00,  8.97batch/s, loss=0.292]


Dropout: 0.4, LR: 0.0001, Validation Loss: 0.08977317958648247, Validation Accuracy: 97.39509426312588%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:44<00:00,  8.76batch/s, loss=1.44]
Epoch 2/10: 100%|██████████| 1439/1439 [02:51<00:00,  8.40batch/s, loss=0.595]
Epoch 3/10: 100%|██████████| 1439/1439 [02:55<00:00,  8.21batch/s, loss=0.404]
Epoch 4/10: 100%|██████████| 1439/1439 [02:59<00:00,  8.03batch/s, loss=0.317]
Epoch 5/10: 100%|██████████| 1439/1439 [03:05<00:00,  7.77batch/s, loss=0.267]
Epoch 6/10: 100%|██████████| 1439/1439 [03:05<00:00,  7.74batch/s, loss=0.229]
Epoch 7/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.16batch/s, loss=0.213]
Epoch 8/10: 100%|██████████| 1439/1439 [02:47<00:00,  8.61batch/s, loss=0.185]
Epoch 9/10: 100%|██████████| 1439/1439 [02:51<00:00,  8.37batch/s, loss=0.175]
Epoch 10/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.13batch/s, loss=0.157]


Dropout: 0.4, LR: 0.0005, Validation Loss: 0.03115209849812407, Validation Accuracy: 99.13845530103386%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:56<00:00,  8.18batch/s, loss=1.28]
Epoch 2/10: 100%|██████████| 1439/1439 [02:58<00:00,  8.07batch/s, loss=0.557]
Epoch 3/10: 100%|██████████| 1439/1439 [02:58<00:00,  8.07batch/s, loss=0.405]
Epoch 4/10: 100%|██████████| 1439/1439 [02:58<00:00,  8.05batch/s, loss=0.33] 
Epoch 5/10: 100%|██████████| 1439/1439 [03:01<00:00,  7.95batch/s, loss=0.28] 
Epoch 6/10: 100%|██████████| 1439/1439 [03:07<00:00,  7.68batch/s, loss=0.25] 
Epoch 7/10: 100%|██████████| 1439/1439 [03:11<00:00,  7.50batch/s, loss=0.224]
Epoch 8/10: 100%|██████████| 1439/1439 [03:06<00:00,  7.72batch/s, loss=0.2]  
Epoch 9/10: 100%|██████████| 1439/1439 [03:06<00:00,  7.72batch/s, loss=0.195]
Epoch 10/10: 100%|██████████| 1439/1439 [03:09<00:00,  7.60batch/s, loss=0.174]


Dropout: 0.4, LR: 0.001, Validation Loss: 0.03141675984771759, Validation Accuracy: 99.13845530103386%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:17<00:00,  7.30batch/s, loss=1.46]
Epoch 2/10: 100%|██████████| 1439/1439 [03:12<00:00,  7.49batch/s, loss=0.706]
Epoch 3/10: 100%|██████████| 1439/1439 [03:04<00:00,  7.80batch/s, loss=0.535]
Epoch 4/10: 100%|██████████| 1439/1439 [03:03<00:00,  7.86batch/s, loss=0.439]
Epoch 5/10: 100%|██████████| 1439/1439 [03:42<00:00,  6.46batch/s, loss=0.389]
Epoch 6/10: 100%|██████████| 1439/1439 [03:57<00:00,  6.06batch/s, loss=0.357]
Epoch 7/10: 100%|██████████| 1439/1439 [03:15<00:00,  7.36batch/s, loss=0.327]
Epoch 8/10:  78%|███████▊  | 1124/1439 [03:05<01:34,  3.32batch/s, loss=0.241]

In [7]:
# Initialize the model, loss function, and optimizer

p_dropouts = [0.4, 0.5]
lrs = [0.0001, 0.0005, 0.001, 0.005, 0.01]
for p_dropout in p_dropouts:
    for lr in lrs: 
        num_classes = get_num_classes(train_dataset)
        model = PaperCNN(num_classes=num_classes, p_dropout = p_dropout)
        device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
        model.to(device)
        #print("Model initialized")
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        #print("Optimizer initialized")
        
        # Initialize lists to track metrics
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        
        
        # Training loop with tqdm progress bar
        num_epochs = 10
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            correct_train = 0
            total_train = 0
            train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
            for images, labels in train_loader_tqdm:
                # move images and labels to device
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                train_loader_tqdm.set_postfix(loss=running_loss/len(train_loader))
                
                # Calculate training accuracy
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()
        
            epoch_loss = running_loss / len(train_loader)
            train_losses.append(epoch_loss)
            train_accuracy = 100 * correct_train / total_train
            train_accuracies.append(train_accuracy)
            
            #print(f"Epoch {epoch+1}, Loss: {epoch_loss}, Accuracy: {train_accuracy}%")
        
            # Validation loop
            model.eval()
            running_val_loss = 0.0
            correct_val = 0
            total_val = 0
            all_labels = []
            all_predictions = []
            with torch.no_grad():
                for images, labels in val_loader:
                    # move images and labels to device
                    images, labels = images.to(device), labels.to(device)
                    
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    running_val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total_val += labels.size(0)
                    correct_val += (predicted == labels).sum().item()
                    all_labels.extend(labels.cpu().numpy())
                    all_predictions.extend(predicted.cpu().numpy())
            
            epoch_val_loss = running_val_loss / len(val_loader)
            val_losses.append(epoch_val_loss)
            val_accuracy = 100 * correct_val / total_val
            val_accuracies.append(val_accuracy)
        final_train_loss = train_losses[-1]
        final_val_loss = val_losses[-1]
        final_train_accuracy = train_accuracies[-1]
        final_val_accuracy = val_accuracies[-1]
        print(f"Dropout: {p_dropout}, LR: {lr}, Validation Loss: {final_val_loss}, Validation Accuracy: {final_val_accuracy}%")
        
        
        print("Training complete")

Epoch 1/10: 100%|██████████| 1439/1439 [06:08<00:00,  3.91batch/s, loss=2.38]
Epoch 2/10: 100%|██████████| 1439/1439 [05:08<00:00,  4.66batch/s, loss=1.24] 
Epoch 3/10: 100%|██████████| 1439/1439 [05:04<00:00,  4.72batch/s, loss=0.862]
Epoch 4/10: 100%|██████████| 1439/1439 [04:57<00:00,  4.83batch/s, loss=0.678]
Epoch 5/10: 100%|██████████| 1439/1439 [04:47<00:00,  5.01batch/s, loss=0.555]
Epoch 6/10: 100%|██████████| 1439/1439 [04:31<00:00,  5.30batch/s, loss=0.473]
Epoch 7/10: 100%|██████████| 1439/1439 [04:01<00:00,  5.95batch/s, loss=0.405]
Epoch 8/10: 100%|██████████| 1439/1439 [03:56<00:00,  6.07batch/s, loss=0.35] 
Epoch 9/10: 100%|██████████| 1439/1439 [03:40<00:00,  6.54batch/s, loss=0.317]
Epoch 10/10: 100%|██████████| 1439/1439 [03:42<00:00,  6.46batch/s, loss=0.284]


Dropout: 0.4, LR: 0.0001, Validation Loss: 0.09374710329142204, Validation Accuracy: 97.30387188323536%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:35<00:00,  6.67batch/s, loss=1.44]
Epoch 2/10: 100%|██████████| 1439/1439 [03:30<00:00,  6.84batch/s, loss=0.609]
Epoch 3/10: 100%|██████████| 1439/1439 [03:33<00:00,  6.73batch/s, loss=0.415]
Epoch 4/10: 100%|██████████| 1439/1439 [03:30<00:00,  6.83batch/s, loss=0.323]
Epoch 5/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.18batch/s, loss=0.276]
Epoch 6/10: 100%|██████████| 1439/1439 [03:37<00:00,  6.61batch/s, loss=0.234]
Epoch 7/10: 100%|██████████| 1439/1439 [03:48<00:00,  6.29batch/s, loss=0.21]  
Epoch 8/10: 100%|██████████| 1439/1439 [03:40<00:00,  6.53batch/s, loss=0.197]
Epoch 9/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.17batch/s, loss=0.177]
Epoch 10/10: 100%|██████████| 1439/1439 [03:27<00:00,  6.95batch/s, loss=0.167]


Dropout: 0.4, LR: 0.0005, Validation Loss: 0.032358311380702, Validation Accuracy: 99.08777620109467%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:09<00:00,  7.58batch/s, loss=1.3] 
Epoch 2/10: 100%|██████████| 1439/1439 [03:10<00:00,  7.55batch/s, loss=0.551]
Epoch 3/10: 100%|██████████| 1439/1439 [03:33<00:00,  6.73batch/s, loss=0.403]
Epoch 4/10: 100%|██████████| 1439/1439 [03:52<00:00,  6.20batch/s, loss=0.32] 
Epoch 5/10: 100%|██████████| 1439/1439 [03:39<00:00,  6.54batch/s, loss=0.278]
Epoch 6/10: 100%|██████████| 1439/1439 [03:25<00:00,  7.02batch/s, loss=0.248]
Epoch 7/10: 100%|██████████| 1439/1439 [03:31<00:00,  6.80batch/s, loss=0.224]
Epoch 8/10: 100%|██████████| 1439/1439 [03:25<00:00,  6.99batch/s, loss=0.205]
Epoch 9/10: 100%|██████████| 1439/1439 [03:36<00:00,  6.65batch/s, loss=0.192]
Epoch 10/10: 100%|██████████| 1439/1439 [03:46<00:00,  6.36batch/s, loss=0.177]


Dropout: 0.4, LR: 0.001, Validation Loss: 0.03769684836615882, Validation Accuracy: 98.81410906142307%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:56<00:00,  6.10batch/s, loss=1.47]
Epoch 2/10: 100%|██████████| 1439/1439 [03:59<00:00,  6.02batch/s, loss=0.692]
Epoch 3/10: 100%|██████████| 1439/1439 [03:56<00:00,  6.09batch/s, loss=0.512]
Epoch 4/10: 100%|██████████| 1439/1439 [03:41<00:00,  6.49batch/s, loss=0.43] 
Epoch 5/10: 100%|██████████| 1439/1439 [03:34<00:00,  6.70batch/s, loss=0.381]
Epoch 6/10: 100%|██████████| 1439/1439 [03:35<00:00,  6.67batch/s, loss=0.336]
Epoch 7/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.19batch/s, loss=0.311]
Epoch 8/10: 100%|██████████| 1439/1439 [03:32<00:00,  6.78batch/s, loss=0.297]
Epoch 9/10: 100%|██████████| 1439/1439 [03:38<00:00,  6.58batch/s, loss=0.278]
Epoch 10/10: 100%|██████████| 1439/1439 [03:29<00:00,  6.86batch/s, loss=0.264]


Dropout: 0.4, LR: 0.005, Validation Loss: 0.05996843089429062, Validation Accuracy: 98.22623150212853%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:04<00:00,  7.79batch/s, loss=1.8] 
Epoch 2/10: 100%|██████████| 1439/1439 [03:38<00:00,  6.59batch/s, loss=1.04] 
Epoch 3/10: 100%|██████████| 1439/1439 [03:43<00:00,  6.44batch/s, loss=0.846]
Epoch 4/10: 100%|██████████| 1439/1439 [03:12<00:00,  7.49batch/s, loss=0.746]
Epoch 5/10: 100%|██████████| 1439/1439 [03:28<00:00,  6.89batch/s, loss=0.672]
Epoch 6/10: 100%|██████████| 1439/1439 [03:54<00:00,  6.13batch/s, loss=0.63] 
Epoch 7/10: 100%|██████████| 1439/1439 [03:53<00:00,  6.18batch/s, loss=0.617]
Epoch 8/10: 100%|██████████| 1439/1439 [04:05<00:00,  5.86batch/s, loss=0.578]
Epoch 9/10: 100%|██████████| 1439/1439 [04:04<00:00,  5.89batch/s, loss=0.545]
Epoch 10/10: 100%|██████████| 1439/1439 [03:39<00:00,  6.55batch/s, loss=0.543]


Dropout: 0.4, LR: 0.01, Validation Loss: 0.14017892329769757, Validation Accuracy: 95.84431380498683%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [03:38<00:00,  6.60batch/s, loss=2.79]
Epoch 2/10: 100%|██████████| 1439/1439 [03:25<00:00,  7.00batch/s, loss=1.77]
Epoch 3/10: 100%|██████████| 1439/1439 [03:10<00:00,  7.56batch/s, loss=1.31] 
Epoch 4/10: 100%|██████████| 1439/1439 [03:33<00:00,  6.74batch/s, loss=1.05] 
Epoch 5/10: 100%|██████████| 1439/1439 [03:50<00:00,  6.24batch/s, loss=0.905]
Epoch 6/10: 100%|██████████| 1439/1439 [03:36<00:00,  6.65batch/s, loss=0.786]
Epoch 7/10: 100%|██████████| 1439/1439 [03:43<00:00,  6.43batch/s, loss=0.7]  
Epoch 8/10: 100%|██████████| 1439/1439 [03:58<00:00,  6.04batch/s, loss=0.632]
Epoch 9/10: 100%|██████████| 1439/1439 [03:16<00:00,  7.32batch/s, loss=0.575]
Epoch 10/10: 100%|██████████| 1439/1439 [03:55<00:00,  6.12batch/s, loss=0.52] 


Dropout: 0.5, LR: 0.0001, Validation Loss: 0.21728032804471004, Validation Accuracy: 93.59416176768701%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [04:03<00:00,  5.92batch/s, loss=1.84]
Epoch 2/10: 100%|██████████| 1439/1439 [03:34<00:00,  6.72batch/s, loss=0.923]
Epoch 3/10: 100%|██████████| 1439/1439 [03:30<00:00,  6.83batch/s, loss=0.679]
Epoch 4/10: 100%|██████████| 1439/1439 [03:20<00:00,  7.17batch/s, loss=0.552]
Epoch 5/10: 100%|██████████| 1439/1439 [09:21<00:00,  2.56batch/s, loss=0.474]  
Epoch 6/10: 100%|██████████| 1439/1439 [02:16<00:00, 10.58batch/s, loss=0.416]
Epoch 7/10: 100%|██████████| 1439/1439 [02:19<00:00, 10.33batch/s, loss=0.378]
Epoch 8/10: 100%|██████████| 1439/1439 [02:25<00:00,  9.90batch/s, loss=0.348]
Epoch 9/10: 100%|██████████| 1439/1439 [02:24<00:00,  9.99batch/s, loss=0.328]
Epoch 10/10: 100%|██████████| 1439/1439 [02:26<00:00,  9.86batch/s, loss=0.307]


Dropout: 0.5, LR: 0.0005, Validation Loss: 0.0715509975335823, Validation Accuracy: 97.94242854246909%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:24<00:00,  9.94batch/s, loss=1.66]
Epoch 2/10: 100%|██████████| 1439/1439 [02:28<00:00,  9.72batch/s, loss=0.848]
Epoch 3/10: 100%|██████████| 1439/1439 [02:27<00:00,  9.75batch/s, loss=0.642]
Epoch 4/10: 100%|██████████| 1439/1439 [02:29<00:00,  9.64batch/s, loss=0.537]
Epoch 5/10: 100%|██████████| 1439/1439 [02:30<00:00,  9.54batch/s, loss=0.469]
Epoch 6/10: 100%|██████████| 1439/1439 [02:29<00:00,  9.62batch/s, loss=0.423]
Epoch 7/10: 100%|██████████| 1439/1439 [02:32<00:00,  9.45batch/s, loss=0.385]
Epoch 8/10: 100%|██████████| 1439/1439 [02:30<00:00,  9.58batch/s, loss=0.353]
Epoch 9/10: 100%|██████████| 1439/1439 [02:33<00:00,  9.36batch/s, loss=0.335]
Epoch 10/10: 100%|██████████| 1439/1439 [02:32<00:00,  9.41batch/s, loss=0.315]


Dropout: 0.5, LR: 0.001, Validation Loss: 0.07225743699238657, Validation Accuracy: 97.66876140279749%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:34<00:00,  9.28batch/s, loss=1.86]
Epoch 2/10: 100%|██████████| 1439/1439 [02:37<00:00,  9.15batch/s, loss=1.05] 
Epoch 3/10: 100%|██████████| 1439/1439 [02:37<00:00,  9.12batch/s, loss=0.829]
Epoch 4/10: 100%|██████████| 1439/1439 [02:37<00:00,  9.13batch/s, loss=0.714]
Epoch 5/10: 100%|██████████| 1439/1439 [02:38<00:00,  9.10batch/s, loss=0.64] 
Epoch 6/10: 100%|██████████| 1439/1439 [02:42<00:00,  8.85batch/s, loss=0.593]
Epoch 7/10: 100%|██████████| 1439/1439 [02:35<00:00,  9.25batch/s, loss=0.55] 
Epoch 8/10: 100%|██████████| 1439/1439 [02:35<00:00,  9.24batch/s, loss=0.51] 
Epoch 9/10: 100%|██████████| 1439/1439 [02:37<00:00,  9.15batch/s, loss=0.495]
Epoch 10/10: 100%|██████████| 1439/1439 [02:34<00:00,  9.34batch/s, loss=0.483]


Dropout: 0.5, LR: 0.005, Validation Loss: 0.12832137884144162, Validation Accuracy: 96.0571660247314%
Training complete


Epoch 1/10: 100%|██████████| 1439/1439 [02:31<00:00,  9.53batch/s, loss=2.29]
Epoch 2/10: 100%|██████████| 1439/1439 [02:33<00:00,  9.37batch/s, loss=1.5] 
Epoch 3/10: 100%|██████████| 1439/1439 [02:33<00:00,  9.36batch/s, loss=1.28] 
Epoch 4/10: 100%|██████████| 1439/1439 [02:34<00:00,  9.29batch/s, loss=1.16] 
Epoch 5/10: 100%|██████████| 1439/1439 [02:33<00:00,  9.40batch/s, loss=1.08] 
Epoch 6/10: 100%|██████████| 1439/1439 [02:34<00:00,  9.32batch/s, loss=1.01] 
Epoch 7/10: 100%|██████████| 1439/1439 [02:38<00:00,  9.05batch/s, loss=0.965]
Epoch 8/10: 100%|██████████| 1439/1439 [02:57<00:00,  8.11batch/s, loss=0.93] 
Epoch 9/10: 100%|██████████| 1439/1439 [03:05<00:00,  7.77batch/s, loss=0.909]
Epoch 10/10: 100%|██████████| 1439/1439 [03:11<00:00,  7.52batch/s, loss=0.878]


Dropout: 0.5, LR: 0.01, Validation Loss: 0.30431127762813787, Validation Accuracy: 90.22906953172512%
Training complete


In [None]:
# Plot training and validation loss
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()

# Plot training and validation accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy Over Epochs')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Generate confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
cmd = ConfusionMatrixDisplay(cm, display_labels=[i for i in range(num_classes)])
fig, ax = plt.subplots(figsize=(10, 10))  # Adjust the size
cmd.plot(ax=ax, xticks_rotation='vertical')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# debugging for mps speedup