In [11]:
# =========================
# Cell 1: Imports & Config
# =========================

import torch
import timm
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm import tqdm

# -------------------------
# Global configuration
# -------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Resolve project root (assuming notebooks/ is one level down)
PROJECT_ROOT = Path.cwd().parents[0]

print("Project root:", PROJECT_ROOT)

K_SLICES = 16  # fixed slices per subject (LOCKED)

FEATURES_DIR = Path("features/vit")
FEATURES_DIR.mkdir(parents=True, exist_ok=True)

print("Device:", DEVICE)
print("K slices per subject:", K_SLICES)


Project root: c:\Users\ADMIN\Documents\Alz_work
Device: cuda
K slices per subject: 16


In [10]:
# Resolve project root (assuming notebooks/ is one level down)
PROJECT_ROOT = Path.cwd().parents[0]

print("Project root:", PROJECT_ROOT)


Project root: c:\Users\ADMIN\Documents\Alz_work


In [4]:
from pathlib import Path

list(Path(".").glob("**/labels.csv"))


[]

In [6]:
import os
print(os.getcwd())


c:\Users\ADMIN\Documents\Alz_work\Notebooks


In [7]:
from pathlib import Path
list(Path(".").glob("**/*.csv"))


[]

In [12]:
# =========================
# Cell 2: Load labels & subjects
# =========================

LABELS_PATH = PROJECT_ROOT / "features" / "labels.csv"


labels_df = pd.read_csv(LABELS_PATH)

# Expect at least: subject_id, label
print("Columns:", labels_df.columns.tolist())

subjects = labels_df["subject_id"].astype(str).tolist()

print("Total subjects:", len(subjects))

# Sanity check
assert len(subjects) == 639, "Subject count mismatch!"


Columns: ['subject_id', 'label']
Total subjects: 639


In [13]:
list((PROJECT_ROOT).glob("**/labels.csv"))


[WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/labels.csv')]

In [15]:
# =========================
# Robust deterministic slice sampling
# =========================

def sample_slices(image_paths, num_slices=16):
    """
    Deterministically sample num_slices images from a list of image paths.
    If fewer images are available, repeat the last image.
    If no images are available, return an empty list.
    """
    image_paths = sorted(image_paths)

    if len(image_paths) == 0:
        return []

    if len(image_paths) >= num_slices:
        indices = np.linspace(
            0, len(image_paths) - 1, num_slices
        ).astype(int)
        return [image_paths[i] for i in indices]
    else:
        return image_paths + [image_paths[-1]] * (num_slices - len(image_paths))


In [17]:
# =========================
# Cell 4: Image Loading & Preprocessing
# =========================

from PIL import Image
import torchvision.transforms as T

# ImageNet normalization (LOCKED)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

preprocess = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),  # converts to [0,1] and (C,H,W)
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

def load_and_preprocess_image(img_path):
    """
    Load a single MRI slice and return a ViT-ready tensor (3, 224, 224).
    """
    img = Image.open(img_path).convert("RGB")
    img = preprocess(img)
    return img


In [18]:
# Sanity check
example_img = load_and_preprocess_image(
    sample_slices([p for p in Path(".").glob("**/*.png")], num_slices=1)[0]
)

print(example_img.shape, example_img.dtype)


torch.Size([3, 224, 224]) torch.float32


In [19]:
# =========================
# Cell 5: Load ViT Model
# =========================

vit = timm.create_model(
    "vit_base_patch16_224",
    pretrained=True,
    num_classes=0  # returns CLS embedding
)

vit = vit.to(DEVICE)
vit.eval()

# Freeze all parameters (safety)
for param in vit.parameters():
    param.requires_grad = False

print("ViT loaded.")
print("ViT output dim:", vit.num_features)


ViT loaded.
ViT output dim: 768


In [21]:
# Diagnostic: locate PNG slices for a known subject
from pathlib import Path

test_subject = "002_S_0295"

png_matches = list(PROJECT_ROOT.glob(f"**/{test_subject}/**/*.png"))
png_matches[:5], len(png_matches)


([WindowsPath('c:/Users/ADMIN/Documents/Alz_work/DATASET/ADNI FULLY PREPROCESSED/002_S_0295/sagittal_slices/sagittal_000.png'),
  WindowsPath('c:/Users/ADMIN/Documents/Alz_work/DATASET/ADNI FULLY PREPROCESSED/002_S_0295/sagittal_slices/sagittal_001.png'),
  WindowsPath('c:/Users/ADMIN/Documents/Alz_work/DATASET/ADNI FULLY PREPROCESSED/002_S_0295/sagittal_slices/sagittal_002.png'),
  WindowsPath('c:/Users/ADMIN/Documents/Alz_work/DATASET/ADNI FULLY PREPROCESSED/002_S_0295/sagittal_slices/sagittal_003.png'),
  WindowsPath('c:/Users/ADMIN/Documents/Alz_work/DATASET/ADNI FULLY PREPROCESSED/002_S_0295/sagittal_slices/sagittal_004.png')],
 200)

In [22]:
# =========================
# Cell 6: ViT Feature Extraction Loop (FINAL)
# =========================

vit_features_dir = FEATURES_DIR
vit_features_dir.mkdir(parents=True, exist_ok=True)

