In [1]:
# =========================
# Project Configuration
# =========================

import os
import random
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns

# -------------------------
# Reproducibility
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# -------------------------
# Paths
# -------------------------
PROJECT_ROOT = Path("..")
DATA_ROOT = PROJECT_ROOT / "DATASET" / "final_balanced_dataset" / "FINAL_BALANCED_DATASET"
RESULTS_DIR = PROJECT_ROOT / "results"
RESULTS_DIR.mkdir(exist_ok=True)

# -------------------------
# Device
# -------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cuda


In [2]:
assert DATA_ROOT.exists(), "‚ùå final_dataset path does not exist"
print("‚úÖ Dataset root found:", DATA_ROOT)

print("Class folders:", [p.name for p in DATA_ROOT.iterdir() if p.is_dir()])


‚úÖ Dataset root found: ..\DATASET\final_balanced_dataset\FINAL_BALANCED_DATASET
Class folders: ['AD', 'CN', 'LMCI', 'results']


In [3]:
print("Current working directory:", os.getcwd())
print("\nParent directory contents:")
print(os.listdir(".."))


Current working directory: c:\Users\ADMIN\Documents\Alz_work\Notebooks

Parent directory contents:
['DATASET', 'Notebooks', 'Results']


In [4]:
print("\nContents of data folder (if exists):")
if os.path.exists("../data"):
    print(os.listdir("../data"))
else:
    print("‚ùå No data folder found at ../data")



Contents of data folder (if exists):
‚ùå No data folder found at ../data


In [5]:
import os
print("Current working directory:")
print(os.getcwd())


Current working directory:
c:\Users\ADMIN\Documents\Alz_work\Notebooks


In [6]:
print("\nContents of parent directory (..):")
print(os.listdir(".."))



Contents of parent directory (..):
['DATASET', 'Notebooks', 'Results']


In [7]:
print("\nContents of grandparent directory (../..):")
print(os.listdir("../.."))



Contents of grandparent directory (../..):
['Alz_work', 'AM Project', 'desktop.ini', 'Discovery Studio', 'Document.rtf', 'Isro_work', 'My Music', 'My Pictures', 'My Videos', 'Nasa_work']


In [8]:
from pathlib import Path

root = Path("..").resolve()
matches = list(root.rglob("final_balanced_dataset"))

print("Found matches:")
for m in matches:
    print(m)


Found matches:
C:\Users\ADMIN\Documents\Alz_work\DATASET\final_balanced_dataset
C:\Users\ADMIN\Documents\Alz_work\DATASET\FINAL_BALANCED_DATASET\final_balanced_dataset


##### PATIENT-WISE DATASET INDEXING & LEAKAGE-PROOF SPLITS

building a subject-level index

In [9]:
# =========================
# Step 2.1: Subject-wise Indexing
# =========================

data = []

for class_name in ["CN", "LMCI", "AD"]:
    class_dir = DATA_ROOT / class_name
    assert class_dir.exists(), f"Missing class folder: {class_name}"
    
    for subject_dir in class_dir.iterdir():
        if subject_dir.is_dir():
            subject_id = f"{class_name}_{subject_dir.name}"
            
            slices = list(subject_dir.glob("*"))
            if len(slices) == 0:
                continue
            
            data.append({
                "subject_id": subject_id,
                "class": class_name,
                "subject_path": subject_dir,
                "num_slices": len(slices)
            })

subjects_df = pd.DataFrame(data)

print("Total subjects:", len(subjects_df))
print(subjects_df["class"].value_counts())
subjects_df.head()


Total subjects: 933
class
CN      311
LMCI    311
AD      311
Name: count, dtype: int64


Unnamed: 0,subject_id,class,subject_path,num_slices
0,CN_002_S_0295,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,2
1,CN_002_S_0295_aug0,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,50
2,CN_002_S_0295_aug1,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,50
3,CN_002_S_0413,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,1
4,CN_002_S_0413_aug1,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,50


In [10]:
print("DATA_ROOT:", DATA_ROOT)
print("\nTop-level contents:")
for item in DATA_ROOT.iterdir():
    print("-", item.name)


DATA_ROOT: ..\DATASET\final_balanced_dataset\FINAL_BALANCED_DATASET

