In [1]:
# Import statements
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_curve, roc_auc_score, precision_recall_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import seaborn as sns
from sklearn.preprocessing import label_binarize

In [2]:
# Details about the cuda device
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"Device name: {torch.cuda.get_device_name(0)}")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(0)} bytes")
    print(f"Memory Cached: {torch.cuda.memory_reserved(0)} bytes")
    print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory} bytes")
else:
    print("CUDA is not available. Using CPU instead.")

True
Using device: cuda
Device name: NVIDIA GeForce RTX 3060 Laptop GPU
Memory Allocated: 0 bytes
Memory Cached: 0 bytes
Total Memory: 6441926656 bytes


In [3]:
class ImageDataset(Dataset):
    def __init__(self, txt_loc, transform=None):
        print("Initializing image dataset.")
        self.image_labels = []
        self.image_paths = self.find_full_paths(txt_loc)
        self.transform = transform
        print(f"Total images found: {len(self.image_paths)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path, label = self.image_paths[index]
        try:
            image = Image.open(img_path)
            if self.transform:
                image = self.transform(image)
            # print(f"Loaded image {index + 1}/{len(self.image_paths)}: {img_path}")
            return image, torch.tensor(label, dtype=torch.long)
        except IOError as e:
            print(f"Error loading image {img_path}: {e}")
            return None, None

    def find_file_by_suffix(self, directory, filename_suffix):
        for root, dirs, files in os.walk(directory):
            for filename in files:
                if filename.endswith(filename_suffix):
                    return os.path.join(root, filename)
        return None

    def find_full_paths(self, txt_loc):
        data_dirs = os.listdir("color/")
        final_paths = []
        with open(txt_loc, 'r') as infile:
            lines = [line.strip() for line in infile.readlines()]
        
        for line in lines:
            parts = line.rsplit(' ', 1) 
            filename = parts[0]
            label = int(parts[1]) - 1
            directory_index = label
            file_location = f'color/{data_dirs[directory_index]}/'
            full_path = self.find_file_by_suffix(file_location, filename)
            if full_path:
                final_paths.append((full_path, label))
            else:
                print(f"File not found: {filename} in {file_location}")
        return final_paths


In [4]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=31):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
model = AlexNet().to(device)
print(model)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU(inplace=True)
    (2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2)
    (5): ReLU(inplace=True)
    (6): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
    (11): ReLU(inplace=True)
    (12): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=

In [6]:
transform = transforms.Compose([
    transforms.CenterCrop(227), #alexnet size
    transforms.ToTensor(),
])

train_dataset = ImageDataset(txt_loc='./train.txt', transform=transform)
print(len(train_dataset))
test_dataset = ImageDataset(txt_loc='./test.txt', transform=transform)
print(len(test_dataset))

batch_sizes = [100]
learning_rates = [0.001, 0.0005, 0.0001]
epochs_list = [15]

def get_loader(dataset, batch_size):
    if batch_size > len(dataset):
        batch_size = len(dataset)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# Results 
results = []


Initializing image dataset.
Total images found: 34011
34011
Initializing image dataset.
Total images found: 8498
8498


In [17]:
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Epoch 1/10, Batch Size 1, LR 0.001:   6%|▌         | 2088/34011 [00:58<14:04, 37.79it/s, loss=3.23]

In [18]:
# model.train()
# for epoch in range(15):
#     loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
#     for i, (inputs, labels) in loop:
#         inputs, labels = inputs.to(device), labels.to(device)
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
# 
#         loop.set_description(f'Epoch {epoch + 1}')
#         loop.set_postfix(loss=loss.item())


