In [None]:
# ================================================================
# Project: Fetal Health Classification using Neural Networks
# ================================================================
# Author: [Islam Abdulrahim]
# Description:
# This project builds and evaluates a neural network model to classify fetal health conditions
# (Normal, Suspect, Pathological) based on physiological features recorded from cardiotocograms.
#
# Dataset:
# The dataset "fetal_health.csv" contains medical signal measurements used to assess fetal well-being.
# Each row represents a patient record with 21 numerical features and a target variable (fetal_health).
#
# Objective:
# To accurately predict the fetal health status using a deep learning model.
#
# Steps:
# 1. Load and explore the dataset.
# 2. Visualize feature distributions and relationships with the target.
# 3. Preprocess data (scaling, train-test split).
# 4. Build and train a neural network classifier.
# 5. Evaluate the model using accuracy, confusion matrix, and classification report.
# ================================================================

In [None]:
# 1. Import Required Libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# 2. Load Dataset
df = pd.read_csv('/content/fetal_health.csv')
df.sample(5)


In [None]:
# Display dataset shape and information
df.shape
df.info()
df.isnull().sum()

In [None]:
# 3. Visualization Setup
sns.set(style="whitegrid", palette="Set2")
plt.rcParams['figure.figsize'] = (8, 5)


In [None]:
# 4. Define Feature Columns
features = [
    'baseline value', 'accelerations', 'fetal_movement', 'uterine_contractions',
    'light_decelerations', 'severe_decelerations', 'prolongued_decelerations',
    'abnormal_short_term_variability', 'mean_value_of_short_term_variability',
    'percentage_of_time_with_abnormal_long_term_variability',
    'histogram_min', 'histogram_max', 'histogram_number_of_peaks',
    'histogram_number_of_zeroes', 'histogram_mode', 'histogram_mean',
    'histogram_median', 'histogram_variance', 'histogram_tendency'
]

In [None]:
# 5. Visualize the Relationship Between Each Feature and the Target Variable
for col in features:
    plt.figure()
    sns.boxplot(x='fetal_health', y=col, data=df)
    plt.title(f'{col} vs Fetal Health', fontsize=13)
    plt.xlabel('Fetal Health (1=Normal, 2=Suspect, 3=Pathological)')
    plt.ylabel(col)
    plt.tight_layout()
    plt.show()
    print('\n\n')

In [None]:
# 6. Feature and Target Separation
X = df[features]
y = df['fetal_health'] - 1  # Adjust labels for TensorFlow (0,1,2)


In [None]:
# 7. Feature Scaling
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X = scaler.fit_transform(X)



In [None]:
import joblib
joblib.dump(scaler, '/content/standard_scaler.pkl')
print("✅ Scaler saved successfully.")

In [None]:
# 8. Train-Test Split
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


In [None]:
# 9. Build Neural Network Model
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential([
    layers.Dense(128, activation='relu', input_shape=(x_train.shape[1],)),
    layers.Dropout(0.2),
    layers.Dense(64, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(3, activation='softmax')  # Output layer for 3 fetal health classes
])


In [None]:
# 10. Compile the Model
from tensorflow.keras.optimizers import Adam

model.compile(
    optimizer=Adam(learning_rate=0.0005),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
# 11. Set Up Callbacks
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-5)

In [None]:
# 12. Train the Model
history = model.fit(
    x_train, y_train,
    epochs=100,
    batch_size=32,
    validation_data=(X, y),
    callbacks=[early_stop, reduce_lr]
)


In [None]:
# 13. Evaluate Model on Test Set
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"\nTest Accuracy: {test_acc*100:.2f}%")

In [None]:
# Accuracy Plot
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
plt.plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2, linestyle='--')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
print('\n\n')
# Loss Plot
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss', linewidth=2)
plt.plot(history.history['val_loss'], label='Validation Loss', linewidth=2, linestyle='--')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# 15. Confusion Matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)

