# Alzheimer's Detection Model with Swin

## Overview
This TensorFlow/Keras model detects Alzheimer’s disease using MRI images. It uses transfer learning with a pre-trained MobileNetV2 as the base network and augments it with a custom Vision Transformer (ViT) head. Training is performed in two phases: initial training with a frozen base model and a fine-tuning phase where certain layers are unfrozen.

## Data Pipeline
- **Dataset Download:**  
  The dataset is downloaded using `kagglehub` and is expected to contain PNG images sorted into folders by diagnosis.
- **Label Mapping:**  
  Two mappings are defined:
  - Numeric-to-diagnosis (e.g., 0: 'CN', 4: 'AD').
  - Folder name to numeric label.
- **Image Processing:**  
  Images are read from disk, decoded, resized to 160×160, normalized, and their labels are one-hot encoded.
- **Data Pipeline:**  
  Constructed using the `tf.data` API, with operations including shuffling, mapping, caching, batching, and prefetching for efficient training.

## Model Architecture
- **Base Network (MobileNetV2):**  
  - Pre-trained on ImageNet (without the top classifier).
  - Input: 160×160×3 images.
  - **Layer Freezing:** Approximately 90% of the layers are frozen during the initial phase.
  
- **Vision Transformer (ViT) Head:**  
  - **Projection Layer:** Projects patch features into a defined embedding space.
  - **Class Token:** A learnable token is prepended to the patch sequence.
  - **Positional Embeddings:** Learnable positional information is added.
  - **Transformer Blocks:** Multiple blocks consisting of layer normalization, multi-head self-attention, MLP layers (with gelu activation and dropout), and residual connections.
  - **Output:** The final class token is extracted as the representative feature.

- **Classification Head:**
  - A Dropout layer (50%) is applied.
  - A Dense layer with softmax activation outputs class probabilities (for 5 classes).

## Training Strategy
- **Phase 1: Training with Frozen Base**
  - **Objective:** Train only the top layers (ViT head and classifier) while keeping most of the MobileNetV2 weights fixed.
  - **Optimizer:** Adam with a learning rate of 1e-4.
  - **Loss:** Categorical crossentropy.
  - **Epochs:** 5 (initial phase; adjustable).
  
- **Phase 2: Fine-Tuning**
  - **Objective:** Unfreeze the last 10% of the MobileNetV2 layers to fine-tune the model.
  - **Optimizer:** Adam with a reduced learning rate of 1e-5.
  - **Epochs:** Additional 5 epochs (adjustable).

## Evaluation and Visualization
- **Callbacks:**  
  EarlyStopping and ModelCheckpoint are used to prevent overfitting and store the best model based on validation loss.
- **Training Visualization:**  
  A plotting function displays the accuracy and loss curves for both phases.
- **Final Evaluation:**  
  The model is evaluated on the validation set, and the final validation accuracy is reported.


In [10]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.applications import MobileNetV2,ResNet50
from tensorflow.keras.mixed_precision import set_global_policy
import math
import matplotlib.pyplot as plt
import pandas as pd
import os
from sklearn.model_selection import train_test_split
import numpy as np
import kagglehub

import os
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from tfswin import SwinTransformerTiny224

In [None]:
# !pip install tfswin


In [None]:

set_global_policy('mixed_float16')
data_dir = "/kaggle/input/adni-for-ad-progression"
print(f"Dataset downloaded to: {data_dir}")

diagnosis_mapping = {0: 'CN', 1: 'MCI', 2: 'EMCI', 3: 'LMCI', 4: 'AD'}
dir_to_code       = {v: k for k, v in diagnosis_mapping.items()}

def get_image_paths_and_labels():
    image_paths, labels = [], []
    images_dir = os.path.join(data_dir, 'ADNI_IMAGES', 'png_images')
    for cls, code in dir_to_code.items():
        folder = os.path.join(images_dir, cls)
        if os.path.exists(folder):
            for f in os.listdir(folder):
                if f.lower().endswith('.png'):
                    image_paths.append(os.path.join(folder, f))
                    labels.append(code)
    return np.array(image_paths), np.array(labels)


train_paths, val_paths, train_labels, val_labels = train_test_split(
    image_paths, labels,
    test_size=0.2,
    random_state=42,
    stratify=labels
)
print(f"\nTraining:   {len(train_paths)} images")
print(f"Validation: {len(val_paths)} images")
IMAGE_SIZE = (160, 160)
BATCH_SIZE  = 32
NUM_CLASSES = len(diagnosis_mapping)

def decode_image(path, label):
    img =tf.io.read_file(path)
    img =tf.image.decode_png(img, channels=3)
    img =tf.image.resize(img, IMAGE_SIZE)
    img =tf.cast(img, tf.float32) / 255.0
    lbl =tf.one_hot(label, NUM_CLASSES)
    return img, lbl

