<a href="https://colab.research.google.com/github/ashithkp7/Alzeimer-s-Disease-Classification/blob/main/ProjectFinal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# train_slice_attention.py
import os, random, math
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
import nibabel as nib
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")


DATA_ROOT = "/content/drive/MyDrive/Project_A"
OUT_DIR = "./slice_attention_output"
MODE = "images"
NUM_SLICES = 8
IMAGE_SIZE = 224
BATCH_SIZE = 8
NUM_EPOCHS = 4
LR = 1e-4
RANDOM_SEED = 42
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_PRETRAINED = True
SAVE_BEST = True

SAVE_ATTENTION_PRINTS = False
VOLUMES_ROOT = os.path.join(DATA_ROOT, "Volumes")
IMAGES_ROOT = os.path.join(DATA_ROOT, "OriginalDataset")
FOLDER_TO_LABEL = {
    "NonDemented": "CN",
    "MildDemented": "MCI",
    "ModerateDemented": "MCI",
    "VeryMildDemented": "AD"
}

os.makedirs(OUT_DIR, exist_ok=True)
random.seed(RANDOM_SEED); np.random.seed(RANDOM_SEED); torch.manual_seed(RANDOM_SEED)

IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")

def find_volume_file(subject_folder):
    p = Path(subject_folder)
    if not p.exists(): return None
    patterns = ["**/*.nii", "**/*.nii.gz", "**/*.img", "**/*.hdr", "*.nii", "*.nii.gz"]
    for pat in patterns:
        found = list(p.glob(pat))
        if found:
            found_sorted = sorted(found, key=lambda x: x.stat().st_size, reverse=True)
            return str(found_sorted[0].resolve())
    return None


def load_volume_as_numpy(path):
    img = nib.load(path)
    data = img.get_fdata()
    data = np.nan_to_num(data.astype(np.float32))
    m = np.mean(data); s = np.std(data) + 1e-8
    data = (data - m) / s
    return data


def select_slices_from_volume(vol, num_slices=8, axis=2, mode='uniform'):
    vol = np.moveaxis(vol, axis, -1)
    n = vol.shape[-1]
    if n == 0: return []
    if n <= num_slices:
        idxs = list(range(n))
    else:
        if mode=='uniform':
            step = n / num_slices
            idxs = [int(i*step) for i in range(num_slices)]
        elif mode=='center':
            c = n//2; half = num_slices//2; start = max(0, c-half); idxs = list(range(start, start+num_slices))
        else:
            idxs = sorted(random.sample(range(n), num_slices))
    slices = [np.atleast_2d(np.squeeze(vol[..., i])) for i in idxs]
    return slices


def build_groups_from_images(images_root, num_slices=8):
    class2files = {}
    root = Path(images_root)
    for cls_dir in sorted(root.iterdir()):
        if not cls_dir.is_dir(): continue
        if cls_dir.name not in FOLDER_TO_LABEL: continue
        label = FOLDER_TO_LABEL[cls_dir.name]
        files = []
        for ext in IMAGE_EXTS:
            files += sorted([str(p) for p in cls_dir.rglob(f"*{ext}")])
        if len(files)==0: continue
        class2files.setdefault(label, []).extend(files)
    groups=[]
    label2idx = {c:i for i,c in enumerate(sorted(class2files.keys()))}
    for cls, fls in class2files.items():
        chunks = [fls[i:i+num_slices] for i in range(0, len(fls), num_slices)]
        for i,chunk in enumerate(chunks):
            groups.append({"group_id": f"{cls}_g{i:04d}", "paths": chunk, "label_name": cls, "label_idx": label2idx[cls]})
    return groups, label2idx