Top-level contents:
- AD
- CN
- final_folder_splits.json
- LMCI
- results
- subject_splits.json
- unique_subjects.json


CANONICAL SUBJECT ID

We will extract the base subject ID by removing augmentation suffixes.

In [11]:
# =========================
# Step 2.1 (FIXED): Subject-wise Indexing with Augmentation Handling
# =========================

data = []

for class_name in ["CN", "LMCI", "AD"]:
    class_dir = DATA_ROOT / class_name
    assert class_dir.exists(), f"Missing class folder: {class_name}"
    
    for subject_dir in class_dir.iterdir():
        if subject_dir.is_dir():
            
            # Remove augmentation suffixes (_aug0, _aug1, etc.)
            base_subject_id = subject_dir.name.split("_aug")[0]
            subject_id = f"{class_name}_{base_subject_id}"
            
            slices = list(subject_dir.glob("*"))
            if len(slices) == 0:
                continue
            
            data.append({
                "subject_id": subject_id,
                "class": class_name,
                "subject_path": subject_dir,
                "num_slices": len(slices)
            })

subjects_df = pd.DataFrame(data)

print("Total folders (including augmentations):", len(subjects_df))
print("\nUnique subjects after grouping:")
print(subjects_df["subject_id"].nunique())

print("\nClass-wise unique subject counts:")
print(subjects_df.drop_duplicates("subject_id")["class"].value_counts())

subjects_df.head()


Total folders (including augmentations): 933

Unique subjects after grouping:
639

Class-wise unique subject counts:
class
LMCI    311
CN      195
AD      133
Name: count, dtype: int64


Unnamed: 0,subject_id,class,subject_path,num_slices
0,CN_002_S_0295,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,2
1,CN_002_S_0295,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,50
2,CN_002_S_0295,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,50
3,CN_002_S_0413,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,1
4,CN_002_S_0413,CN,..\DATASET\final_balanced_dataset\FINAL_BALANC...,50


FINAL SUBJECT TABLE (CANONICAL)

We will now create one row per REAL subject,
and attach all augmented folders to it.

In [12]:
# =========================
# Step 2.1 FINAL: Canonical subject indexing
# =========================

records = {}

for class_name in ["CN", "LMCI", "AD"]:
    class_dir = DATA_ROOT / class_name
    
    for subject_dir in class_dir.iterdir():
        if subject_dir.is_dir():
            base_id = subject_dir.name.split("_aug")[0]
            subject_id = f"{class_name}_{base_id}"
            
            if subject_id not in records:
                records[subject_id] = {
                    "subject_id": subject_id,
                    "class": class_name,
                    "subject_dirs": [],
                    "total_slices": 0
                }
            
            slices = list(subject_dir.glob("*"))
            records[subject_id]["subject_dirs"].append(subject_dir)
            records[subject_id]["total_slices"] += len(slices)

subjects_df = pd.DataFrame(records.values())

print("Total REAL subjects:", len(subjects_df))
print("\nClass-wise REAL subject counts:")
print(subjects_df["class"].value_counts())

subjects_df.head()


Total REAL subjects: 639

Class-wise REAL subject counts:
class
LMCI    311
CN      195
AD      133
Name: count, dtype: int64


Unnamed: 0,subject_id,class,subject_dirs,total_slices
0,CN_002_S_0295,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...,102
1,CN_002_S_0413,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...,51
2,CN_002_S_0685,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...,101
3,CN_002_S_1261,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...,102
4,CN_002_S_1280,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...,101


PATIENT-WISE TRAIN / VAL / TEST SPLIT (REAL SUBJECTS ONLY)

In [13]:
# =========================
# Step 2.2: Patient-wise Train / Val / Test Split (REAL subjects)
# =========================

train_df, temp_df = train_test_split(
    subjects_df,
    test_size=0.30,
    stratify=subjects_df["class"],
    random_state=SEED
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,
    stratify=temp_df["class"],
    random_state=SEED
)

print("Train subjects:", len(train_df))
print("Validation subjects:", len(val_df))
print("Test subjects:", len(test_df))

print("\nClass distribution:")
print("Train:\n", train_df["class"].value_counts())
print("Val:\n", val_df["class"].value_counts())
print("Test:\n", test_df["class"].value_counts())


Train subjects: 447
Validation subjects: 96
Test subjects: 96

Class distribution:
Train:
 class
