## Imports

In [1]:
import os
import time
import torch
import numpy as np
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch import nn
from efficientnet_pytorch import EfficientNet
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt


## Logs
##### Function to log the training metrics
###### (Change the path name in `log_path` as required)

In [8]:
import csv

log_path = "logs/trainingFullData_log.csv"
with open(log_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'train_loss', 'train_accuracy', 'val_loss', 'val_accuracy'])

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Data Loader

In [4]:
# Define your transform (modify as needed)
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts HWC NumPy -> CHW Tensor
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

# Load MRI volumes into a list
def load_npy_data(root_dir):
    samples = []
    print(f"Loading data from {root_dir}...")
    start = time.time()
    
    for label in os.listdir(root_dir):
        label_path = os.path.join(root_dir, label)
        if os.path.isdir(label_path):
            for file in tqdm(os.listdir(label_path), desc=label):
                full_path = os.path.join(label_path, file)
                vol = np.load(full_path)  # (20, H, W)
                vol = np.repeat(vol[:, :, :, None], 3, axis=3)  # (20, H, W, 3)
                
                # Apply transforms to each slice
                vol_tensor = torch.stack([transform(slice) for slice in vol])  # (20, 3, H, W)
                label_val = 0 if label == "demented" else 1
                samples.append((vol_tensor, label_val))
    
    print(f"✅ Loaded {len(samples)} samples in {time.time() - start:.2f}s")
    return samples

train_dataset = load_npy_data(r"C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/dataset/train")
val_dataset = load_npy_data(r"C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/dataset/val")

Loading data from C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/dataset/train...


demented: 100%|██████████| 146/146 [00:06<00:00, 21.93it/s]
non-demented: 100%|██████████| 152/152 [00:07<00:00, 21.42it/s]


✅ Loaded 298 samples in 13.77s
Loading data from C:/Users/moksh/OneDrive/Desktop/Alzeimers/Alzeimers-detection/dataset/val...


demented: 100%|██████████| 37/37 [00:01<00:00, 21.36it/s]
non-demented: 100%|██████████| 38/38 [00:01<00:00, 21.67it/s]

✅ Loaded 75 samples in 3.49s





In [5]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)

In [6]:
from collections import Counter

# Check label distribution in training set
train_labels = [label for _, label in train_dataset]
val_labels = [label for _, label in val_dataset]

print("Training Set Label Distribution:", Counter(train_labels))
print("Validation Set Label Distribution:", Counter(val_labels))


Training Set Label Distribution: Counter({1: 152, 0: 146})
Validation Set Label Distribution: Counter({1: 38, 0: 37})


## Load the Model (Efficientnet-B2)

In [7]:
model = EfficientNet.from_pretrained("efficientnet-b2")
feature_dim = model._fc.in_features
model._fc = nn.Identity()

classifier = nn.Sequential(
    nn.Flatten(),
    nn.Linear(feature_dim, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 2)
)

model = nn.Sequential(model, classifier).to(device)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler = GradScaler()


Loaded pretrained weights for efficientnet-b2


  scaler = GradScaler()


## Train from a save
##### Run the below cell if you want to train from a saved state. Change the path inside `torch.load()`

In [None]:
checkpoint = torch.load("checkpoint_epoch_10.pt")  # or checkpoint_epoch_8.pt
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
scaler.load_state_dict(checkpoint['scaler_state'])
start_epoch = checkpoint['epoch'] + 1  # Resume from next epoch

## Training Loop
###### Increase or decrease the patience accrodingly. The model tends to plateau at the beginning and will suddenely generalize well, after which it will start to overfit

In [11]:
def train(num_epochs=25, start_epoch=1, patience=7):
    best_val_acc = 0
    best_val_loss = float('inf')
    epochs_since_improvement = 0

    # Ensure CSV log starts clean
    if not os.path.exists(log_path):
        with open(log_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc'])

    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        epoch_start = time.time()
        total_loss, total_correct, total_samples = 0, 0, 0

        for x, y in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
            x, y = x.to(device), y.to(device)
            B, S, C, H, W = x.shape
            x = x.view(B * S, C, H, W).float()

            optimizer.zero_grad()
            with autocast():
                features = model[0].extract_features(x)
                pooled = nn.AdaptiveAvgPool2d(1)(features).view(B, S, -1)
                mean_features = pooled.mean(dim=1)
                out = model[1](mean_features)
                loss = criterion(out, y)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            total_correct += (out.argmax(1) == y).sum().item()
            total_samples += y.size(0)

        train_loss = total_loss / len(train_loader)
        train_acc = 100 * total_correct / total_samples
        print(f"\n[Epoch {epoch}] Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}% | Time: {time.time() - epoch_start:.1f}s")

        # === Validation ===
        model.eval()
        val_loss, val_correct, val_samples = 0, 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                B, S, C, H, W = x.shape
                x = x.view(B * S, C, H, W).float()

                with autocast():
                    features = model[0].extract_features(x)
                    pooled = nn.AdaptiveAvgPool2d(1)(features).view(B, S, -1)
                    mean_features = pooled.mean(dim=1)
                    out = model[1](mean_features)
                    loss = criterion(out, y)

                val_loss += loss.item()
                val_correct += (out.argmax(1) == y).sum().item()
                val_samples += y.size(0)

        val_loss /= len(val_loader)
        val_acc = 100 * val_correct / val_samples
        print(f"           Val Loss:   {val_loss:.4f} | Acc: {val_acc:.2f}%")

        # === Log to CSV ===
        with open(log_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch, train_loss, train_acc, val_loss, val_acc])

        # === Save Best Accuracy Model ===
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': epoch, 'model_state': model.state_dict()}, 'best_val.pt')
            print("Best val accuracy model saved.")
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1

        # === Save Best Generalization (lowest val loss) ===
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({'epoch': epoch, 'model_state': model.state_dict()}, 'best_generalization.pt')
            print("Best generalization model saved.")

        # === Early Stopping Check ===
        if epochs_since_improvement >= patience:
            print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs).")
            break

    # === Save Final Model ===
    torch.save({'epoch': epoch, 'model_state': model.state_dict()}, 'last_epoch.pt')
    print("Final model saved as last_epoch.pt")