def build_groups_from_volumes(volumes_root, num_slices=8):
    groups=[]
    for subj in sorted(Path(volumes_root).iterdir()):
        if not subj.is_dir(): continue
        volfile = find_volume_file(subj)
        if volfile is None: continue
        try:
            vol = load_volume_as_numpy(volfile)
        except Exception as e:
            print("fail load", volfile, e); continue
        slices = select_slices_from_volume(vol, num_slices=num_slices, axis=2, mode='uniform')
        if len(slices)==0: continue
        name = subj.name.upper()
        label_guess = "unknown"
        if "AD" in name or "ALZ" in name: label_guess = "AD"
        elif "MCI" in name: label_guess = "MCI"
        elif "CN" in name or "NORMAL" in name or "NON" in name: label_guess = "CN"
        groups.append({"group_id": subj.name, "slices": slices, "label_name": label_guess})
    label_names = sorted(list({g['label_name'] for g in groups}))
    label2idx = {n:i for i,n in enumerate(label_names)}
    for g in groups: g['label_idx'] = label2idx[g['label_name']]
    return groups, label2idx

class ImageGroupDataset(Dataset):
    def __init__(self, groups, num_slices=8, transform=None, image_size=224):
        self.groups = groups; self.num_slices=num_slices; self.transform=transform; self.image_size=image_size
    def __len__(self): return len(self.groups)
    def sample_slices(self, paths):
        n=len(paths)
        if n >= self.num_slices:
            step = n / self.num_slices
            idxs = [int(i*step) for i in range(self.num_slices)]
        else:
            idxs = [random.randrange(n) for _ in range(self.num_slices)]
        return [paths[i] for i in idxs]
    def __getitem__(self, idx):
        g=self.groups[idx]; paths=g['paths']; sel=self.sample_slices(paths)
        imgs=[]
        for p in sel:
            img = Image.open(p).convert("RGB")
            if self.transform: img = self.transform(img)
            imgs.append(img)
        x = torch.stack(imgs, dim=0)
        return x, int(g['label_idx']), g['group_id']

class VolumeGroupDataset(Dataset):
    def __init__(self, groups, num_slices=8, transform=None, image_size=224):
        self.groups=groups; self.num_slices=num_slices; self.transform=transform; self.image_size=image_size
    def __len__(self): return len(self.groups)
    def sample_slices(self, slices_list):
        n=len(slices_list)
        if n >= self.num_slices:
            step = n/self.num_slices
            idxs=[int(i*step) for i in range(self.num_slices)]
        else:
            idxs=[random.randrange(n) for _ in range(self.num_slices)]
        return [slices_list[i] for i in idxs]
    def __getitem__(self, idx):
        g=self.groups[idx]; slices_list=g['slices']; sel=self.sample_slices(slices_list)
        imgs=[]
        for sl in sel:
            mn, mx = float(np.min(sl)), float(np.max(sl))
            if mx - mn < 1e-6:
                arr = np.zeros_like(sl, dtype=np.uint8)
            else:
                sln = (sl - mn) / (mx-mn)
                arr = (sln*255.0).astype(np.uint8)
            pil = Image.fromarray(arr).convert("RGB")
            if self.transform: pil = self.transform(pil)
            imgs.append(pil)
        x = torch.stack(imgs, dim=0)
        return x, int(g['label_idx']), g['group_id']

class SliceAttentionModel(nn.Module):
    def __init__(self, pretrained=True, embedding_dim=512, num_classes=3):
        super().__init__()
        if hasattr(models, "ResNet18_Weights") and pretrained:
            res = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        else:
            res = models.resnet18(pretrained=pretrained)
        features = list(res.children())[:-1]
        self.backbone = nn.Sequential(*features)
        self.embedding_dim = res.fc.in_features
        self.att_mlp = nn.Sequential(nn.Linear(self.embedding_dim,128), nn.ReLU(), nn.Linear(128,1))
        self.classifier = nn.Sequential(nn.Linear(self.embedding_dim,256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256,num_classes))
    def forward(self,x):
        B,S,C,H,W = x.shape
        x = x.view(B*S, C, H, W)
        feats = self.backbone(x)
        feats = feats.view(B, S, self.embedding_dim)
        scores = self.att_mlp(feats).squeeze(-1)
        weights = torch.softmax(scores, dim=1)
        fused = (feats * weights.unsqueeze(-1)).sum(dim=1)
        out = self.classifier(fused)
        return out, weights

