In [1]:
# =========================
# STEP 1: Imports & Config
# =========================

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms

import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm import tqdm

# -------------------------
# Device
# -------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# -------------------------
# Paths
# -------------------------
PROJECT_ROOT = Path("..")
DATA_ROOT = PROJECT_ROOT / "DATASET" / "final_balanced_dataset" / "FINAL_BALANCED_DATASET"
FEATURE_ROOT = PROJECT_ROOT / "features"
CNN_FEATURE_DIR = FEATURE_ROOT / "cnn"

CNN_FEATURE_DIR.mkdir(parents=True, exist_ok=True)

print("DATA_ROOT:", DATA_ROOT)
print("CNN_FEATURE_DIR:", CNN_FEATURE_DIR)


Using device: cuda
DATA_ROOT: ..\DATASET\final_balanced_dataset\FINAL_BALANCED_DATASET
CNN_FEATURE_DIR: ..\features\cnn


In [2]:
# =========================
# STEP 2: Build subject-level index
# =========================

records = []

for class_dir in DATA_ROOT.iterdir():
    if not class_dir.is_dir():
        continue

    class_name = class_dir.name  # CN, LMCI, AD

    for subject_dir in class_dir.iterdir():
        if not subject_dir.is_dir():
            continue

        subject_name = subject_dir.name

        # Remove augmentation suffixes (_aug0, _aug1, etc.)
        base_subject_id = subject_name.split("_aug")[0]

        records.append({
            "subject_id": base_subject_id,
            "class": class_name,
            "subject_dir": subject_dir
        })

df = pd.DataFrame(records)

print("Total folders (including augmentations):", len(df))
df.head()


Total folders (including augmentations): 933


Unnamed: 0,subject_id,class,subject_dir
0,002_S_0619,AD,..\DATASET\final_balanced_dataset\FINAL_BALANC...
1,002_S_0619,AD,..\DATASET\final_balanced_dataset\FINAL_BALANC...
2,002_S_0619,AD,..\DATASET\final_balanced_dataset\FINAL_BALANC...
3,002_S_0816,AD,..\DATASET\final_balanced_dataset\FINAL_BALANC...
4,002_S_0816,AD,..\DATASET\final_balanced_dataset\FINAL_BALANC...


In [3]:
# =========================
# Group by real subject
# =========================

subjects_df = (
    df.groupby(["subject_id", "class"])["subject_dir"]
      .apply(list)
      .reset_index(name="subject_dirs")
)

print("Total REAL subjects:", len(subjects_df))
subjects_df["class"].value_counts()


Total REAL subjects: 639


class
LMCI    311
CN      195
AD      133
Name: count, dtype: int64

In [4]:
subjects_df.head()