processed = 0
skipped = 0

with torch.no_grad():
    for _, row in tqdm(labels_df.iterrows(), total=len(labels_df)):
        subject_id = str(row["subject_id"])

        # --------------------------------------------------
        # Output path (resume-safe)
        # --------------------------------------------------
        out_path = vit_features_dir / f"{subject_id}.npy"
        if out_path.exists():
            skipped += 1
            continue

        # --------------------------------------------------
        # Subject slice directory (CONFIRMED PATH)
        # --------------------------------------------------
        subject_dir = (
            PROJECT_ROOT
            / "DATASET"
            / "ADNI FULLY PREPROCESSED"
            / subject_id
            / "sagittal_slices"
        )

        if not subject_dir.exists():
            print(f"[WARNING] Subject directory missing: {subject_id}")
            continue

        slice_paths = sorted(subject_dir.glob("*.png"))

        if len(slice_paths) == 0:
            print(f"[WARNING] No slices found for subject {subject_id}")
            continue

        # --------------------------------------------------
        # Deterministic slice sampling (K = 16)
        # --------------------------------------------------
        sampled_slices = sample_slices(slice_paths, num_slices=K_SLICES)

        # --------------------------------------------------
        # Load & preprocess slices
        # --------------------------------------------------
        imgs = []
        for p in sampled_slices:
            try:
                img = load_and_preprocess_image(p)
                imgs.append(img)
            except Exception as e:
                print(f"[ERROR] Failed to load {p}: {e}")
                imgs = []
                break

        if len(imgs) != K_SLICES:
            print(f"[WARNING] Incomplete slice load for subject {subject_id}")
            continue

        imgs = torch.stack(imgs, dim=0).to(DEVICE)  # (K, 3, 224, 224)

        # --------------------------------------------------
        # Forward through frozen ViT
        # --------------------------------------------------
        feats = vit(imgs)  # (K, 768)

        # --------------------------------------------------
        # Subject-level aggregation (mean pooling)
        # --------------------------------------------------
        subj_feat = feats.mean(dim=0)  # (768,)

        # --------------------------------------------------
        # Save subject feature
        # --------------------------------------------------
        np.save(out_path, subj_feat.cpu().numpy())
        processed += 1

print(f"ViT feature extraction completed.")
print(f"Subjects processed: {processed}")
print(f"Subjects skipped (already existed): {skipped}")


100%|██████████| 639/639 [03:36<00:00,  2.95it/s]

ViT feature extraction completed.
Subjects processed: 639
Subjects skipped (already existed): 0





In [23]:
# =========================
# Cell 7: ViT Feature Integrity Checks
# =========================

vit_feature_files = sorted(FEATURES_DIR.glob("*.npy"))

print("Total ViT feature files:", len(vit_feature_files))
assert len(vit_feature_files) == 639, "ViT feature count mismatch!"

# Check shape of a few samples
for f in vit_feature_files[:3]:
    feat = np.load(f)
    print(f.name, feat.shape)
    assert feat.shape == (768,), "Feature shape mismatch!"

print("ViT feature integrity check PASSED.")


Total ViT feature files: 639
002_S_0295.npy (768,)
002_S_0413.npy (768,)
002_S_0619.npy (768,)
ViT feature integrity check PASSED.


In [24]:
# =========================
# Cell 8: CNN–ViT Alignment Check
# =========================

cnn_dir = PROJECT_ROOT / "features" / "cnn"
vit_dir = PROJECT_ROOT / "features" / "vit"

cnn_ids = set(p.stem for p in cnn_dir.glob("*.npy"))
vit_ids = set(p.stem for p in vit_dir.glob("*.npy"))

print("CNN subjects:", len(cnn_ids))
print("ViT subjects:", len(vit_ids))

assert cnn_ids == vit_ids, "CNN and ViT subject IDs do NOT match!"

print("CNN–ViT alignment check PASSED.")


CNN subjects: 639
ViT subjects: 0


AssertionError: CNN and ViT subject IDs do NOT match!

In [25]:
# =========================
# Diagnose CNN–ViT mismatch
# =========================

only_in_cnn = sorted(cnn_ids - vit_ids)
only_in_vit = sorted(vit_ids - cnn_ids)

print("Subjects only in CNN features:", len(only_in_cnn))
print("Subjects only in ViT features:", len(only_in_vit))

print("\nExamples only in CNN:", only_in_cnn[:10])
print("Examples only in ViT:", only_in_vit[:10])


Subjects only in CNN features: 639
Subjects only in ViT features: 0

Examples only in CNN: ['002_S_0295', '002_S_0413', '002_S_0619', '002_S_0685', '002_S_0729', '002_S_0782', '002_S_0816', '002_S_0938', '002_S_0954', '002_S_1018']
Examples only in ViT: []


In [26]:
# Locate actual CNN feature files
from pathlib import Path

list((PROJECT_ROOT / "features").glob("**/*.npy"))[:20]


[WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0295.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0413.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0619.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0685.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0729.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0782.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0816.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0938.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_0954.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_1018.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_1070.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/features/cnn/002_S_1155.npy'),
 WindowsPath('c:/Users/ADMIN/Documents/Alz_work/feat