cm = confusion_matrix(y_test, y_pred_classes)
labels = [1, 2, 3]

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix for Fetal Health Classification")
plt.show()

In [None]:
# 16. Classification Report
from sklearn.metrics import classification_report

print("Classification Report:")
print(classification_report(y_test, y_pred_classes, target_names=['Health 1', 'Health 2', 'Health 3']))


In [None]:
# ======================
# 17. Results Summary
# ======================
print("=============================================")
print("RESULTS SUMMARY")
print("=============================================")
print(f"Final Test Accuracy: {test_acc*100:.2f}%")
print("The model demonstrates strong capability in classifying fetal health conditions.")
print("Health 1 (Normal) predictions are highly accurate, while Health 2 and 3 show solid generalization.")
print("=============================================")

In [None]:
# ======================
# 18.*************
# ======================
import joblib

# 1. Save the trained model
model.save('/content/fetal_health_model.h5')
print("✅ Model saved successfully at '/content/fetal_health_model.h5'")

In [None]:
import gradio as gr
import numpy as np
import joblib
from tensorflow.keras.models import load_model

# ============================
# Load the trained model & scaler
# ============================
model = load_model('/content/fetal_health_model.h5')
scaler = joblib.load('/content/standard_scaler.pkl')

# ============================
# Define prediction function
# ============================
def predict_fetal_health(
    baseline_value, accelerations, fetal_movement, uterine_contractions,
    light_decelerations, severe_decelerations, prolongued_decelerations,
    abnormal_short_term_variability, mean_value_of_short_term_variability,
    percentage_of_time_with_abnormal_long_term_variability,
    histogram_min, histogram_max, histogram_number_of_peaks,
    histogram_number_of_zeroes, histogram_mode, histogram_mean,
    histogram_median, histogram_variance, histogram_tendency
):
    # Prepare input data
    input_data = np.array([[baseline_value, accelerations, fetal_movement, uterine_contractions,
                            light_decelerations, severe_decelerations, prolongued_decelerations,
                            abnormal_short_term_variability, mean_value_of_short_term_variability,
                            percentage_of_time_with_abnormal_long_term_variability,
                            histogram_min, histogram_max, histogram_number_of_peaks,
                            histogram_number_of_zeroes, histogram_mode, histogram_mean,
                            histogram_median, histogram_variance, histogram_tendency]])

    # Apply scaling
    scaled_data = scaler.transform(input_data)

    # Predict
    prediction = model.predict(scaled_data)
    predicted_class = np.argmax(prediction, axis=1)[0] + 1

    # Label mapping
    labels = {1: "Normal", 2: "Suspect", 3: "Pathological"}
    messages = {
        "Normal": "🟢 Normal (Healthy Fetus)",
        "Suspect": "🟡 Suspect (Needs Monitoring)",
        "Pathological": "🔴 Pathological (High Risk)"
    }

    return messages[labels[predicted_class]]