LMCI    218
CN      136
AD       93
Name: count, dtype: int64
Val:
 class
LMCI    47
CN      29
AD      20
Name: count, dtype: int64
Test:
 class
LMCI    46
CN      30
AD      20
Name: count, dtype: int64


HARD LEAKAGE CHECK

In [14]:
# =========================
# Leakage Verification
# =========================

assert set(train_df["subject_id"]).isdisjoint(val_df["subject_id"])
assert set(train_df["subject_id"]).isdisjoint(test_df["subject_id"])
assert set(val_df["subject_id"]).isdisjoint(test_df["subject_id"])

print("‚úÖ No subject leakage across splits")


‚úÖ No subject leakage across splits


SAVE & FREEZE SPLITS

In [15]:
splits_dir = RESULTS_DIR / "splits"
splits_dir.mkdir(exist_ok=True)

train_df.to_csv(splits_dir / "train_subjects.csv", index=False)
val_df.to_csv(splits_dir / "val_subjects.csv", index=False)
test_df.to_csv(splits_dir / "test_subjects.csv", index=False)

print("‚úÖ Subject-wise splits saved at:", splits_dir)


‚úÖ Subject-wise splits saved at: ..\results\splits


FROZEN PREPROCESSING + SPLIT-AWARE AUGMENTATION

Create a single, frozen preprocessing + augmentation pipeline such that:

‚úÖ Same preprocessing everywhere (train / val / test)

‚úÖ Augmentation applied only to TRAIN

‚ùå No augmentation leakage into val/test

‚úÖ All augmentations of a subject stay in the same split

‚úÖ Ready for CNN + ViT simultaneously


What we freeze now

Image size: 224 √ó 224

Normalization: ImageNet mean/std
(because you use ImageNet-pretrained CNN + ViT)

Augmentations: medically safe only

Dataset is subject-driven, not slice-driven

defining transformers