def visualize_attention_stack(x_tensor, weights, out_path):
    inv_norm = T.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                           std=[1/0.229,1/0.224,1/0.225])
    x_vis = inv_norm(x_tensor).clamp(0,1)
    to_pil = T.ToPILImage()
    S = x_vis.shape[0]
    fig, axes = plt.subplots(1, S+1, figsize=(3*(S+1),3))
    axes[0].bar(range(S), weights); axes[0].set_title("weights")
    for i in range(S):
        axes[i+1].imshow(to_pil(x_vis[i].cpu()).convert('L')); axes[i+1].axis('off'); axes[i+1].set_title(f"{weights[i]:.2f}")
    plt.tight_layout(); plt.savefig(out_path); plt.close()

def main():
    if MODE == "nifti":
        groups, label2idx = build_groups_from_volumes(VOLUMES_ROOT, num_slices=NUM_SLICES)
    else:
        groups, label2idx = build_groups_from_images(IMAGES_ROOT, num_slices=NUM_SLICES)
    print("Labels:", label2idx, "Total groups:", len(groups))
    if len(groups) < 4:
        print("Warning: not enough groups.")

    df = pd.DataFrame(groups)
    trainval, test = train_test_split(df, test_size=0.10, stratify=df['label_idx'], random_state=RANDOM_SEED)
    rel_val = 0.15 / (1 - 0.10)
    train, val = train_test_split(trainval, test_size=rel_val, stratify=trainval['label_idx'], random_state=RANDOM_SEED)
    train_g = train.to_dict(orient='records'); val_g = val.to_dict(orient='records'); test_g = test.to_dict(orient='records')
    pd.DataFrame(train_g).to_csv(os.path.join(OUT_DIR,"train_groups.csv"), index=False)
    pd.DataFrame(val_g).to_csv(os.path.join(OUT_DIR,"val_groups.csv"), index=False)
    pd.DataFrame(test_g).to_csv(os.path.join(OUT_DIR,"test_groups.csv"), index=False)

    train_transform = T.Compose([T.Resize((IMAGE_SIZE,IMAGE_SIZE)), T.RandomHorizontalFlip(), T.RandomRotation(8),
                                 T.ColorJitter(0.08,0.08), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
    val_transform = T.Compose([T.Resize((IMAGE_SIZE,IMAGE_SIZE)), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])

    if MODE=="nifti":
        train_ds = VolumeGroupDataset(train_g, num_slices=NUM_SLICES, transform=train_transform, image_size=IMAGE_SIZE)
        val_ds = VolumeGroupDataset(val_g, num_slices=NUM_SLICES, transform=val_transform, image_size=IMAGE_SIZE)
        test_ds = VolumeGroupDataset(test_g, num_slices=NUM_SLICES, transform=val_transform, image_size=IMAGE_SIZE)
    else:
        train_ds = ImageGroupDataset(train_g, num_slices=NUM_SLICES, transform=train_transform, image_size=IMAGE_SIZE)
        val_ds = ImageGroupDataset(val_g, num_slices=NUM_SLICES, transform=val_transform, image_size=IMAGE_SIZE)
        test_ds = ImageGroupDataset(test_g, num_slices=NUM_SLICES, transform=val_transform, image_size=IMAGE_SIZE)

    train_labels = np.array([g['label_idx'] for g in train_g])
    class_counts = np.bincount(train_labels)
    class_counts = np.where(class_counts==0, 1, class_counts)
    class_weights = 1.0 / class_counts
    samples_weight = class_weights[train_labels]
    sampler = WeightedRandomSampler(torch.from_numpy(samples_weight).float(), num_samples=len(samples_weight), replacement=True)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    model = SliceAttentionModel(pretrained=USE_PRETRAINED, embedding_dim=512, num_classes=len(label2idx)).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float32).to(DEVICE))

    best_val_acc = 0.0
    best_path = os.path.join(OUT_DIR, "best_slice_attn.pt")
    for epoch in range(1, NUM_EPOCHS+1):
        model.train()
        running_loss = 0.0
        for X, y, gids in train_loader:
            X = X.to(DEVICE).float(); y = y.to(DEVICE).long()
            opt.zero_grad()
            out, _ = model(X)
            loss = criterion(out, y)
            loss.backward()
            opt.step()
            running_loss += loss.item() * X.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        model.eval()
        preds, trues = [], []
        with torch.no_grad():
            for X, y, gids in val_loader:
                X = X.to(DEVICE).float(); y = y.to(DEVICE).long()
                out, _ = model(X)
                preds.extend(out.argmax(dim=1).cpu().numpy().tolist())
                trues.extend(y.cpu().numpy().tolist())
        val_acc = accuracy_score(trues, preds) if len(trues) else 0.0
        print(f"Epoch {epoch}/{NUM_EPOCHS} train_loss={train_loss:.4f} val_acc={val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            if SAVE_BEST:
                torch.save(model.state_dict(), best_path)

    if SAVE_BEST and os.path.exists(best_path):
        model.load_state_dict(torch.load(best_path))
    model.eval()
    preds, trues, gids = [], [], []
    with torch.no_grad():
        for X, y, group_ids in test_loader:
            X = X.to(DEVICE).float()
            out, w = model(X)
            preds.extend(out.argmax(dim=1).cpu().tolist())
            trues.extend(y.cpu().tolist())
            gids.extend(group_ids)
    print("Test acc:", accuracy_score(trues, preds))
    print(classification_report(trues, preds, digits=4))
    cm = confusion_matrix(trues, preds)
    np.savetxt(os.path.join(OUT_DIR, "confusion_matrix.csv"), cm, delimiter=",")

    vis_dir = os.path.join(OUT_DIR, "attention_viz"); os.makedirs(vis_dir, exist_ok=True)
    nvis = min(12, len(test_g))
    for i in range(nvis):
        X, y, gid = test_ds[i]
        xb = X.unsqueeze(0).to(DEVICE).float()
        with torch.no_grad():
            out, weights = model(xb)
        weights_np = weights.cpu().numpy().squeeze()
        inv_norm = T.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229,1/0.224,1/0.225])
        X_vis = inv_norm(X)

        visualize_attention_stack(X_vis, weights_np, os.path.join(vis_dir, f"{gid}_attn.png"))

        used_paths = None
        try:
            used_paths = test_g[i].get('paths', None)
        except Exception:
            used_paths = None

        to_pil = T.ToPILImage()

        if used_paths and SAVE_ATTENTION_PRINTS:
            paths_all = used_paths
            n_all = len(paths_all)
            if n_all >= NUM_SLICES:
                step = n_all / NUM_SLICES
                idxs = [int(j*step) for j in range(NUM_SLICES)]
            else:
                idxs = list(range(n_all))
            selected_paths = [paths_all[k] for k in idxs]
            for idx, (p, w) in enumerate(zip(selected_paths, weights_np)):
                print(f"Slice {idx+1}: {p}  --> Attention Weight = {w:.4f}")
            order = np.argsort(weights_np)[::-1]
            print("\n=== Ranking (most -> least) ===")
            for rank, idx in enumerate(order, 1):
                print(f"Rank {rank}: {selected_paths[idx]}  (Weight = {weights_np[idx]:.4f})")
            best_idx = order[0]
            best_path = selected_paths[best_idx]
            print("\n*** MOST INFORMATIVE SLICE ***")
            print(f"Image: {best_path}")
            print(f"Attention Weight: {weights_np[best_idx]:.4f}\n")
            try:
                best_img = to_pil(X_vis[best_idx].cpu())
                best_img.save(os.path.join(vis_dir, f"{gid}_best_slice.png"))
            except Exception as e:
                print("Could not save best slice image from X_vis:", e)
        else:
            order = np.argsort(weights_np)[::-1]
            best_idx = order[0]
            try:
                best_img = to_pil(X_vis[best_idx].cpu())
                best_img.save(os.path.join(vis_dir, f"{gid}_best_slice.png"))
            except Exception:
                pass

    print("Done. Outputs in:", OUT_DIR)