Unnamed: 0,subject_id,class,subject_dirs
0,002_S_0295,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...
1,002_S_0413,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...
2,002_S_0619,AD,[..\DATASET\final_balanced_dataset\FINAL_BALAN...
3,002_S_0685,CN,[..\DATASET\final_balanced_dataset\FINAL_BALAN...
4,002_S_0729,LMCI,[..\DATASET\final_balanced_dataset\FINAL_BALAN...


In [6]:
# =========================
# STEP 3: Load frozen ResNet50
# =========================

# Load pretrained ResNet50
resnet = models.resnet50(
    weights=models.ResNet50_Weights.IMAGENET1K_V1
)

# Remove the classification layer
resnet.fc = nn.Identity()

# Freeze all parameters
for param in resnet.parameters():
    param.requires_grad = False

# Move to device and set eval mode
resnet = resnet.to(DEVICE)
resnet.eval()

print("✅ ResNet50 loaded as frozen feature extractor")


✅ ResNet50 loaded as frozen feature extractor


In [7]:
# Check output feature dimension
dummy = torch.randn(1, 3, 224, 224).to(DEVICE)
with torch.no_grad():
    out = resnet(dummy)

print("Output feature shape:", out.shape)


Output feature shape: torch.Size([1, 2048])


In [8]:
# =========================
# STEP 4: CNN Image Preprocessing
# =========================

cnn_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

print("✅ CNN preprocessing pipeline defined")


✅ CNN preprocessing pipeline defined


In [9]:
# Sanity check transform output
dummy_img = Image.new("RGB", (300, 300))
tensor_img = cnn_transform(dummy_img)

print("Transformed image shape:", tensor_img.shape)
print("Min / Max:", tensor_img.min().item(), tensor_img.max().item())


Transformed image shape: torch.Size([3, 224, 224])
Min / Max: -2.1179039478302 -1.804444432258606


In [14]:
# =========================
# Robust deterministic slice sampling
# =========================

def sample_slices(image_paths, num_slices=16):
    """
    Deterministically sample num_slices images from a list of image paths.
    If fewer images are available, repeat the last image.
    If no images are available, return an empty list.
    """
    image_paths = sorted(image_paths)

    if len(image_paths) == 0:
        return []

    if len(image_paths) >= num_slices:
        indices = np.linspace(
            0, len(image_paths) - 1, num_slices
        ).astype(int)
        return [image_paths[i] for i in indices]
    else:
        return image_paths + [image_paths[-1]] * (num_slices - len(image_paths))


In [18]:
# Check CNN feature directory
print("CNN feature directory exists:", CNN_FEATURE_DIR.exists())

files = list(CNN_FEATURE_DIR.glob("*.npy"))
print("Number of CNN feature files found:", len(files))

if len(files) > 0:
    print("Example feature file:", files[0].name)


CNN feature directory exists: True
Number of CNN feature files found: 146
Example feature file: 002_S_0295.npy


In [19]:
# =========================
# STEP 6: CNN feature extraction per subject
# =========================

all_subject_ids = []
all_labels = []

for idx, row in tqdm(subjects_df.iterrows(), total=len(subjects_df)):
    subject_id = row["subject_id"]
    label = row["class"]
    subject_dirs = row["subject_dirs"]

    save_path = CNN_FEATURE_DIR / f"{subject_id}.npy"

    # Skip if already processed (resume-safe)
    if save_path.exists():
        all_subject_ids.append(subject_id)
        all_labels.append(label)
        continue

    # Collect all slice paths
    image_paths = []
    for d in subject_dirs:
        image_paths.extend(sorted(d.glob("*.png")))

    if len(image_paths) == 0:
        print(f"⚠️ No images found for subject {subject_id}, skipping.")
        continue

    # Sample slices
    sampled_paths = sample_slices(image_paths, num_slices=16)

    images = []
    for p in sampled_paths:
        img = Image.open(p).convert("RGB")
        img = cnn_transform(img)
        images.append(img)

    images = torch.stack(images).to(DEVICE)  # (16, 3, 224, 224)

    # Extract features
    with torch.no_grad():
        feats = resnet(images)               # (16, 2048)
        subject_feat = feats.mean(dim=0)     # (2048,)
        subject_feat = subject_feat.cpu().numpy()

    # Save feature
    np.save(save_path, subject_feat)

    all_subject_ids.append(subject_id)
    all_labels.append(label)

print("✅ CNN feature extraction loop completed")


100%|██████████| 639/639 [00:00<00:00, 6802.66it/s]

⚠️ No images found for subject 002_S_0729, skipping.
⚠️ No images found for subject 002_S_0782, skipping.
⚠️ No images found for subject 002_S_0954, skipping.
⚠️ No images found for subject 002_S_1070, skipping.
⚠️ No images found for subject 002_S_1155, skipping.
⚠️ No images found for subject 002_S_1268, skipping.
⚠️ No images found for subject 003_S_0908, skipping.
⚠️ No images found for subject 003_S_1057, skipping.
⚠️ No images found for subject 003_S_1122, skipping.
⚠️ No images found for subject 005_S_0222, skipping.
⚠️ No images found for subject 005_S_0324, skipping.
⚠️ No images found for subject 005_S_0448, skipping.
⚠️ No images found for subject 005_S_0546, skipping.
⚠️ No images found for subject 005_S_1224, skipping.
⚠️ No images found for subject 006_S_0675, skipping.
⚠️ No images found for subject 006_S_1130, skipping.
⚠️ No images found for subject 007_S_0041, skipping.
⚠️ No images found for subject 007_S_0101, skipping.
⚠️ No images found for subject 007_S_0128, ski




In [20]:
files = list(CNN_FEATURE_DIR.glob("*.npy"))
print("Final CNN feature count:", len(files))


Final CNN feature count: 146


In [21]:
# =========================
# STEP 7: Save subject labels index
# =========================

labels_df = pd.DataFrame({
    "subject_id": all_subject_ids,
    "label": all_labels
})

labels_path = FEATURE_ROOT / "labels.csv"
labels_df.to_csv(labels_path, index=False)

print("✅ labels.csv saved at:", labels_path)
print("Total labeled subjects:", len(labels_df))
labels_df["label"].value_counts()


✅ labels.csv saved at: ..\features\labels.csv
Total labeled subjects: 146


label
AD    76
CN    70
Name: count, dtype: int64

In [22]:
import pandas as pd

labels_df = pd.read_csv(FEATURE_ROOT / "labels.csv")
print(labels_df["label"].value_counts())


label
AD    76
CN    70
Name: count, dtype: int64


In [23]:
# Inspect LMCI folders to see what files they contain
lmci_dirs = subjects_df[subjects_df["class"] == "LMCI"].iloc[0]["subject_dirs"]

for d in lmci_dirs:
    print("\nFolder:", d)
    for f in list(d.iterdir())[:10]:
        print("  ", f.name)



Folder: ..\DATASET\final_balanced_dataset\FINAL_BALANCED_DATASET\LMCI\002_S_0729
   sagittal_slices


In [24]:
# Inspect actual slice file extensions in one LMCI subject
lmci_subject = subjects_df[subjects_df["class"] == "LMCI"].iloc[0]
dirs = lmci_subject["subject_dirs"]

extensions = set()

for d in dirs:
    for f in d.rglob("*"):
        if f.is_file():
            extensions.add(f.suffix)

print("Found file extensions:", extensions)


Found file extensions: {'.png'}


In [27]:
# =========================
# STEP 6: CNN Feature Extraction (FINAL)
# =========================

from tqdm import tqdm

all_subject_ids = []
all_labels = []

print("Starting CNN feature extraction...")

# Ensure feature directory exists (bulletproof)
CNN_FEATURE_DIR.mkdir(parents=True, exist_ok=True)


for idx, row in tqdm(subjects_df.iterrows(), total=len(subjects_df)):

    subject_id = row["subject_id"]
    label = row["class"]
    subject_dirs = row["subject_dirs"]

    save_path = CNN_FEATURE_DIR / f"{subject_id}.npy"

    # -------------------------
    # Resume-safe: skip if already done
    # -------------------------
    if save_path.exists():
        all_subject_ids.append(subject_id)
        all_labels.append(label)
        continue

    # -------------------------
    # Collect ALL slice images (recursive)
    # -------------------------
    image_paths = []
    for d in subject_dirs:
        image_paths.extend(sorted(d.rglob("*.png")))

    # If still no images, skip subject
    if len(image_paths) == 0:
        print(f"⚠️ No valid slices for subject {subject_id}, skipping.")
        continue

    # -------------------------
    # Deterministic slice sampling
    # -------------------------
    sampled_paths = sample_slices(image_paths, num_slices=16)

    if len(sampled_paths) == 0:
        print(f"⚠️ Sampling failed for subject {subject_id}, skipping.")
        continue

    # -------------------------
    # Load & preprocess images
    # -------------------------
    images = []
    for p in sampled_paths:
        img = Image.open(p).convert("RGB")
        img = cnn_transform(img)
        images.append(img)

    images = torch.stack(images).to(DEVICE)   # (16, 3, 224, 224)

    # -------------------------
    # CNN feature extraction
    # -------------------------
    with torch.no_grad():
        feats = resnet(images)               # (16, 2048)
        subject_feat = feats.mean(dim=0)     # (2048,)
        subject_feat = subject_feat.cpu().numpy()

    # -------------------------
    # Save feature
    # -------------------------
    np.save(save_path, subject_feat)

    all_subject_ids.append(subject_id)
    all_labels.append(label)

print("✅ CNN feature extraction completed successfully.")
print("Total subjects processed:", len(all_subject_ids))


Starting CNN feature extraction...


100%|██████████| 639/639 [02:36<00:00,  4.07it/s]

✅ CNN feature extraction completed successfully.
Total subjects processed: 639





In [29]:
# =========================
# STEP 7: Save subject labels index (FINAL)
# =========================

import pandas as pd

labels_df = pd.DataFrame({
    "subject_id": all_subject_ids,
    "label": all_labels
})

labels_path = FEATURE_ROOT / "labels.csv"
labels_df.to_csv(labels_path, index=False)

print("✅ labels.csv saved at:", labels_path)
print("Total labeled subjects:", len(labels_df))
print("\nClass distribution:")
print(labels_df["label"].value_counts())


✅ labels.csv saved at: ..\features\labels.csv
Total labeled subjects: 639

Class distribution:
label
LMCI    311
CN      195
AD      133
Name: count, dtype: int64


In [30]:
# Final verification of CNN features and labels

import pandas as pd

cnn_files = list(CNN_FEATURE_DIR.glob("*.npy"))
print("Total CNN feature files:", len(cnn_files))

labels_df = pd.read_csv(FEATURE_ROOT / "labels.csv")
print("\nClass distribution:")
print(labels_df["label"].value_counts())


Total CNN feature files: 639

Class distribution:
label
LMCI    311
CN      195
AD      133
Name: count, dtype: int64