In [16]:
# =========================
# Step 3.1: Preprocessing & Augmentation Transforms
# =========================

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# -------- TRAIN transforms (with augmentation) --------
train_transforms = T.Compose([
    T.Resize((224, 224)),
    T.RandomApply([
        T.ColorJitter(brightness=0.1, contrast=0.1)
    ], p=0.5),
    T.RandomApply([
        T.GaussianBlur(kernel_size=3)
    ], p=0.3),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# -------- VAL / TEST transforms (NO augmentation) --------
eval_transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

print("‚úÖ Transforms defined and frozen")


‚úÖ Transforms defined and frozen


In [17]:
from PIL import Image
from torch.utils.data import Dataset
import torch

BUILD SUBJECT-AWARE DATASET CLASS

In [18]:
class AlzheimerSubjectDataset(Dataset):
    def __init__(self, subjects_df, transform=None, cache=True):
        self.subjects_df = subjects_df.reset_index(drop=True)
        self.transform = transform
        self.cache = cache
        self._cache = {}

        self.class_to_label = {
            "CN": 0,
            "LMCI": 1,
            "AD": 2
        }

    def __len__(self):
        return len(self.subjects_df)

    def _load_subject(self, idx):
        row = self.subjects_df.iloc[idx]
        images = []

        for subject_dir in row["subject_dirs"]:
            for img_path in subject_dir.rglob("*.png"):
                img = Image.open(img_path).convert("RGB")
                if self.transform:
                    img = self.transform(img)
                images.append(img)

        if len(images) == 0:
            raise RuntimeError(f"No images found for subject {row['subject_id']}")

        return torch.stack(images)

    def __getitem__(self, idx):
        if self.cache and idx in self._cache:
            images = self._cache[idx]
        else:
            images = self._load_subject(idx)
            if self.cache:
                self._cache[idx] = images

        label = self.class_to_label[self.subjects_df.iloc[idx]["class"]]
        return images, label


CREATE DATASETS (SPLIT-AWARE)

connect splits-> datasets

In [19]:
# =========================
# Step 3.3: Create Datasets
# =========================

train_dataset = AlzheimerSubjectDataset(train_df, transform=train_transforms)
val_dataset   = AlzheimerSubjectDataset(val_df, transform=eval_transforms)
test_dataset  = AlzheimerSubjectDataset(test_df, transform=eval_transforms)

print("Train subjects:", len(train_dataset))
print("Val subjects:", len(val_dataset))
print("Test subjects:", len(test_dataset))


Train subjects: 447
Val subjects: 96
Test subjects: 96


sanity check

In [20]:
# =========================
# Step 3.4: Dataset Sanity Check
# =========================

sample_images, sample_label = train_dataset[0]

print("Sample label:", sample_label)
print("Sample tensor shape:", sample_images.shape)


Sample label: 2
Sample tensor shape: torch.Size([50, 3, 224, 224])


CNN FEATURE EXTRACTOR (ResNet50, SUBJECT-LEVEL)
For each subject:

MRI slices ‚Üí ResNet50 ‚Üí slice features ‚Üí subject embedding

Load ResNet50 (feature extractor)

In [21]:
# =========================
# Step 4.1: Load CNN Backbone (ResNet50)
# =========================

import torchvision.models as models

cnn_backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Remove classification head
cnn_backbone.fc = nn.Identity()

cnn_backbone = cnn_backbone.to(DEVICE)
cnn_backbone.eval()

print("‚úÖ ResNet50 loaded as feature extractor")


‚úÖ ResNet50 loaded as feature extractor


Subject-level CNN feature extraction function

In [22]:
# =========================
# Step 4.2: CNN Subject-level Feature Extraction
# =========================

@torch.no_grad()
def extract_cnn_subject_features(model, dataset, batch_size=16):
    model.eval()
    
    all_features = []
    all_labels = []

    for idx in range(len(dataset)):
        images, label = dataset[idx]      # images: (num_slices, 3, 224, 224)
        images = images.to(DEVICE)

        slice_features = []

        # Process slices in mini-batches
        for i in range(0, images.size(0), batch_size):
            batch = images[i:i+batch_size]
            feats = model(batch)          # (batch, 2048)
            slice_features.append(feats.cpu())

        slice_features = torch.cat(slice_features, dim=0)   # (num_slices, 2048)

        # Subject-level pooling
        subject_feature = slice_features.mean(dim=0)        # (2048,)

        all_features.append(subject_feature)
        all_labels.append(label)

        if (idx + 1) % 50 == 0:
            print(f"Processed {idx+1}/{len(dataset)} subjects")

    X = torch.stack(all_features)
    y = torch.tensor(all_labels)

    return X, y


Test on SMALL subset first

In [23]:
# =========================
# Step 4.3: Small-scale Test Extraction
# =========================

small_subset = torch.utils.data.Subset(train_dataset, range(5))

X_test, y_test = extract_cnn_subject_features(cnn_backbone, small_subset)

print("Feature shape:", X_test.shape)
print("Labels:", y_test.tolist())


Feature shape: torch.Size([5, 2048])
Labels: [2, 1, 2, 1, 2]


Full extraction (train / val / test)

In [24]:
# =========================
# Step 4.4: Full CNN Feature Extraction
# =========================

X_train_cnn, y_train = extract_cnn_subject_features(cnn_backbone, train_dataset)
X_val_cnn, y_val     = extract_cnn_subject_features(cnn_backbone, val_dataset)
X_test_cnn, y_test   = extract_cnn_subject_features(cnn_backbone, test_dataset)

print("Train features:", X_train_cnn.shape)
print("Val features:", X_val_cnn.shape)
print("Test features:", X_test_cnn.shape)


Processed 50/447 subjects
Processed 100/447 subjects
Processed 150/447 subjects
Processed 200/447 subjects
Processed 250/447 subjects
Processed 300/447 subjects
Processed 350/447 subjects
Processed 400/447 subjects
Processed 50/96 subjects
Processed 50/96 subjects
Train features: torch.Size([447, 2048])
Val features: torch.Size([96, 2048])
Test features: torch.Size([96, 2048])


ViT-B/16 FEATURE EXTRACTION (SUBJECT-LEVEL)
For each subject:

MRI slices ‚Üí ViT-B/16 ‚Üí slice embeddings ‚Üí subject embedding

loading vit-B/16 backbone

In [25]:
# =========================
# Step 5.1: Load ViT-B/16 Backbone
# =========================

from torchvision.models import vit_b_16, ViT_B_16_Weights

vit_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

# Replace classification head with identity
vit_model.heads = nn.Identity()

vit_model = vit_model.to(DEVICE)
vit_model.eval()

print("‚úÖ ViT-B/16 loaded as feature extractor")


‚úÖ ViT-B/16 loaded as feature extractor


ViT Subject-level Feature Extraction Function
This mirrors the CNN extraction logic (important for fairness).

In [26]:
# =========================
# Step 5.2: ViT Subject-level Feature Extraction
# =========================

@torch.no_grad()
def extract_vit_subject_features(model, dataset, batch_size=16):
    model.eval()
    
    all_features = []
    all_labels = []

    for idx in range(len(dataset)):
        images, label = dataset[idx]        # (num_slices, 3, 224, 224)
        images = images.to(DEVICE)

        slice_features = []

        for i in range(0, images.size(0), batch_size):
            batch = images[i:i+batch_size]
            feats = model(batch)            # (batch, 768)
            slice_features.append(feats.cpu())

        slice_features = torch.cat(slice_features, dim=0)  # (num_slices, 768)

        # Subject-level pooling
        subject_feature = slice_features.mean(dim=0)       # (768,)

        all_features.append(subject_feature)
        all_labels.append(label)

        if (idx + 1) % 50 == 0:
            print(f"Processed {idx+1}/{len(dataset)} subjects")

    X = torch.stack(all_features)
    y = torch.tensor(all_labels)

    return X, y


SMALL-SCALE SAFETY TEST

In [27]:
# =========================
# Step 5.3: Small-scale ViT Feature Test
# =========================

small_subset = torch.utils.data.Subset(train_dataset, range(5))

X_vit_test, y_vit_test = extract_vit_subject_features(vit_model, small_subset)

print("ViT feature shape:", X_vit_test.shape)
print("Labels:", y_vit_test.tolist())


ViT feature shape: torch.Size([5, 768])
Labels: [2, 1, 2, 1, 2]


FULL ViT FEATURE EXTRACTION

In [28]:
# =========================
# Step 5.4: Full ViT Feature Extraction
# =========================

X_train_vit, y_train_vit = extract_vit_subject_features(vit_model, train_dataset)
X_val_vit, y_val_vit     = extract_vit_subject_features(vit_model, val_dataset)
X_test_vit, y_test_vit   = extract_vit_subject_features(vit_model, test_dataset)

print("Train ViT features:", X_train_vit.shape)
print("Val ViT features:", X_val_vit.shape)
print("Test ViT features:", X_test_vit.shape)


Processed 50/447 subjects
Processed 100/447 subjects
Processed 150/447 subjects
Processed 200/447 subjects
Processed 250/447 subjects
Processed 300/447 subjects
Processed 350/447 subjects
Processed 400/447 subjects
Processed 50/96 subjects
Processed 50/96 subjects
Train ViT features: torch.Size([447, 768])
Val ViT features: torch.Size([96, 768])
Test ViT features: torch.Size([96, 768])


In [29]:
import pandas as pd
import torch
import torchvision
import numpy as np

print("pandas:", pd.__version__)
print("torch:", torch.__version__)
print("cuda:", torch.cuda.is_available())


pandas: 2.3.3
torch: 2.7.1+cu118
cuda: True


BASELINE EXPERIMENTS (CNN, ViT, SIMPLE FUSION)

Train and evaluate three baseline models:

CNN-only (ResNet50 features)

ViT-only (ViT-B/16 features)

Simple CNN + ViT fusion (concatenation)

Using:

the same train/val/test splits

no fancy tricks

clear metrics

In [30]:
# =========================
# Step 6.1: Baseline Utilities
# =========================

from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np


CNN-ONLY BASELINE

In [31]:
# =========================
# Step 6.2: CNN-only Baseline (SVM)
# =========================

svm_cnn = SVC(
    kernel="linear",
    class_weight="balanced",
    probability=False,
    random_state=SEED
)

svm_cnn.fit(X_train_cnn.numpy(), y_train.numpy())

y_val_pred_cnn = svm_cnn.predict(X_val_cnn.numpy())
y_test_pred_cnn = svm_cnn.predict(X_test_cnn.numpy())

print("=== CNN-only (Validation) ===")
print(classification_report(y_val.numpy(), y_val_pred_cnn, target_names=["CN", "LMCI", "AD"]))

print("=== CNN-only (Test) ===")
print(classification_report(y_test.numpy(), y_test_pred_cnn, target_names=["CN", "LMCI", "AD"]))


=== CNN-only (Validation) ===
              precision    recall  f1-score   support

          CN       0.43      0.41      0.42        29
        LMCI       0.58      0.66      0.62        47
          AD       0.47      0.35      0.40        20

    accuracy                           0.52        96
   macro avg       0.49      0.47      0.48        96
weighted avg       0.51      0.52      0.51        96

=== CNN-only (Test) ===
              precision    recall  f1-score   support

          CN       0.41      0.53      0.46        30
        LMCI       0.65      0.65      0.65        46
          AD       0.64      0.35      0.45        20

    accuracy                           0.55        96
   macro avg       0.57      0.51      0.52        96
weighted avg       0.57      0.55      0.55        96



In [32]:
# =========================
# Step 6.3: ViT-only Baseline (SVM)
# =========================

svm_vit = SVC(
    kernel="linear",
    class_weight="balanced",
    probability=False,
    random_state=SEED
)

svm_vit.fit(X_train_vit.numpy(), y_train_vit.numpy())

y_val_pred_vit = svm_vit.predict(X_val_vit.numpy())
y_test_pred_vit = svm_vit.predict(X_test_vit.numpy())

print("=== ViT-only (Validation) ===")
print(classification_report(y_val_vit.numpy(), y_val_pred_vit, target_names=["CN", "LMCI", "AD"]))

print("=== ViT-only (Test) ===")
print(classification_report(y_test_vit.numpy(), y_test_pred_vit, target_names=["CN", "LMCI", "AD"]))


=== ViT-only (Validation) ===
              precision    recall  f1-score   support

          CN       0.50      0.45      0.47        29
        LMCI       0.67      0.83      0.74        47
          AD       0.83      0.50      0.62        20

    accuracy                           0.65        96
   macro avg       0.67      0.59      0.61        96
weighted avg       0.65      0.65      0.64        96

=== ViT-only (Test) ===
              precision    recall  f1-score   support

          CN       0.48      0.47      0.47        30
        LMCI       0.59      0.72      0.65        46
          AD       0.64      0.35      0.45        20

    accuracy                           0.56        96
   macro avg       0.57      0.51      0.52        96
weighted avg       0.57      0.56      0.55        96



In [33]:
# =========================
# Step 6.4: CNN + ViT Simple Fusion Baseline
# =========================

X_train_fused = torch.cat([X_train_cnn, X_train_vit], dim=1)
X_val_fused   = torch.cat([X_val_cnn, X_val_vit], dim=1)
X_test_fused  = torch.cat([X_test_cnn, X_test_vit], dim=1)

print("Fused feature dimension:", X_train_fused.shape[1])

svm_fused = SVC(
    kernel="linear",
    class_weight="balanced",
    probability=False,
    random_state=SEED
)

svm_fused.fit(X_train_fused.numpy(), y_train.numpy())

y_val_pred_fused = svm_fused.predict(X_val_fused.numpy())
y_test_pred_fused = svm_fused.predict(X_test_fused.numpy())

print("=== CNN + ViT Fusion (Validation) ===")
print(classification_report(y_val.numpy(), y_val_pred_fused, target_names=["CN", "LMCI", "AD"]))

print("=== CNN + ViT Fusion (Test) ===")
print(classification_report(y_test.numpy(), y_test_pred_fused, target_names=["CN", "LMCI", "AD"]))


Fused feature dimension: 2816
=== CNN + ViT Fusion (Validation) ===
              precision    recall  f1-score   support

          CN       0.45      0.45      0.45        29
        LMCI       0.67      0.81      0.73        47
          AD       0.70      0.35      0.47        20

    accuracy                           0.60        96
   macro avg       0.60      0.54      0.55        96
weighted avg       0.61      0.60      0.59        96

=== CNN + ViT Fusion (Test) ===
              precision    recall  f1-score   support

          CN       0.43      0.43      0.43        30
        LMCI       0.59      0.74      0.65        46
          AD       0.62      0.25      0.36        20

    accuracy                           0.54        96
   macro avg       0.55      0.47      0.48        96
weighted avg       0.55      0.54      0.52        96



In [34]:
# =========================
# Step 6.5: Save Baseline Predictions
# =========================

np.save(RESULTS_DIR / "y_test_true.npy", y_test.numpy())
np.save(RESULTS_DIR / "y_test_pred_cnn.npy", y_test_pred_cnn)
np.save(RESULTS_DIR / "y_test_pred_vit.npy", y_test_pred_vit)
np.save(RESULTS_DIR / "y_test_pred_fused.npy", y_test_pred_fused)

print("‚úÖ Baseline predictions saved")


‚úÖ Baseline predictions saved


END-TO-END DUAL-STREAM MODEL

(CNN + ViT + Attention Fusion + Classifier)




MRI slices (per subject)
   ‚îú‚îÄ‚îÄ CNN Stream (Med3D ResNet50 / ResNet50 for now)
   ‚îú‚îÄ‚îÄ Transformer Stream (ViT-B/16)
   ‚îî‚îÄ‚îÄ Attention Fusion (learns importance)
           ‚Üì
      Subject-level embedding
           ‚Üì
        Classifier
           ‚Üì
      CN / LMCI / AD

Right now, to avoid breaking the environment, we will:

Implement Step 7 using ResNet50 (ImageNet) as a placeholder

The architecture will be identical

In the next step, we will swap the CNN backbone to Med3D
with zero changes to training code

##### dual - stream model defination

In [35]:
# =========================
# Step 7.1: End-to-End Dual-Stream Model with Attention Fusion
# =========================

class AttentionFusion(nn.Module):
    """
    Gated attention fusion for CNN & ViT features
    """
    def __init__(self, cnn_dim, vit_dim, hidden_dim=512):
        super().__init__()
        
        self.fc_cnn = nn.Linear(cnn_dim, hidden_dim)
        self.fc_vit = nn.Linear(vit_dim, hidden_dim)
        
        self.attn = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),
            nn.Softmax(dim=1)
        )

    def forward(self, cnn_feat, vit_feat):
        cnn_h = self.fc_cnn(cnn_feat)
        vit_h = self.fc_vit(vit_feat)

        combined = torch.cat([cnn_h, vit_h], dim=1)
        weights = self.attn(combined)   # (B, 2)

        fused = (
            weights[:, 0:1] * cnn_h +
            weights[:, 1:2] * vit_h
        )

        return fused


