In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import os
import numpy as np
from sklearn.model_selection import train_test_split
import kagglehub

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, f1_score, classification_report

DEIT_MODEL_NAME = 'deit_small_patch16_224'
DEIT_EMBED_DIM = 384
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 16
NUM_CLASSES = 4
LEARNING_RATE = 1e-4
LEARNING_RATE_FINETUNE = 1e-5
EPOCHS = 15
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

RESULTS_DIR = "training_results"
if not os.path.exists(RESULTS_DIR):
    os.makedirs(RESULTS_DIR)

print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
print(f"Timm version: {timm.__version__}")

print("Downloading dataset...")
try:
    data_dir = kagglehub.dataset_download("abdullahtauseef2003/adni-4c-alzheimers-mri-classification-dataset")
    print(f"Dataset downloaded to: {data_dir}")
except Exception as e:
    print(f"Error downloading dataset via kagglehub: {e}")
data_dir = "./adni_data"
print(f"Attempting to use local path: {data_dir}")
if not os.path.isdir(data_dir):
    exit(f"Error: Dataset directory not found at fallback path: {data_dir}")

diagnosis_mapping = { 0: 'CN', 1: 'EMCI', 2: 'LMCI', 3: 'AD' }
dir_to_code = { 'CN': 0, 'EMCI': 1, 'LMCI': 2, 'AD': 3 }
class_names = list(diagnosis_mapping.values())

def get_image_paths_and_labels(base_data_dir):
    image_paths = []
    labels = []
    possible_image_dirs = [
        os.path.join(base_data_dir, 'ADNI_4C_MRI_Classification_Dataset', 'AugmentedAlzheimerDataset'),
        os.path.join(base_data_dir, 'AugmentedAlzheimerDataset'),
        os.path.join(base_data_dir, 'ADNI_IMAGES')
    ]
    images_dir = None
    print("--- Searching for Image Directory ---")
    for dir_path in possible_image_dirs:
        print(f"Checking: {dir_path}")
        if os.path.isdir(dir_path):
            has_any_subdir = any(os.path.isdir(os.path.join(dir_path, class_name)) for class_name in dir_to_code)
            if has_any_subdir:
                images_dir = dir_path
                print(f"Found valid image directory: {images_dir}")
                break
            else:
                print(f"  > Directory exists but lacks expected class subdirectories.")
        else:
            print(f"  > Directory does not exist.")
    print("--- Search Complete ---")

    if images_dir is None:
        print(f"\nERROR: Could not find a valid image directory structure within '{base_data_dir}'")
        return [], []

    print(f"\nUsing image directory: {images_dir}")
    for class_name, class_code in dir_to_code.items():
        class_dir = os.path.join(images_dir, class_name)
        if os.path.isdir(class_dir):
            try:
                files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                print(f"Found {len(files)} images in '{class_name}'")
                for file in files:
                    image_paths.append(os.path.join(class_dir, file))
                    labels.append(class_code)
            except Exception as e:
                print(f"Error reading directory {class_dir}: {e}")
        else:
            print(f"Warning: Class directory '{class_dir}' not found. Skipping this class.")

    if not image_paths:
        print("Warning: No image paths were collected. Check class subdirectories and file types.")
    return image_paths, labels

print("\nCollecting image paths and labels...")
image_paths, labels = get_image_paths_and_labels(data_dir)

if not image_paths:
    print("\nNo images found. Please check the dataset path and structure.")
    exit()

labels = np.array(labels)

print(f"Total images found: {len(image_paths)}")
unique, counts = np.unique(labels, return_counts=True)
print("\nClass distribution:")
if len(unique) > 0:
    for label_code, count in zip(unique, counts):
        print(f"Class {diagnosis_mapping.get(label_code, f'Unknown ({label_code})')}: {count} images")
else:
    print("No labels found.")

class AlzheimerMRIDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert("RGB")
            if self.transform:
                image = self.transform(image)
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return image, label
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            return None, None

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE[0], IMAGE_SIZE[1])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

X_train, X_val, y_train, y_val = train_test_split(image_paths, labels, test_size=0.2, stratify=labels, random_state=42)

print(f"\nTraining set size: {len(X_train)}")
print(f"Validation set size: {len(X_val)}")