# ============================
# Custom CSS for Professional Dark Theme
# ============================
custom_css = """
/* General Styling */
.gradio-container {
    font-family: 'Segoe UI', 'Roboto', 'Helvetica Neue', sans-serif;
    background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
    color: #e0e0e0;
    max-width: 1200px !important;
    margin: auto;
    padding: 20px;
}

/* Header Styling */
.main-header {
    text-align: center;
    margin-bottom: 30px;
    padding: 20px;
    background: rgba(22, 33, 62, 0.7);
    border-radius: 15px;
    box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3);
    backdrop-filter: blur(10px);
    border: 1px solid rgba(255, 255, 255, 0.1);
}

h1 {
    color: #64ffda;
    font-size: 2.5rem;
    margin-bottom: 10px;
    text-shadow: 0 0 10px rgba(100, 255, 218, 0.5);
}

h2, h3 {
    color: #bb86fc;
    margin-top: 15px;
}

/* Input Sections */
.input-section {
    background: rgba(22, 33, 62, 0.6);
    border-radius: 12px;
    padding: 20px;
    margin-bottom: 20px;
    box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
    border: 1px solid rgba(255, 255, 255, 0.05);
}

.section-title {
    color: #03dac6;
    font-size: 1.2rem;
    margin-bottom: 15px;
    padding-bottom: 8px;
    border-bottom: 1px solid rgba(3, 218, 198, 0.3);
}

/* Input Fields */
.form-group {
    margin-bottom: 15px;
}

label {
    color: #e1bee7;
    font-weight: 500;
}

input[type="number"] {
    background: rgba(30, 30, 45, 0.8) !important;
    border: 1px solid rgba(187, 134, 252, 0.3) !important;
    border-radius: 8px !important;
    color: #e0e0e0 !important;
    padding: 10px !important;
    transition: all 0.3s ease !important;
}

input[type="number"]:focus {
    border-color: #bb86fc !important;
    box-shadow: 0 0 0 2px rgba(187, 134, 252, 0.2) !important;
    outline: none !important;
}

/* Button Styling */
.btn-primary {
    background: linear-gradient(45deg, #bb86fc, #03dac6) !important;
    border: none !important;
    border-radius: 8px !important;
    color: #121212 !important;
    font-weight: 600 !important;
    padding: 12px 30px !important;
    transition: all 0.3s ease !important;
    box-shadow: 0 4px 15px rgba(187, 134, 252, 0.3) !important;
}

.btn-primary:hover {
    transform: translateY(-2px) !important;
    box-shadow: 0 6px 20px rgba(187, 134, 252, 0.4) !important;
}

/* Output Box */
.output-box {
    background: rgba(22, 33, 62, 0.7);
    border-radius: 12px;
    padding: 20px;
    margin-top: 20px;
    box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
    border: 1px solid rgba(255, 255, 255, 0.05);
    text-align: center;
}

.output-text {
    font-size: 1.5rem;
    font-weight: 600;
    padding: 15px;
    border-radius: 8px;
    margin-top: 10px;
}

.normal-result {
    background: rgba(76, 175, 80, 0.2);
    border: 1px solid rgba(76, 175, 80, 0.5);
    color: #81c784;
}

.suspect-result {
    background: rgba(255, 193, 7, 0.2);
    border: 1px solid rgba(255, 193, 7, 0.5);
    color: #ffd54f;
}

.pathological-result {
    background: rgba(244, 67, 54, 0.2);
    border: 1px solid rgba(244, 67, 54, 0.5);
    color: #e57373;
}

/* Footer */
.footer {
    display: none !important;
}

/* Tabs */
.tabs {
    background: rgba(22, 33, 62, 0.6) !important;
    border-radius: 12px !important;
    padding: 10px !important;
    margin-bottom: 20px !important;
}

.tab-nav {
    border-bottom: 1px solid rgba(187, 134, 252, 0.3) !important;
}

.tab-button {
    color: #bb86fc !important;
    font-weight: 500 !important;
}

.tab-button.selected {
    color: #03dac6 !important;
    border-bottom: 2px solid #03dac6 !important;
}

/* Info Box */
.info-box {
    background: rgba(3, 218, 198, 0.1);
    border-left: 4px solid #03dac6;
    border-radius: 8px;
    padding: 15px;
    margin: 20px 0;
}

/* Responsive Design */
@media (max-width: 768px) {
    .gradio-container {
        padding: 10px;
    }

    h1 {
        font-size: 2rem;
    }

    .input-section {
        padding: 15px;
    }
}
"""