class DualStreamAlzheimerModel(nn.Module):
    def __init__(self, cnn_backbone, vit_backbone, num_classes=3):
        super().__init__()

        self.cnn = cnn_backbone
        self.vit = vit_backbone

        self.fusion = AttentionFusion(
            cnn_dim=2048,
            vit_dim=768,
            hidden_dim=512
        )

        self.classifier = nn.Sequential(
            nn.LayerNorm(512),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, images):
        """
        images: (B, S, 3, 224, 224)
        """
        B, S, C, H, W = images.shape
        images = images.view(B * S, C, H, W)

        # CNN stream
        cnn_feats = self.cnn(images)          # (B*S, 2048)
        cnn_feats = cnn_feats.view(B, S, -1).mean(dim=1)

        # ViT stream
        vit_feats = self.vit(images)          # (B*S, 768)
        vit_feats = vit_feats.view(B, S, -1).mean(dim=1)

        fused = self.fusion(cnn_feats, vit_feats)
        logits = self.classifier(fused)

        return logits


initialize the model

In [36]:
# =========================
# Step 7.2: Model Initialization
# =========================

model = DualStreamAlzheimerModel(
    cnn_backbone=cnn_backbone,
    vit_backbone=vit_model
).to(DEVICE)

print("‚úÖ Dual-stream model initialized")