Epoch 1: 100%|██████████| 1063/1063 [04:20<00:00,  4.07it/s, loss=2.33] 
Epoch 2: 100%|██████████| 1063/1063 [02:06<00:00,  8.40it/s, loss=0.774]
Epoch 3: 100%|██████████| 1063/1063 [02:06<00:00,  8.39it/s, loss=0.984]
Epoch 4: 100%|██████████| 1063/1063 [02:02<00:00,  8.69it/s, loss=0.716]
Epoch 5: 100%|██████████| 1063/1063 [02:04<00:00,  8.54it/s, loss=0.832]
Epoch 6: 100%|██████████| 1063/1063 [02:05<00:00,  8.48it/s, loss=0.321]
Epoch 7: 100%|██████████| 1063/1063 [02:37<00:00,  6.74it/s, loss=0.421]
Epoch 8: 100%|██████████| 1063/1063 [02:30<00:00,  7.08it/s, loss=0.591]
Epoch 9: 100%|██████████| 1063/1063 [02:20<00:00,  7.59it/s, loss=0.615]
Epoch 10: 100%|██████████| 1063/1063 [02:18<00:00,  7.66it/s, loss=0.435]
Epoch 11: 100%|██████████| 1063/1063 [02:18<00:00,  7.69it/s, loss=0.294]
Epoch 12: 100%|██████████| 1063/1063 [02:17<00:00,  7.73it/s, loss=0.492]
Epoch 13: 100%|██████████| 1063/1063 [02:19<00:00,  7.65it/s, loss=0.474]
Epoch 14: 100%|██████████| 1063/1063 [02:24<00:

In [7]:
import torch.nn.functional as F
classes = [f'{i}' for i in range(1, 32)]

def plot_confusion_matrix(true_labels, predicted_labels, classes, save_path, file_suffix):
    cm = confusion_matrix(true_labels, predicted_labels)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, cmap=plt.cm.Purples, fmt='g', 
                 xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.savefig(f'{save_path}/confusion_matrix_{file_suffix}.png')
    plt.close()
    
def plot_roc_curve_multiclass(num_classes, true_labels, probabilities, save_path, file_suffix):
    true_labels = label_binarize(true_labels, classes=[i for i in range(num_classes)])
    n_classes = true_labels.shape[1]
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], probabilities[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    plt.figure(figsize=(10, 8))
    colors = iter(plt.cm.rainbow(np.linspace(0, 1, n_classes)))
    for i in range(n_classes):
        plt.plot(fpr[i], tpr[i], color=next(colors), label=f'Class {i} (AUC = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class ROC Curve')
    plt.legend(loc="lower right")
    plt.savefig(f'{save_path}/roc_curve_multiclass_{file_suffix}.png')
    plt.close()

def plot_precision_recall(labels, probabilities, num_classes, save_path, file_suffix):
    labels = label_binarize(labels, classes=[i for i in range(num_classes)])
    
    precision = dict()
    recall = dict()
    pr_auc = dict()

    for i in range(num_classes):
        precision[i], recall[i], _ = precision_recall_curve(labels[:, i], probabilities[:, i])
        pr_auc[i] = auc(recall[i], precision[i])

    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        plt.plot(recall[i], precision[i], lw=2, label=f'Class {i} (PR AUC = {pr_auc[i]:.2f})')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Class-wise Precision-Recall Curves')
    plt.legend(loc="best")
    plt.savefig(f'{save_path}/precision_recall_curves_{file_suffix}.png')
    plt.close()

    
def evaluate_model(model, device, data_loader, classes):
    model.eval()
    correct = 0
    total = 0
    true_labels = []
    predicted_labels = []
    probabilities_list = []

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            probabilities = F.softmax(outputs, dim=1)
            probabilities_list.append(probabilities.cpu())
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total
    f1 = f1_score(true_labels, predicted_labels, average='weighted')
    print(f'Accuracy: {accuracy:.2f}%')
    print(f'F1 Score: {f1:.2f}')

    probabilities = torch.cat(probabilities_list, dim=0)
    true_labels = torch.tensor(true_labels)
        
    return true_labels, predicted_labels, probabilities

In [8]:
plt.figure()
plt.plot([1, 2, 3], [1, 2, 3])
plt.savefig('results/Alexnet/confusion_matrices/test_plot.png')
plt.savefig('results/Alexnet/precision_recall/test_plot.png')
plt.savefig('results/Alexnet/roc_curves/test_plot.png')
plt.close()

for batch_size in batch_sizes:
    train_loader = get_loader(train_dataset, batch_size)
    test_loader = get_loader(test_dataset, batch_size)
    
    for lr in learning_rates:
        for num_epochs in epochs_list:
            model = AlexNet().to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            criterion = nn.CrossEntropyLoss()

            # Training loop
            model.train()
            for epoch in range(num_epochs):
                loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
                for i, (inputs, labels) in loop:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    loop.set_description(f'Epoch {epoch + 1}/{num_epochs}, Batch Size {batch_size}, LR {lr}')
                    loop.set_postfix(loss=loss.item())
                if (epoch + 1) % 5 == 0:
                    true_labels, predicted_labels, probabilities = evaluate_model(model, device, test_loader, classes)
                    plot_confusion_matrix(true_labels, predicted_labels, classes, 'results/AlexNet/confusion_matrices', f'batch_{batch_size}_lr_{lr}_epoch_{epoch + 1}')
                    plot_roc_curve_multiclass(31, true_labels, probabilities, 'results/AlexNet/roc_curves', f'batch_{batch_size}_lr_{lr}_epoch_{epoch + 1}')
                    plot_precision_recall(true_labels, probabilities, 31, 'results/AlexNet/precision_recall', f'batch_{batch_size}_lr_{lr}_epoch_{epoch + 1}')

Epoch 1/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:58<00:00,  2.88it/s, loss=1.41]
Epoch 2/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:51<00:00,  3.07it/s, loss=0.587]
Epoch 3/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:47<00:00,  3.18it/s, loss=0.618]
Epoch 4/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:43<00:00,  3.29it/s, loss=0.527]
Epoch 5/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:43<00:00,  3.28it/s, loss=0.722]