if __name__ == "__main__":
    main()




Labels: {'AD': 0, 'CN': 1, 'MCI': 2} Total groups: 802
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 125MB/s]


Epoch 1/4 train_loss=0.6199 val_acc=0.7273
Epoch 2/4 train_loss=0.3143 val_acc=0.7603
Epoch 3/4 train_loss=0.2242 val_acc=0.8430
Epoch 4/4 train_loss=0.1515 val_acc=0.9504
Test acc: 0.9629629629629629
              precision    recall  f1-score   support

           0     0.9310    0.9643    0.9474        28
           1     0.9750    0.9512    0.9630        41
           2     1.0000    1.0000    1.0000        12

    accuracy                         0.9630        81
   macro avg     0.9687    0.9718    0.9701        81
weighted avg     0.9635    0.9630    0.9631        81

Done. Outputs in: ./slice_attention_output


In [3]:

# inference_slice_attention.py

import os, random
from pathlib import Path
import numpy as np
import torch, torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
import nibabel as nib
import matplotlib.pyplot as plt


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_FOLDER = "/content/drive/MyDrive/Project_A/Test_images/"
NIFTI_FILE = ""
MODEL_PATH = "./slice_attention_output/best_slice_attn.pt"
OUT_DIR = "./inference_outputs"
NUM_SLICES = 8
IMAGE_SIZE = 224
MEAN = [0.485,0.456,0.406]; STD = [0.229,0.224,0.225]