# ============================
# Create input groups for better organization
# ============================
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as app:
    # Header
    gr.HTML("""
    <div class="main-header">
        <h1>🩺 Fetal Health Classification System</h1>
        <p style="color: #b0bec5; font-size: 1.1rem;">Advanced AI-powered fetal health assessment tool</p>
    </div>
    """)

    # Information Box
    gr.HTML("""
    <div class="info-box">
        <p style="margin: 0; color: #e0e0e0;">
            <strong>Instructions:</strong> Enter the fetal monitoring parameters below to get an instant health assessment.
            The system uses a neural network trained on comprehensive fetal health data.
        </p>
    </div>
    """)

    with gr.Tabs():
        with gr.Tab("Basic Parameters"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### 📊 Basic Fetal Monitoring")
                    with gr.Group(elem_classes=["input-section"]):
                        baseline_value = gr.Number(label="Baseline Value (bpm)", value=120)
                        accelerations = gr.Number(label="Accelerations", value=0.0)
                        fetal_movement = gr.Number(label="Fetal Movement", value=0.0)
                        uterine_contractions = gr.Number(label="Uterine Contractions", value=0.0)

                with gr.Column():
                    gr.Markdown("### ⚠️ Decelerations")
                    with gr.Group(elem_classes=["input-section"]):
                        light_decelerations = gr.Number(label="Light Decelerations", value=0.0)
                        severe_decelerations = gr.Number(label="Severe Decelerations", value=0.0)
                        prolongued_decelerations = gr.Number(label="Prolonged Decelerations", value=0.0)

        with gr.Tab("Advanced Parameters"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### 📈 Variability Metrics")
                    with gr.Group(elem_classes=["input-section"]):
                        abnormal_short_term_variability = gr.Number(label="Abnormal Short Term Variability", value=50)
                        mean_value_of_short_term_variability = gr.Number(label="Mean Value of Short Term Variability", value=0.5)
                        percentage_of_time_with_abnormal_long_term_variability = gr.Number(label="Percentage of Time with Abnormal Long Term Variability", value=30)

                with gr.Column():
                    gr.Markdown("### 📊 Histogram Parameters")
                    with gr.Group(elem_classes=["input-section"]):
                        histogram_min = gr.Number(label="Histogram Min", value=50)
                        histogram_max = gr.Number(label="Histogram Max", value=180)
                        histogram_number_of_peaks = gr.Number(label="Histogram Number of Peaks", value=2)
                        histogram_number_of_zeroes = gr.Number(label="Histogram Number of Zeroes", value=0)

        with gr.Tab("Statistical Parameters"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### 📉 Histogram Statistics")
                    with gr.Group(elem_classes=["input-section"]):
                        histogram_mode = gr.Number(label="Histogram Mode", value=120)
                        histogram_mean = gr.Number(label="Histogram Mean", value=120)
                        histogram_median = gr.Number(label="Histogram Median", value=120)
                        histogram_variance = gr.Number(label="Histogram Variance", value=10)
                        histogram_tendency = gr.Number(label="Histogram Tendency", value=0)

    # Prediction Button
    predict_btn = gr.Button("🔍 Predict Fetal Health", elem_classes=["btn-primary"], size="lg")

    # Output Section
    gr.HTML("""
    <div class="output-box">
        <h3 style="color: #03dac6; margin-top: 0;">Prediction Result</h3>
    </div>
    """)

    output = gr.Textbox(label="", elem_classes=["output-text"], interactive=False)

    # Set up the prediction
    predict_btn.click(
        fn=predict_fetal_health,
        inputs=[
            baseline_value, accelerations, fetal_movement, uterine_contractions,
            light_decelerations, severe_decelerations, prolongued_decelerations,
            abnormal_short_term_variability, mean_value_of_short_term_variability,
            percentage_of_time_with_abnormal_long_term_variability,
            histogram_min, histogram_max, histogram_number_of_peaks,
            histogram_number_of_zeroes, histogram_mode, histogram_mean,
            histogram_median, histogram_variance, histogram_tendency
        ],
        outputs=output
    )

    # Footer
    gr.HTML("""
    <div style="text-align: center; margin-top: 30px; color: #757575; font-size: 0.9rem;">
        <p>Fetal Health Classification System | Powered by Deep Learning</p>
        <p style="margin-top: 5px;">This tool is for informational purposes only and should not replace professional medical advice.</p>
    </div>
    """)

# Launch the app
app.launch()