train()

  with autocast():
Epoch 1/25: 100%|██████████| 149/149 [17:52<00:00,  7.20s/it]
  with autocast():



[Epoch 1] Train Loss: 0.6747 | Acc: 57.38% | Time: 1072.2s
           Val Loss:   0.7560 | Acc: 50.67%
Best val accuracy model saved.
Best generalization model saved.


Epoch 2/25: 100%|██████████| 149/149 [17:45<00:00,  7.15s/it]



[Epoch 2] Train Loss: 0.5875 | Acc: 68.79% | Time: 1065.5s
           Val Loss:   0.8076 | Acc: 50.67%


Epoch 3/25: 100%|██████████| 149/149 [17:45<00:00,  7.15s/it]



[Epoch 3] Train Loss: 0.5572 | Acc: 70.13% | Time: 1065.2s
           Val Loss:   0.6494 | Acc: 54.67%
Best val accuracy model saved.
Best generalization model saved.


Epoch 4/25: 100%|██████████| 149/149 [17:32<00:00,  7.07s/it]



[Epoch 4] Train Loss: 0.5041 | Acc: 73.83% | Time: 1052.8s
           Val Loss:   0.8562 | Acc: 50.67%


Epoch 5/25:   2%|▏         | 3/149 [00:26<21:32,  8.86s/it]


KeyboardInterrupt: 

In [None]:
from torchvision import datasets

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], 
                         [0.5, 0.5, 0.5])
])

test_dir = "dataset/test/"

test_dataset = datasets.ImageFolder(
    root=test_dir,
    transform=test_transform
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,      
    shuffle=False,       
    num_workers=2,       
    pin_memory=True      
)


In [None]:
test_dataset = datasets.ImageFolder("dataset/test", transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
from efficientnet_pytorch import EfficientNet
import torch.nn as nn
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Recreate the same model used during training ===
model = EfficientNet.from_pretrained("efficientnet-b2")
feature_dim = model._fc.in_features
model._fc = nn.Identity()

classifier = nn.Sequential(
    nn.Flatten(),
    nn.Linear(feature_dim, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 2)
)

model = nn.Sequential(model, classifier).to(device)



Loaded pretrained weights for efficientnet-b2


In [None]:
import glob
import torch
import numpy as np
from tqdm import tqdm

def evaluate_checkpoint(ckpt_path, model, test_loader, device):
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['model_state'])
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, y in tqdm(test_loader, desc=f"Evaluating {ckpt_path}"):
            x, y = x.to(device), y.to(device)

            out = model(x)
            all_preds.append(out.softmax(dim=1).cpu().numpy())
            all_labels.append(y.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    pred_classes = np.argmax(all_preds, axis=1)
    acc = (pred_classes == all_labels).mean() * 100

    return acc


def check_all_checkpoints(model, test_loader, device, ckpt_paths=None):
    if ckpt_paths is None:
        ckpt_paths = [
            "checkpoints/best_val.pt",
            "checkpoints/best_generalization.pt",
            "checkpoints/last_epoch.pt"
        ]

    print(f"Evaluating {len(ckpt_paths)} checkpoint(s)...")

    best_acc = 0
    best_ckpt = None
    results = []

    for ckpt in ckpt_paths:
        acc = evaluate_checkpoint(ckpt, model, test_loader, device)
        results.append((ckpt, acc))
        print(f"{ckpt}: Test Acc = {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            best_ckpt = ckpt

    print("\n🏆 Best Checkpoint:")
    print(f"{best_ckpt} → {best_acc:.2f}%")

    return results, best_ckpt

# Run this:
check_all_checkpoints(model, test_loader, device)


Evaluating 3 checkpoint(s)...


  checkpoint = torch.load(ckpt_path, map_location=device)
Evaluating checkpoints/best_val.pt: 100%|██████████| 26/26 [00:08<00:00,  3.19it/s]


checkpoints/best_val.pt: Test Acc = 36.39%


Evaluating checkpoints/best_generalization.pt: 100%|██████████| 26/26 [00:07<00:00,  3.46it/s]


checkpoints/best_generalization.pt: Test Acc = 54.82%


Evaluating checkpoints/last_epoch.pt: 100%|██████████| 26/26 [00:08<00:00,  2.90it/s]

checkpoints/last_epoch.pt: Test Acc = 33.46%

🏆 Best Checkpoint:
checkpoints/best_generalization.pt → 54.82%





([('checkpoints/best_val.pt', 36.385836385836384),
  ('checkpoints/best_generalization.pt', 54.82295482295483),
  ('checkpoints/last_epoch.pt', 33.45543345543345)],
 'checkpoints/best_generalization.pt')