train_ds = (
    tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
      .shuffle(len(train_paths))
      .map(decode_image, num_parallel_calls=tf.data.AUTOTUNE)
      .cache()
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
)
val_ds = (
    tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
      .map(decode_image, num_parallel_calls=tf.data.AUTOTUNE)
      .cache()
      .batch(BATCH_SIZE)
      .prefetch(tf.data.AUTOTUNE)
)
#using swin as backbone now different approach from before this would be another model in our ensemble models
swin_backbone = SwinTransformerTiny224(
    include_top=False,       
    weights='imagenet',    
    input_shape=(*IMAGE_SIZE, 3),
)

inputs = layers.Input(shape=(*IMAGE_SIZE, 3))
x =swin_backbone(inputs)#4d map
x =layers.GlobalAveragePooling2D()(x)
x =layers.Dropout(0.5)(x)
outputs = layers.Dense(NUM_CLASSES, activation='softmax', dtype='float32')(x)

model= models.Model(inputs, outputs)
model.compile(
    optimizer=optimizers.Adam(1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy'],
    jit_compile=True,
)
early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
ckpt1= callbacks.ModelCheckpoint('alz_swin_phase1.keras', monitor='val_loss', save_best_only=True)

print("\nPhase 1: Training with Swin frozen")
history1 = model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=[early_stop, ckpt1])

for layer in swin_backbone.layers:
    layer.trainable = True

model.compile(
    optimizer=optimizers.Adam(1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy'],
    jit_compile=True,
)
ckpt2 = callbacks.ModelCheckpoint('alz_swin_phase2.keras', monitor='val_loss', save_best_only=True)

print("\nPhase 2: Fine-tuning full Swin")
history2 = model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=[early_stop, ckpt2])

model.save('final_alzheimer_swin.keras')

def plot_history(h1, h2):
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1,2,1)
    plt.plot(h1.history['accuracy'], label='Phase1 Train')
    plt.plot(h1.history['val_accuracy'], label='Phase1 Val')
    off = len(h1.history['accuracy'])
    x2  = [off + i for i in range(len(h2.history['accuracy']))]
    plt.plot(x2, h2.history['accuracy'], label='Phase2 Train')
    plt.plot(x2, h2.history['val_accuracy'], label='Phase2 Val')
    plt.title('Accuracy'); plt.xlabel('Epoch'); plt.legend()

    plt.subplot(1,2,2)
    plt.plot(h1.history['loss'], label='Phase1 Train')
    plt.plot(h1.history['val_loss'], label='Phase1 Val')
    plt.plot(x2, h2.history['loss'], label='Phase2 Train')
    plt.plot(x2, h2.history['val_loss'], label='Phase2 Val')
    plt.title('Loss'); plt.xlabel('Epoch'); plt.legend()
    plt.tight_layout(); plt.show()

plot_history(history1, history2)
test_loss, test_acc = model.evaluate(val_ds)
print(f"\nFinal Validation Accuracy: {test_acc}")


Dataset downloaded to: /kaggle/input/adni-for-ad-progression

Collecting image paths and labels…

Class distribution:
  CN: 4077 images
  MCI: 4073 images
  EMCI: 3958 images
  LMCI: 4074 images
  AD: 4075 images

Training:   16205 images
Validation: 4052 images
Downloading data from https://github.com/shkarupa-alex/tfswin/releases/download/3.0.0/swin_tiny_patch4_window7_224_22k.h5
[1m177485300/177485300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step

Phase 1: Training with Swin frozen
Epoch 1/10


I0000 00:00:1746038567.316768      90 service.cc:148] XLA service 0x7d1af4025550 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1746038567.317624      90 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1746038572.998186      90 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1746038608.896080      90 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m254/254[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m256s[0m 644ms/step - accuracy: 0.2039 - loss: 1.8865 - val_accuracy: 0.2011 - val_loss: 1.7813
Epoch 2/10
[1m254/254[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 357ms/step - accuracy: 0.2011 - loss: 1.8954 - val_accuracy: 0.2011 - val_loss: 1.7813
Epoch 3/10
[1m254/254[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 356ms/step - accuracy: 0.2055 - loss: 1.8879 - val_accuracy: 0.2011 - val_loss: 1.7813
Epoch 4/10
[1m254/254[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 356ms/step - accuracy: 0.2026 - loss: 1.8970 - val_accuracy: 0.2011 - val_loss: 1.7813

Phase 2: Fine-tuning full Swin
Epoch 1/10
[1m254/254[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 477ms/step - accuracy: 0.2023 - loss: 1.8993