In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import nibabel as nib
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input, Concatenate
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, roc_auc_score, confusion_matrix

In [2]:
# Parameters
IMG_SIZE = (128, 128)
BATCH_SIZE = 32
DATA_DIR = "./abide2_preprocessed"
CSV_FILE = "./abide2_data.csv" 

# Load CSV
df = pd.read_csv(CSV_FILE, dtype={'SUB_ID': str})

In [3]:
def load_nii_file(file_path):
    try:
        img = nib.load(file_path).get_fdata()
        if len(img.shape) != 2:
            print(f"Warning: {file_path} has unexpected shape {img.shape}")
            return None
        img = tf.image.resize(img[..., np.newaxis], IMG_SIZE).numpy()
        img = (img - img.min()) / (img.max() - img.min() + 1e-10)
        return img
    except Exception as e:
        print(f"Error loading {file_path}: {str(e)}")
        return None


In [6]:
# Load images and labels
images, labels = [], []
for idx, row in df.iterrows():
    sub_id = row["SUB_ID"]
    label = row["DX_GROUP"] - 1
    nii_path = os.path.join(DATA_DIR, sub_id, f"{sub_id}_.nii")
    if os.path.exists(nii_path):
        img = load_nii_file(nii_path)
        if img is not None:
            images.append(img)
            labels.append(label)
    else:
        print(f"File not found: {nii_path}")

images = np.array(images)
labels = np.array(labels)
print(f"Total images loaded: {len(images)}, Total labels: {len(labels)}")

if len(images) == 0 or len(labels) == 0:
    raise ValueError("No images or labels loaded.")

File not found: ./abide2_preprocessed\28682\28682_.nii
File not found: ./abide2_preprocessed\28817\28817_.nii
File not found: ./abide2_preprocessed\29305\29305_.nii
File not found: ./abide2_preprocessed\29327\29327_.nii
Total images loaded: 1110, Total labels: 1110


In [7]:

# Train-validation split
X_train, X_val, y_train, y_val = train_test_split(images, labels, test_size=0.2, random_state=42)
print(f"Training set: {len(X_train)}, Validation set: {len(X_val)}")

# Define EfficientNetB0 Model
base_model = EfficientNetB0(include_top=False, weights='imagenet', input_shape=(128, 128, 3))
base_model.trainable = False

inputs = Input(shape=(128, 128, 1))
x = Concatenate()([inputs, inputs, inputs])  # Grayscale to RGB
x = base_model(x, training=False)
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.25)(x)
outputs = Dense(1, activation='sigmoid')(x)

model = Model(inputs, outputs)

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

Training set: 888, Validation set: 222


In [8]:
# Training
history = model.fit(X_train, y_train,
                    batch_size=BATCH_SIZE,
                    epochs=20,
                    validation_data=(X_val, y_val))

Epoch 1/20


[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 3s/step - accuracy: 0.4971 - loss: 0.7203 - val_accuracy: 0.4685 - val_loss: 0.6959
Epoch 2/20
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 2s/step - accuracy: 0.5597 - loss: 0.7062 - val_accuracy: 0.5315 - val_loss: 0.7156
Epoch 3/20
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 680ms/step - accuracy: 0.5094 - loss: 0.7185 - val_accuracy: 0.5315 - val_loss: 0.6963
Epoch 4/20
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 710ms/step - accuracy: 0.5008 - loss: 0.7074 - val_accuracy: 0.5315 - val_loss: 0.6952
Epoch 5/20
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 705ms/step - accuracy: 0.5410 - loss: 0.6949 - val_accuracy: 0.4685 - val_loss: 0.7006
Epoch 6/20
[1m28/28[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 663ms/step - accuracy: 0.4719 - loss: 0.7096 - val_accuracy: 0.5315 - val_loss: 0.6912
Epoch 7/20
[1m28/28[0m [32m━━━━━━━━

In [10]:
# Inference
NII_FILE = "./abide2_preprocessed/28681/28681_.nii"
def load_single_image(file_path):
    img = load_nii_file(file_path)
    return np.expand_dims(img, axis=0)

img = load_single_image(NII_FILE)
prediction = model.predict(img)[0][0]
if prediction >= 0.5:
    print(f"Prediction: ASD (Confidence: {prediction:.4f})")
else:
    print(f"Prediction: TD (Confidence: {1 - prediction:.4f})")

# Evaluation
y_pred_val = model.predict(X_val).flatten()
y_pred_binary = (y_pred_val > 0.5).astype(int)

def evaluate_model(y_true, y_pred, y_prob=None, model_name=""):
    if model_name:
        print(f"--- Performance Metrics for {model_name} ---")

    metrics = {}
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    sensitivity = recall_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    specificity = tn / (tn + fp)

    metrics['Accuracy'] = accuracy
    metrics['F1-Score'] = f1
    metrics['Sensitivity'] = sensitivity
    metrics['Specificity'] = specificity

    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"Sensitivity (TPR): {sensitivity:.4f}")
    print(f"Specificity (TNR): {specificity:.4f}")

    if y_prob is not None:
        try:
            auc = roc_auc_score(y_true, y_prob)
            metrics['AUC-ROC'] = auc
            print(f"AUC-ROC: {auc:.4f}")
        except ValueError:
            print("Warning: AUC-ROC requires both classes present.")
    return metrics

# Call evaluation
evaluation_metrics = evaluate_model(y_val, y_pred_binary, y_pred_val, model_name="EfficientNetB0 ABIDE")
print("\nValidation Metrics:", evaluation_metrics)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 14s/step
Prediction: ASD (Confidence: 0.5309)
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 398ms/step
--- Performance Metrics for EfficientNetB0 ABIDE ---
Accuracy: 0.5315
F1-Score: 0.6941
Sensitivity (TPR): 1.0000
Specificity (TNR): 0.0000
AUC-ROC: 0.4639

Validation Metrics: {'Accuracy': 0.5315315315315315, 'F1-Score': 0.6941176470588235, 'Sensitivity': 1.0, 'Specificity': np.float64(0.0), 'AUC-ROC': np.float64(0.4638608213820078)}