Accuracy: 83.62%
F1 Score: 0.83


Epoch 6/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [02:04<00:00,  2.74it/s, loss=0.427] 
Epoch 7/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [02:00<00:00,  2.82it/s, loss=0.98]  
Epoch 8/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:49<00:00,  3.12it/s, loss=0.0252]
Epoch 9/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:50<00:00,  3.08it/s, loss=0.0506]
Epoch 10/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:51<00:00,  3.07it/s, loss=0.022] 


Accuracy: 88.63%
F1 Score: 0.89


Epoch 11/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:52<00:00,  3.03it/s, loss=0.0258] 
Epoch 12/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:50<00:00,  3.09it/s, loss=0.165]  
Epoch 13/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:48<00:00,  3.13it/s, loss=0.0177] 
Epoch 14/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:47<00:00,  3.17it/s, loss=0.216]  
Epoch 15/15, Batch Size 100, LR 0.001: 100%|██████████| 341/341 [01:46<00:00,  3.20it/s, loss=0.225]  


Accuracy: 88.49%
F1 Score: 0.88


Epoch 1/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:48<00:00,  3.16it/s, loss=2.16]
Epoch 2/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:48<00:00,  3.15it/s, loss=0.939]
Epoch 3/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:44<00:00,  3.25it/s, loss=0.757]
Epoch 4/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:43<00:00,  3.29it/s, loss=0.203]
Epoch 5/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:43<00:00,  3.29it/s, loss=0.102]


Accuracy: 89.41%
F1 Score: 0.89


Epoch 6/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:44<00:00,  3.27it/s, loss=0.159] 
Epoch 7/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:43<00:00,  3.30it/s, loss=0.0185]
Epoch 8/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:43<00:00,  3.29it/s, loss=0.485] 
Epoch 9/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:44<00:00,  3.25it/s, loss=0.0279]
Epoch 10/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:45<00:00,  3.24it/s, loss=0.000475]


Accuracy: 92.36%
F1 Score: 0.92


Epoch 11/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:45<00:00,  3.24it/s, loss=0.0214] 
Epoch 12/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:45<00:00,  3.25it/s, loss=0.205]  
Epoch 13/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:44<00:00,  3.25it/s, loss=0.0245] 
Epoch 14/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:44<00:00,  3.27it/s, loss=0.000691]
Epoch 15/15, Batch Size 100, LR 0.0005: 100%|██████████| 341/341 [01:42<00:00,  3.33it/s, loss=0.00015]


Accuracy: 91.77%
F1 Score: 0.92


Epoch 1/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:46<00:00,  3.21it/s, loss=1.68]
Epoch 2/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:47<00:00,  3.18it/s, loss=0.546]
Epoch 3/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:47<00:00,  3.17it/s, loss=0.582]
Epoch 4/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:46<00:00,  3.20it/s, loss=0.214]
Epoch 5/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:45<00:00,  3.24it/s, loss=0.306]


Accuracy: 88.31%
F1 Score: 0.88


Epoch 6/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:44<00:00,  3.25it/s, loss=0.3]   
Epoch 7/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:45<00:00,  3.23it/s, loss=0.0476]
Epoch 8/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:45<00:00,  3.24it/s, loss=0.155] 
Epoch 9/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:42<00:00,  3.32it/s, loss=0.0157]
Epoch 10/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:46<00:00,  3.21it/s, loss=0.05]  


Accuracy: 91.55%
F1 Score: 0.91


Epoch 11/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:51<00:00,  3.06it/s, loss=0.00833]
Epoch 12/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:46<00:00,  3.20it/s, loss=0.018]  
Epoch 13/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:44<00:00,  3.26it/s, loss=0.00398]
Epoch 14/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:44<00:00,  3.26it/s, loss=0.132]  
Epoch 15/15, Batch Size 100, LR 0.0001: 100%|██████████| 341/341 [01:44<00:00,  3.26it/s, loss=0.000245]


Accuracy: 91.49%
F1 Score: 0.92