idx2label = {0:"AD", 1:"CN", 2:"MCI"}
IMAGE_EXTS = (".jpg",".jpeg",".png",".bmp",".tif",".tiff")
os.makedirs(OUT_DIR, exist_ok=True)

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

transform = T.Compose([T.Resize((IMAGE_SIZE,IMAGE_SIZE)), T.ToTensor(), T.Normalize(MEAN,STD)])

class SliceAttentionModel(nn.Module):
    def __init__(self, pretrained=False, embedding_dim=512, num_classes=3):
        super().__init__()
        if hasattr(models, "ResNet18_Weights") and pretrained:
            res = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        else:
            res = models.resnet18(pretrained=pretrained)
        layers = list(res.children())[:-1]
        self.backbone = nn.Sequential(*layers)
        self.embedding_dim = res.fc.in_features
        self.att_mlp = nn.Sequential(nn.Linear(self.embedding_dim,128), nn.ReLU(), nn.Linear(128,1))
        self.classifier = nn.Sequential(nn.Linear(self.embedding_dim,256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256,num_classes))
    def forward(self, x):
        B,S,C,H,W = x.shape
        x = x.view(B*S, C, H, W)
        feats = self.backbone(x)
        feats = feats.view(B, S, self.embedding_dim)
        scores = self.att_mlp(feats).squeeze(-1)
        weights = torch.softmax(scores, dim=1)
        fused = (feats * weights.unsqueeze(-1)).sum(dim=1)
        out = self.classifier(fused)
        return out, weights

# ---------- helpers ----------
def build_stack_from_folder(folder, num_slices=NUM_SLICES):
    files = sorted([str(p) for p in Path(folder).glob("*") if p.suffix.lower() in IMAGE_EXTS])
    if len(files) == 0:
        raise ValueError("No images found in folder: " + str(folder))
    n = len(files)

    if n < num_slices:
        selected = files[:]
    else:
        step = n / num_slices
        idxs = [int(i*step) for i in range(num_slices)]
        selected = [files[i] for i in idxs]

    tensors = []
    for p in selected:
        img = Image.open(p).convert("RGB")
        tensors.append(transform(img))
    stack = torch.stack(tensors, dim=0)
    return stack, selected


