In [10]:
import pandas as pd
import os
import glob
import zipfile

metadata_path = "/content/ADNI1_Complete_1Yr_1.5T_6_20_2025.csv"
df = pd.read_csv(metadata_path)

df = df.rename(columns={"Image Data ID": "ImageID", "Subject": "SubjectID", "Group": "Diagnosis"})
df = df.dropna(subset=["ImageID", "SubjectID", "Diagnosis"])
df["ImageID"] = df["ImageID"].astype(str).str.strip()
df["SubjectID"] = df["SubjectID"].astype(str).str.strip()

print(f"Total records in metadata after cleaning: {len(df)}")

zip_path = "/content/Preprocessed_ADNI_flat.zip"
extract_dir = "/content/Preprocessed_ADNI_flat"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

all_files = glob.glob(os.path.join(extract_dir, "**", "*.nii"), recursive=True)
all_files += glob.glob(os.path.join(extract_dir, "**", "*.nii.gz"), recursive=True)

print(f"Total NIfTI files found (recursive): {len(all_files)}")

imageid_to_path = {}
miss_count = 0

for img_id in df["ImageID"].unique():
    found = False
    for path in all_files:
        fname = os.path.basename(path)
        if img_id in fname:
            imageid_to_path[img_id] = path
            found = True
            break
    if not found:
        miss_count += 1

print(f"\nImageIDs successfully mapped to files: {len(imageid_to_path)}")
print(f"ImageIDs with NO matching file: {miss_count}")


df["FilePath"] = df["ImageID"].map(imageid_to_path)
df = df.dropna(subset=["FilePath"])

label_map = {"CN": 0, "AD": 1, "MCI": 2}
df["Label"] = df["Diagnosis"].map(label_map)

print(f"\nTotal usable MRI scans: {len(df)}")

class_counts = df["Diagnosis"].value_counts()
print("\nClass Distribution:")
print(class_counts)

print("\nUnique Subjects per Class:")
for diag in ["CN", "AD", "MCI"]:
    unique_subjects = df[df["Diagnosis"] == diag]["SubjectID"].nunique()
    print(f"{diag}: {unique_subjects} unique subjects")

df.head()


Total records in metadata after cleaning: 2294
Total NIfTI files found (recursive): 230

ImageIDs successfully mapped to files: 230
ImageIDs with NO matching file: 2064

Total usable MRI scans: 230

Class Distribution:
Diagnosis
MCI    121
CN      58
AD      51
Name: count, dtype: int64

Unique Subjects per Class:
CN: 54 unique subjects
AD: 44 unique subjects
MCI: 105 unique subjects


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["Label"] = df["Diagnosis"].map(label_map)


Unnamed: 0,ImageID,SubjectID,Diagnosis,Sex,Age,Visit,Modality,Description,Type,Acq Date,Format,Downloaded,FilePath,Label
0,I112538,941_S_1311,MCI,M,70,m12,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,6/01/2008,NiFTI,,/content/Preprocessed_ADNI_flat/Preprocessed_A...,2
2,I97327,941_S_1311,MCI,M,69,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,3/02/2007,NiFTI,,/content/Preprocessed_ADNI_flat/Preprocessed_A...,2
7,I66462,941_S_1197,CN,F,82,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,1/20/2007,NiFTI,,/content/Preprocessed_ADNI_flat/Preprocessed_A...,0
21,I92415,141_S_1137,AD,F,82,m06,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,6/26/2007,NiFTI,,/content/Preprocessed_ADNI_flat/Preprocessed_A...,1
46,I81508,141_S_0851,MCI,F,74,m12,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,10/13/2007,NiFTI,,/content/Preprocessed_ADNI_flat/Preprocessed_A...,2


In [11]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score


class MRIDatasetSelectiveSlices(Dataset):
    def __init__(self, dataframe, slice_indices=[80, 85, 90, 95, 100]):
        self.df = dataframe
        self.slice_indices = slice_indices

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = nib.load(row.FilePath).get_fdata().astype(np.float32)

        # Normalize with mean-std
        img = (img - np.mean(img)) / (np.std(img) + 1e-8)

        # Select coronal slices: axis=1
        img = img[:, self.slice_indices, :]
        img = np.transpose(img, (1, 0, 2))  # (slices, H, W)
        img = np.expand_dims(img, axis=0)  # (1, D, H, W)

        return torch.tensor(img, dtype=torch.float32), row.Label