‚úÖ Dual-stream model initialized


Loss function and optimizer

In [37]:
# =========================
# Step 7.3: Loss, Optimizer, Scheduler
# =========================

class_counts = train_df["class"].value_counts().sort_index()
class_weights = 1.0 / torch.tensor(class_counts.values, dtype=torch.float)
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=5
)


print("‚úÖ Loss, optimizer, scheduler ready")


‚úÖ Loss, optimizer, scheduler ready


Subject-Batch DataLoader (CRITICAL)

In [38]:
# =========================
# STEP 7.4.1: Slice Sampling Utility
# =========================

import random

def sample_slices(images, num_slices=50, training=True):
    """
    images: Tensor of shape (S, 3, 224, 224)
    returns: Tensor of shape (num_slices, 3, 224, 224)
    """
    S = images.shape[0]

    # Case 1: exact match
    if S == num_slices:
        return images

    # Case 2: more slices than needed
    if S > num_slices:
        if training:
            indices = random.sample(range(S), num_slices)
        else:
            indices = torch.linspace(0, S - 1, steps=num_slices).long()
        return images[indices]

    # Case 3: fewer slices ‚Üí repeat
    repeat_factor = (num_slices // S) + 1
    images = images.repeat(repeat_factor, 1, 1, 1)
    return images[:num_slices]


In [41]:
# =========================
# Windows-safe collate wrappers
# =========================

def train_collate_fn(batch):
    return subject_collate(
        batch,
        num_slices=16,
        training=True
    )

def eval_collate_fn(batch):
    return subject_collate(
        batch,
        num_slices=16,
        training=False
    )


In [42]:
# =========================
# STEP 7.4.2: Fixed-size Subject Collate Function
# =========================

def subject_collate(batch, num_slices=50, training=True):
    images_list = []
    labels = []

    for images, label in batch:
        # images: (S, 3, 224, 224)
        images = sample_slices(
            images,
            num_slices=num_slices,
            training=training
        )
        images_list.append(images)
        labels.append(label)

    images = torch.stack(images_list)   # (B, 50, 3, 224, 224)
    labels = torch.tensor(labels)

    return images, labels


In [43]:
# =========================
# STAGE 2: Safe multiprocessing DataLoaders (Windows)
# =========================

train_loader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2,          # üîë safe on Windows
    pin_memory=True,        # üîë faster CPU -> GPU
    persistent_workers=False,
    collate_fn=train_collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=eval_collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=False,
    collate_fn=eval_collate_fn
)