def extract_slices_from_nifti(nifti_path, num_slices=NUM_SLICES, axis=2):
    img = nib.load(nifti_path)
    vol = img.get_fdata().astype(float)
    vol = np.nan_to_num(vol)
    vol = (vol - vol.mean()) / (vol.std() + 1e-8)
    vol = np.moveaxis(vol, axis, -1)
    n = vol.shape[-1]
    if n == 0: raise ValueError("No slices in NIfTI")
    if n >= num_slices:
        step = n / num_slices
        idxs = [int(i*step) for i in range(num_slices)]
    else:
        idxs = [min(n-1, int(i)) for i in range(num_slices)]
    slices = []
    for i in idxs:
        sl = vol[..., i]
        mn, mx = float(sl.min()), float(sl.max())
        if mx-mn < 1e-6:
            arr = np.zeros((sl.shape[0], sl.shape[1]), dtype=np.uint8)
        else:
            sln = (sl - mn) / (mx - mn)
            arr = (sln * 255.0).astype(np.uint8)
        pil = Image.fromarray(arr).convert("RGB")
        slices.append(transform(pil))
    stack = torch.stack(slices, dim=0)
    return stack, idxs


def save_topk(stack, weights, prefix, paths=None, topk=6):
    """
    stack: tensor (S,C,H,W) before batch dim and before normalization inversion
    weights: 1D numpy array of length S
    paths: optional list of original filenames corresponding to stack order
    """
    inv = T.Normalize(mean=[-m/s for m,s in zip(MEAN,STD)], std=[1/s for s in STD])
    x_vis = inv(stack).clamp(0,1)
    to_pil = T.ToPILImage()
    order = np.argsort(weights)[::-1]
    S = len(weights)
    topk = min(topk, S)
    for r, idx in enumerate(order[:topk], start=1):
        pil = to_pil(x_vis[idx].cpu())
        if paths:
            base = os.path.basename(paths[idx])
            name = f"{os.path.splitext(base)[0]}"
        else:
            name = f"slice{idx}"
        pil.save(os.path.join(OUT_DIR, f"{prefix}_rank{r}_{name}_w{weights[idx]:.4f}.png"))
    plt.figure(figsize=(6,2)); plt.bar(range(S), weights); plt.title("attention weights")
    plt.savefig(os.path.join(OUT_DIR, f"{prefix}_weights.png")); plt.close()

def run_inference_from_folder(folder):
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model not found at {MODEL_PATH}. Train and save the model first or set MODEL_PATH correctly.")
    stack, used_paths = build_stack_from_folder(folder)
    xb = stack.unsqueeze(0).to(DEVICE)
    model = SliceAttentionModel(pretrained=False, embedding_dim=512, num_classes=len(idx2label)).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
    with torch.no_grad():
        out, weights = model(xb)
        probs = torch.softmax(out, dim=1).cpu().numpy().squeeze()
        weights_np = weights.cpu().numpy().squeeze()
    pred_idx = int(np.argmax(probs)); pred_label = idx2label[pred_idx]
    save_topk(stack, weights_np, "subject_folder", paths=used_paths, topk=min(6, len(weights_np)))

    print("\n=== Slice Attention Ranking ===")
    for i, (p,w) in enumerate(zip(used_paths, weights_np), start=1):
        print(f"Slice {i}: {p}  --> Attention Weight = {w:.4f}")
    order = np.argsort(weights_np)[::-1]
    print("\n=== Ranking (most -> least) ===")
    for rank, idx in enumerate(order, start=1):
        print(f"Rank {rank}: {used_paths[idx]}  (Weight = {weights_np[idx]:.4f})")
    best_idx = order[0]
    print("\n*** MOST INFORMATIVE SLICE ***")
    print(f"Image: {used_paths[best_idx]}")
    print(f"Attention Weight: {weights_np[best_idx]:.4f}\n")
    return {"pred_label": pred_label, "probs": probs, "weights": weights_np, "used_paths": used_paths}


