# Stroke Stage Classification using U-Net + DenseNet
**Dataset:** ISLES 2022 (T2A, DWI, ADC)
- Segment ischemic regions with U-Net
- Classify stroke stage with DenseNet (Acute, Subacute, Chronic)


In [None]:
# Install dependencies
!pip install nibabel kagglehub tensorflow scikit-learn -q

In [None]:
# Import libraries
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical

In [None]:
# Download dataset
import kagglehub
path = kagglehub.dataset_download("smitgandhi2005/isles-dataset")
base_path = os.path.join(path, "ISLES-2022", "ISLES-2022", "Training")

def load_case(case_id):
    case_path = os.path.join(base_path, case_id)
    adc = nib.load(os.path.join(case_path, f"{case_id}_adc.nii")).get_fdata()
    dwi = nib.load(os.path.join(case_path, f"{case_id}_dwi.nii")).get_fdata()
    flair = nib.load(os.path.join(case_path, f"{case_id}_flair.nii")).get_fdata()
    return np.stack([adc, dwi, flair], axis=-1)

In [None]:
# Define U-Net model
def unet_model(input_shape):
    inputs = layers.Input(input_shape)
    def conv_block(x, filters):
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        return x
    def encoder_block(x, filters):
        f = conv_block(x, filters)
        p = layers.MaxPooling2D((2, 2))(f)
        return f, p
    def decoder_block(x, conv_output, filters):
        x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(x)
        x = layers.concatenate([x, conv_output])
        x = conv_block(x, filters)
        return x

    c1, p1 = encoder_block(inputs, 32)
    c2, p2 = encoder_block(p1, 64)
    c3, p3 = encoder_block(p2, 128)
    bn = conv_block(p3, 256)
    d1 = decoder_block(bn, c3, 128)
    d2 = decoder_block(d1, c2, 64)
    d3 = decoder_block(d2, c1, 32)
    outputs = layers.Conv2D(1, (1, 1), activation="sigmoid")(d3)
    return models.Model(inputs, outputs)

In [None]:
# Define DenseNet classifier
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
def build_classifier(input_shape=(64, 64, 3), num_classes=3):
    base = DenseNet121(include_top=False, input_shape=input_shape, weights=None)
    x = GlobalAveragePooling2D()(base.output)
    output = Dense(num_classes, activation="softmax")(x)
    return models.Model(base.input, output)