train_dataset = AlzheimerMRIDataset(X_train, y_train, transform)
val_dataset = AlzheimerMRIDataset(X_val, y_val, transform)

def collate_fn(batch):
    batch = list(filter(lambda x: x[0] is not None, batch))
    if not batch: return torch.Tensor(), torch.Tensor()
    return torch.utils.data.dataloader.default_collate(batch)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)

print("DataLoaders created.")

class MobileNet_DeiT_Hybrid(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, deit_model_name=DEIT_MODEL_NAME, deit_embed_dim=DEIT_EMBED_DIM, freeze_backbone=True, freeze_deit=True):
        super().__init__()
        self.num_classes = num_classes
        self.deit_embed_dim = deit_embed_dim
        mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        self.mobilenet_features = mobilenet.features
        self.num_cnn_features = 1280
        self.feature_map_size = 7

        if freeze_backbone:
            for param in self.mobilenet_features.parameters():
                param.requires_grad = False

        self.projection = nn.Conv2d(self.num_cnn_features, self.deit_embed_dim, kernel_size=1)
        deit_full = timm.create_model(deit_model_name, pretrained=True)
        self.deit_blocks = deit_full.blocks
        self.deit_norm = deit_full.norm

        if freeze_deit:
            for param in self.deit_blocks.parameters():
                param.requires_grad = False

        self.num_patches = self.feature_map_size * self.feature_map_size
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.deit_embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.deit_embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)

        self.classifier = nn.Linear(self.deit_embed_dim, self.num_classes)

    def forward(self, x):
        x = self.mobilenet_features(x)
        x = self.projection(x)
        B, D, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.deit_blocks(x)
        x = self.deit_norm(x)
        cls_output = x[:, 0]
        out = self.classifier(cls_output)
        return out

def save_and_report_results(model, val_loader, device, history, class_names, output_prefix):
    print("\n--- Generating Final Report and Saving Results ---")
    weights_path = os.path.join(RESULTS_DIR, f"{output_prefix}_final_weights.pth")
    torch.save(model.state_dict(), weights_path)
    print(f"Model weights saved to {weights_path}")
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    loss_plot_path = os.path.join(RESULTS_DIR, f"{output_prefix}_loss_plot.png")
    plt.savefig(loss_plot_path)
    plt.close()
    print(f"Loss plot saved to {loss_plot_path}")
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_acc'], label='Training Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    acc_plot_path = os.path.join(RESULTS_DIR, f"{output_prefix}_accuracy_plot.png")
    plt.savefig(acc_plot_path)
    plt.close()
    print(f"Accuracy plot saved to {acc_plot_path}")
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            if images.numel() == 0: continue
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    if not all_labels or not all_preds:
        print("Could not generate F1 score or confusion matrix due to empty validation predictions.")
        return

    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    print(f"\nValidation F1 Score (Macro): {f1_macro:.4f}")
    print(f"Validation F1 Score (Weighted): {f1_weighted:._


Using device: cuda
PyTorch version: 2.5.1+cu124
Timm version: 1.0.14
Downloading dataset...
Dataset downloaded to: /kaggle/input/adni-4c-alzheimers-mri-classification-dataset

Collecting image paths and labels...
--- Searching for Image Directory ---
Checking: /kaggle/input/adni-4c-alzheimers-mri-classification-dataset/ADNI_4C_MRI_Classification_Dataset/AugmentedAlzheimerDataset
  > Directory does not exist.
Checking: /kaggle/input/adni-4c-alzheimers-mri-classification-dataset/AugmentedAlzheimerDataset
Found valid image directory: /kaggle/input/adni-4c-alzheimers-mri-classification-dataset/AugmentedAlzheimerDataset
--- Search Complete ---

Using image directory: /kaggle/input/adni-4c-alzheimers-mri-classification-dataset/AugmentedAlzheimerDataset
Found 6464 images in 'CN'
Found 9600 images in 'EMCI'
Found 8960 images in 'LMCI'
Found 8960 images in 'AD'
Total images found: 33984

Class distribution:
Class CN: 6464 images
Class EMCI: 9600 images
Class LMCI: 8960 images
Class AD: 8960 ima