In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import timm
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import PIL
from PIL import Image
import tqdm
from tqdm import tqdm
from sklearn.preprocessing import label_binarize
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score, matthews_corrcoef, roc_auc_score, roc_curve, auc

In [2]:
# Device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# Constants

data_dir = '/kaggle/input/fdata-adni-dataset/AugmentedAlzheimerDataset'

TEST_SIZE = 0.15
VAL_SIZE = 0.15

MEAN = 0.2956
STD = 0.3069

MODEL_TAG = 'hf_hub:timm/vgg16.tv_in1k'
NUM_CLASSES = 4

BATCH_SIZE = 32
LEARNING_RATE = 5e-5
EPOCHS = 75
DECAY = 1e-4

PATIENCE = 5
MIN_DELTA = 0.05

RANDOM_SEED = 43

transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[MEAN] * 3, std=[STD] * 3)
])

In [4]:
# Seeds
torch.manual_seed(RANDOM_SEED)
generator = torch.Generator().manual_seed(RANDOM_SEED) # Surity

In [5]:
# Datasets

master = datasets.ImageFolder(root=data_dir, transform=transform)

total_size = len(master)
test_size = int(TEST_SIZE * total_size)
val_size = int(VAL_SIZE * total_size)
train_size = total_size - test_size - val_size

train_set, val_set, test_set = random_split(master, [train_size, val_size, test_size], generator=generator)

In [6]:
print(f'Train: {len(train_set)}\nVal: {len(val_set)}\nTest: {len(test_set)}')

Train: 23790
Val: 5097
Test: 5097


In [7]:
# Model

model = timm.create_model(MODEL_TAG, pretrained=True, num_classes=NUM_CLASSES)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=DECAY)
criterion = nn.CrossEntropyLoss()

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/553M [00:00<?, ?B/s]

In [None]:
# Training

train_loss_array = []
val_loss_array = []

bad_epochs = 0
min_loss = float('inf')

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

for epoch in range(EPOCHS):
    model.train()
    print(f'Epoch {epoch + 1}/{EPOCHS}:')

    cum_loss = 0.0
    
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        cum_loss += loss.item()
        loss.backward()
        optimizer.step()

    avg_train_loss = cum_loss / len(train_loader)
    train_loss_array.append(avg_train_loss)
    print(f'Training Loss: {avg_train_loss}')

    model.eval()
    with torch.no_grad():
        val_loss = 0.0

        for images, labels in tqdm(val_loader):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_loss_array.append(avg_val_loss)
        print(f'Validation Loss: {avg_val_loss}')

    if avg_val_loss < min_loss * (1.0 + MIN_DELTA):
        if avg_val_loss < min_loss:
            torch.save(model.state_dict(), 'vgg_best.pth')
            min_loss = avg_val_loss
        bad_epochs = 0
    else:
        bad_epochs += 1

    if bad_epochs >= PATIENCE:
        print('Early stopping triggered')
        break

if os.path.exists('/kaggle/working/vgg_best.pth'):
    model.load_state_dict(torch.load('/kaggle/working/vgg_best.pth'))
    
torch.save(model, 'vgg_model.pth')

if os.path.exists('/kaggle/working/vgg_best.pth'):
    os.remove('/kaggle/working/vgg_best.pth')

Epoch 1/75:


100%|██████████| 744/744 [03:50<00:00,  3.23it/s]


Training Loss: 0.38869661370402464


100%|██████████| 160/160 [00:47<00:00,  3.34it/s]


Validation Loss: 0.09810981146874838
Epoch 2/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.07064413629628415


100%|██████████| 160/160 [00:26<00:00,  5.94it/s]


Validation Loss: 0.061846991107449865
Epoch 3/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.04128650788383585


100%|██████████| 160/160 [00:26<00:00,  6.03it/s]


Validation Loss: 0.044911644794456154
Epoch 4/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.026013689051498454


100%|██████████| 160/160 [00:26<00:00,  5.97it/s]


Validation Loss: 0.026643085639716447
Epoch 5/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.025339211278200095


100%|██████████| 160/160 [00:26<00:00,  5.98it/s]


Validation Loss: 0.02114568973911446
Epoch 6/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.02179764711581933