def run_inference_from_nifti(nifti_path):
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model not found at {MODEL_PATH}. Train and save the model first or set MODEL_PATH correctly.")
    stack, slice_idxs = extract_slices_from_nifti(nifti_path)
    xb = stack.unsqueeze(0).to(DEVICE)
    model = SliceAttentionModel(pretrained=False, embedding_dim=512, num_classes=len(idx2label)).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
    with torch.no_grad():
        out, weights = model(xb)
        probs = torch.softmax(out, dim=1).cpu().numpy().squeeze()
        weights_np = weights.cpu().numpy().squeeze()
    pred_idx = int(np.argmax(probs)); pred_label = idx2label[pred_idx]
    save_topk(stack, weights_np, "subject_nifti", paths=None, topk=min(6, len(weights_np)))
    print("\n=== Slice Attention Ranking (by index) ===")
    for i, w in enumerate(weights_np):
        print(f"Slice index {slice_idxs[i]}  --> Attention Weight = {w:.4f}")
    order = np.argsort(weights_np)[::-1]
    print("\n=== Ranking (most -> least) ===")
    for rank, idx in enumerate(order, start=1):
        print(f"Rank {rank}: slice index {slice_idxs[idx]}  (Weight = {weights_np[idx]:.4f})")
    best_idx = order[0]
    print("\n*** MOST INFORMATIVE SLICE (index) ***")
    print(f"Slice index: {slice_idxs[best_idx]}, Weight: {weights_np[best_idx]:.4f}\n")
    return {"pred_label": pred_label, "probs": probs, "weights": weights_np, "slice_indices": slice_idxs}

if __name__ == "__main__":
    if TEST_FOLDER and Path(TEST_FOLDER).exists():
        res = run_inference_from_folder(TEST_FOLDER)
        print("Folder prediction:", res["pred_label"])
        print("Probabilities:", res["probs"])
        print("Saved attention images to:", OUT_DIR)
    elif NIFTI_FILE and Path(NIFTI_FILE).exists():
        res = run_inference_from_nifti(NIFTI_FILE)
        print("NIfTI prediction:", res["pred_label"])
        print("Probabilities:", res["probs"])
        print("Saved attention images to:", OUT_DIR)
    else:
        print("No valid TEST_FOLDER or NIFTI_FILE found. Edit the script CONFIG paths and try again.")



=== Slice Attention Ranking ===
Slice 1: /content/drive/MyDrive/Project_A/Test_images/verymildDem948.jpg  --> Attention Weight = 0.0355
Slice 2: /content/drive/MyDrive/Project_A/Test_images/verymildDem949.jpg  --> Attention Weight = 0.1102
Slice 3: /content/drive/MyDrive/Project_A/Test_images/verymildDem95.jpg  --> Attention Weight = 0.2304
Slice 4: /content/drive/MyDrive/Project_A/Test_images/verymildDem96.jpg  --> Attention Weight = 0.3449
Slice 5: /content/drive/MyDrive/Project_A/Test_images/verymildDem97.jpg  --> Attention Weight = 0.1269
Slice 6: /content/drive/MyDrive/Project_A/Test_images/verymildDem98.jpg  --> Attention Weight = 0.0804
Slice 7: /content/drive/MyDrive/Project_A/Test_images/verymildDem99.jpg  --> Attention Weight = 0.0718

=== Ranking (most -> least) ===
Rank 1: /content/drive/MyDrive/Project_A/Test_images/verymildDem96.jpg  (Weight = 0.3449)
Rank 2: /content/drive/MyDrive/Project_A/Test_images/verymildDem95.jpg  (Weight = 0.2304)
Rank 3: /content/drive/MyDrive/