print("‚úÖ Stage 2 DataLoaders initialized (num_workers=2)")


‚úÖ Stage 2 DataLoaders initialized (num_workers=2)


end-to-end training loop

metrix utilities

In [44]:
# =========================
# Step 7.5.1: Metrics utilities
# =========================

from sklearn.metrics import accuracy_score
import numpy as np

def compute_accuracy(logits, labels):
    preds = torch.argmax(logits, dim=1)
    return accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())


In [45]:
# =========================
# Step 7.5.2: Early Stopping
# =========================

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = np.inf
        self.counter = 0
        self.should_stop = False

    def step(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True


In [46]:
# =========================
# Step 7.5.3: Train / Validate functions
# =========================

def train_one_epoch(model, loader):
    model.train()
    total_loss = 0
    total_acc = 0

    for images, labels in loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        logits = model(images)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += compute_accuracy(logits, labels)

    return total_loss / len(loader), total_acc / len(loader)


@torch.no_grad()
def validate_one_epoch(model, loader):
    model.eval()
    total_loss = 0
    total_acc = 0

    for images, labels in loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(images)
        loss = criterion(logits, labels)

        total_loss += loss.item()
        total_acc += compute_accuracy(logits, labels)

    return total_loss / len(loader), total_acc / len(loader)


In [47]:
# =========================
# PHASE 1: Freeze heavy backbones
# =========================

for param in model.cnn.parameters():
    param.requires_grad = False

for param in model.vit.parameters():
    param.requires_grad = False

print("‚úÖ CNN and ViT frozen (Phase 1)")


‚úÖ CNN and ViT frozen (Phase 1)


In [48]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


CUDA available: True
Current device: 0
Device name: NVIDIA GeForce RTX 3060


In [49]:
print(next(model.parameters()).device)


cuda:0


In [None]:
# =========================
# Step 7.5.4: Full training loop
# =========================

EPOCHS = 100
early_stopping = EarlyStopping(patience=10)

history = {
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": []
}

best_val_acc = 0.0
checkpoint_path = RESULTS_DIR / "best_dualstream_model.pt"

for epoch in range(1, EPOCHS + 1):

    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss, val_acc = validate_one_epoch(model, val_loader)

    scheduler.step(val_loss)

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch:03d} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), checkpoint_path)
        print("‚úÖ Best model updated")

    # Early stopping
    early_stopping.step(val_loss)
    if early_stopping.should_stop:
        print("üõë Early stopping triggered")
        break

print("Training complete.")
print("Best validation accuracy:", best_val_acc)