100%|██████████| 160/160 [00:26<00:00,  5.93it/s]


Validation Loss: 0.057912581978757774
Epoch 7/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.01780346246953506


100%|██████████| 160/160 [00:26<00:00,  5.99it/s]


Validation Loss: 0.061882843079729355
Epoch 8/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.019293586270431397


100%|██████████| 160/160 [00:26<00:00,  5.94it/s]


Validation Loss: 0.027131200575723824
Epoch 9/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.020158901801508172


100%|██████████| 160/160 [00:26<00:00,  5.95it/s]


Validation Loss: 0.03976157348916729
Epoch 10/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.015310666214557591


100%|██████████| 160/160 [00:27<00:00,  5.90it/s]


Validation Loss: 0.018700307142353267
Epoch 11/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.012092652490912953


100%|██████████| 160/160 [00:27<00:00,  5.87it/s]


Validation Loss: 0.022261551993307906
Epoch 12/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.018179162954272587


100%|██████████| 160/160 [00:26<00:00,  6.02it/s]


Validation Loss: 0.015617178958251543
Epoch 13/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.011202804730004895


100%|██████████| 160/160 [00:27<00:00,  5.90it/s]


Validation Loss: 0.018494897300263345
Epoch 14/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.009429896817367117


100%|██████████| 160/160 [00:26<00:00,  6.05it/s]


Validation Loss: 0.02453140503848772
Epoch 15/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.012090439873292684


100%|██████████| 160/160 [00:26<00:00,  5.95it/s]


Validation Loss: 0.016122514230897878
Epoch 16/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.013395626139644026


100%|██████████| 160/160 [00:26<00:00,  5.99it/s]


Validation Loss: 0.019506847868615295
Epoch 17/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.008518499785397652


100%|██████████| 160/160 [00:27<00:00,  5.90it/s]


Validation Loss: 0.005620758490437083
Epoch 18/75:


100%|██████████| 744/744 [03:20<00:00,  3.71it/s]


Training Loss: 0.012750803051367426


100%|██████████| 160/160 [00:26<00:00,  5.94it/s]


Validation Loss: 0.010007196611377367
Epoch 19/75:


 19%|█▉        | 144/744 [00:38<02:41,  3.71it/s]

In [None]:
# Plot

t = [i + 1 for i in range(len(train_loss_array))]

plt.figure(figsize=(12, 8))

plt.plot(t, train_loss_array, color='#0f80bd',linestyle='--', marker='o', label='Training Loss')
plt.plot(t, val_loss_array, color='#d68a18', linestyle='--', marker='o', label='Validation Loss')

plt.title(f'Training and Validation Loss\nVGG-16')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.5)

plt.savefig('Loss_VGG.png', dpi=600)

plt.show()

In [None]:
# Testing

csv_path = 'vgg_metrics.csv'
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

model.eval()
y_true = []
y_pred_prob_list = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        y_pred_prob_list.append(probs.cpu().numpy())
        y_true.append(labels.cpu().numpy())

y_pred_prob = np.concatenate(y_pred_prob_list, axis=0)
y_true = np.concatenate(y_true, axis=0)
y_pred = np.argmax(y_pred_prob, axis=1)

y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
n_classes = y_true_bin.shape[1]

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
jaccard = jaccard_score(y_true, y_pred, average='weighted')
mcc = matthews_corrcoef(y_true, y_pred)
auc_score = roc_auc_score(y_true_bin, y_pred_prob, multi_class='ovr')

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Jaccard Index: {jaccard:.4f}")
print(f"MCC: {mcc:.4f}")
print(f"AUC: {auc_score:.4f}")

metrics = {
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'Jaccard Index', 'MCC', 'AUC'],
    'Value': [accuracy, precision, recall, f1, jaccard, mcc, auc_score]
}
metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv(csv_path, index=False)

fpr = dict()
tpr = dict()
roc_auc = dict()

plt.figure(figsize=(10, 8))
plt.plot([0, 1], [0, 1], 'k--', lw=2)

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_prob[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], label=f"Class {i} (AUC = {roc_auc[i]:.2f})")

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'ROC Curve - VGG-16')
plt.legend(loc='lower right')
plt.savefig('ROC_VGG.png', dpi=600)
plt.show()