class BlockAttention3D(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv3d(ch, ch, kernel_size=(1, 3, 3), padding=(0, 1, 1))
    def forward(self, x):
        return x + self.conv(x)

class GridAttention3D(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv3d(ch, ch, kernel_size=(3, 1, 1), padding=(1, 0, 0))
    def forward(self, x):
        return x + self.conv(x)

class MBConv3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.expand = nn.Conv3d(in_ch, out_ch, 1)
        self.dwconv = nn.Conv3d(out_ch, out_ch, 3, padding=1, groups=out_ch)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(out_ch, out_ch // 4, 1),
            nn.ReLU(),
            nn.Conv3d(out_ch // 4, out_ch, 1),
            nn.Sigmoid()
        )
        self.project = nn.Conv3d(out_ch, out_ch, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        x = self.relu(self.expand(x))
        x = self.relu(self.dwconv(x))
        x = x * self.se(x)
        x = self.project(x)
        return x + identity

class PConv3D(nn.Module):
    def __init__(self, ch):
        super().__init__()
        partial = ch // 2
        self.pconv = nn.Conv3d(partial, partial, 3, padding=1)
        self.pwconv = nn.Conv3d(ch, ch, 1)

    def forward(self, x):
        partial = x[:, :x.shape[1] // 2]
        partial = self.pconv(partial)
        combined = torch.cat([partial, x[:, x.shape[1] // 2:]], dim=1)
        return self.pwconv(combined)

class InvertedResidualFFN(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.expand = nn.Conv3d(ch, ch * 4, 1)
        self.dwconv = nn.Conv3d(ch * 4, ch * 4, 3, padding=1, groups=ch * 4)
        self.project = nn.Conv3d(ch * 4, ch, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        x = self.relu(self.expand(x))
        x = self.relu(self.dwconv(x))
        x = self.project(x)
        return x + identity


class ADHybridClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.mb = MBConv3D(1, 16)
        self.attn1 = BlockAttention3D(16)
        self.irffn1 = InvertedResidualFFN(16)
        self.pconv = PConv3D(16)
        self.attn2 = GridAttention3D(16)
        self.irffn2 = InvertedResidualFFN(16)
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(16, 3)

    def forward(self, x):
        x = self.mb(x)
        x = self.attn1(x)
        x = self.irffn1(x)
        x = self.pconv(x)
        x = self.attn2(x)
        x = self.irffn2(x)
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [13]:

dataset = MRIDatasetSelectiveSlices(df)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ADHybridClassifier().to(device)

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.array([0, 1, 2]),
    y=df["Label"].values
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)

best_val_loss = float('inf')
patience = 3
epochs_no_improve = 0
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    train_loss, train_preds, train_labels = 0, [], []

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs)
        loss = criterion(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_preds.extend(preds.argmax(1).cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

    train_loss /= len(train_loader)
    train_acc = accuracy_score(train_labels, train_preds)

    model.eval()
    val_loss, val_preds, val_labels = 0, [], []

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs)
            loss = criterion(preds, labels)
            val_loss += loss.item()
            val_preds.extend(preds.argmax(1).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_acc = accuracy_score(val_labels, val_preds)

    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"Train Acc: {train_acc:.2%} | Val Acc: {val_acc:.2%}")


    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_ad_model.pth")
        print(" Model saved.")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(" Early stopping triggered.")
            break




Epoch 1 | Train Loss: 1.1048 | Val Loss: 1.0881 | Train Acc: 38.59% | Val Acc: 52.17%
 Model saved.
Epoch 2 | Train Loss: 1.0793 | Val Loss: 1.0927 | Train Acc: 52.72% | Val Acc: 52.17%
Epoch 3 | Train Loss: 1.0875 | Val Loss: 1.0919 | Train Acc: 52.72% | Val Acc: 52.17%
Epoch 4 | Train Loss: 1.0773 | Val Loss: 1.0941 | Train Acc: 52.72% | Val Acc: 52.17%
 Early stopping triggered.
