# Maternal Health Risk Prediction - Exploratory Data Analysis

This notebook performs comprehensive exploratory data analysis on the maternal health risk dataset.

## Contents
1. Data Loading
2. Data Quality Assessment
3. Descriptive Statistics
4. Feature Distributions
5. Correlation Analysis
6. Class Balance Analysis
7. Feature Relationships


In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("Libraries imported successfully!")


## 1. Data Loading

Load the raw maternal health risk dataset.


In [None]:
# Load dataset
data_path = '../data/raw/maternal_health.csv'

try:
    df = pd.read_csv(data_path)
    print(f"✓ Data loaded successfully!")
    print(f"  Shape: {df.shape}")
    print(f"  Rows: {df.shape[0]}")
    print(f"  Columns: {df.shape[1]}")
except FileNotFoundError:
    print("⚠ Dataset not found!")
    print("Please download the dataset from:")
    print("  - Kaggle: https://www.kaggle.com/datasets/andrewmvd/maternal-health-risk-data")
    print("  - UCI: https://archive.ics.uci.edu/ml/datasets/Maternal+Health+Risk+Data+Set")
    print(f"And place it at: {data_path}")


In [None]:
# Display first few rows
df.head(10)


## 2. Data Quality Assessment

Check for missing values, duplicates, and data types.


In [None]:
# Data info
print("Dataset Information:")
print("="*60)
df.info()
print("\n" + "="*60)


In [None]:
# Check for missing values
print("Missing Values:")
print("="*60)
missing = df.isnull().sum()
missing_pct = (df.isnull().sum() / len(df)) * 100

missing_df = pd.DataFrame({
    'Missing Count': missing,
    'Percentage': missing_pct
})

print(missing_df)

if missing.sum() == 0:
    print("\n✓ No missing values found!")
else:
    print(f"\n⚠ Total missing values: {missing.sum()}")


In [None]:
# Check for duplicates
duplicates = df.duplicated().sum()
print(f"Duplicate rows: {duplicates}")

if duplicates > 0:
    print(f"⚠ {duplicates} duplicate rows found ({duplicates/len(df)*100:.2f}%)")
else:
    print("✓ No duplicate rows found!")


## 3. Descriptive Statistics

Compute summary statistics for all features.


In [None]:
# Descriptive statistics
print("Descriptive Statistics:")
print("="*60)
df.describe().round(2)


In [None]:
# Target variable distribution
print("Target Variable Distribution (RiskLevel):")
print("="*60)
risk_counts = df['RiskLevel'].value_counts()
risk_pct = (df['RiskLevel'].value_counts(normalize=True) * 100).round(2)

target_df = pd.DataFrame({
    'Count': risk_counts,
    'Percentage': risk_pct
})

print(target_df)
print(f"\nTotal samples: {len(df)}")


## 4. Feature Distributions

Visualize the distribution of each feature.


In [None]:
# Create histograms for all numerical features
numerical_cols = df.select_dtypes(include=[np.number]).columns

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

for idx, col in enumerate(numerical_cols):
    axes[idx].hist(df[col], bins=30, edgecolor='black', alpha=0.7, color='skyblue')
    axes[idx].set_title(f'Distribution of {col}', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel(col)
    axes[idx].set_ylabel('Frequency')
    axes[idx].grid(True, alpha=0.3)
    
    # Add mean line
    mean_val = df[col].mean()
    axes[idx].axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.2f}')
    axes[idx].legend()

plt.tight_layout()
plt.show()


In [None]:
# Box plots to identify outliers
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

for idx, col in enumerate(numerical_cols):
    axes[idx].boxplot(df[col], vert=True, patch_artist=True,
                     boxprops=dict(facecolor='lightblue', alpha=0.7))
    axes[idx].set_title(f'Box Plot of {col}', fontsize=12, fontweight='bold')
    axes[idx].set_ylabel(col)
    axes[idx].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()


## 5. Correlation Analysis

Examine relationships between features.


In [None]:
# Correlation heatmap
plt.figure(figsize=(10, 8))

# Select only numerical columns
numerical_df = df.select_dtypes(include=[np.number])
correlation = numerical_df.corr()

# Create heatmap
sns.heatmap(correlation, annot=True, fmt='.2f', cmap='coolwarm', 
            center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8})

plt.title('Feature Correlation Heatmap', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

print("\nHighest Correlations (excluding diagonal):")
print("="*60)

# Get correlation pairs
corr_pairs = []
for i in range(len(correlation.columns)):
    for j in range(i+1, len(correlation.columns)):
        corr_pairs.append((correlation.columns[i], correlation.columns[j], correlation.iloc[i, j]))

# Sort by absolute correlation
corr_pairs_sorted = sorted(corr_pairs, key=lambda x: abs(x[2]), reverse=True)

for feat1, feat2, corr_val in corr_pairs_sorted[:5]:
    print(f"  {feat1} <-> {feat2}: {corr_val:.3f}")


## 6. Class Balance Analysis

Analyze the distribution of risk levels.


In [None]:
# Risk level distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar plot
risk_counts = df['RiskLevel'].value_counts()
axes[0].bar(risk_counts.index, risk_counts.values, color=['green', 'orange', 'red'], alpha=0.7, edgecolor='black')
axes[0].set_title('Risk Level Distribution (Count)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Risk Level')
axes[0].set_ylabel('Count')
axes[0].grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for i, v in enumerate(risk_counts.values):
    axes[0].text(i, v + 5, str(v), ha='center', va='bottom', fontweight='bold')

# Pie chart
colors = ['green', 'orange', 'red']
axes[1].pie(risk_counts.values, labels=risk_counts.index, autopct='%1.1f%%', 
           colors=colors, startangle=90, explode=[0.05, 0.05, 0.05])
axes[1].set_title('Risk Level Distribution (Percentage)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Check for class imbalance
print("\nClass Balance Analysis:")
print("="*60)
total = len(df)
for risk_level in risk_counts.index:
    count = risk_counts[risk_level]
    pct = (count / total) * 100
    print(f"{risk_level}: {count} ({pct:.2f}%)")

# Calculate imbalance ratio
max_class = risk_counts.max()
min_class = risk_counts.min()
imbalance_ratio = max_class / min_class

print(f"\nImbalance Ratio: {imbalance_ratio:.2f}:1")
if imbalance_ratio > 2:
    print("⚠ Significant class imbalance detected! Consider using:")
    print("  - Class weights")
    print("  - SMOTE or other resampling techniques")
    print("  - Stratified sampling")
else:
    print("✓ Classes are relatively balanced")


## 7. Feature Relationships with Target

Analyze how each feature relates to the risk level.


In [None]:
# Box plots of features by risk level
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.ravel()

for idx, col in enumerate(numerical_cols):
    df.boxplot(column=col, by='RiskLevel', ax=axes[idx], 
               patch_artist=True, grid=True)
    axes[idx].set_title(f'{col} by Risk Level', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel('Risk Level')
    axes[idx].set_ylabel(col)
    axes[idx].get_figure().suptitle('')  # Remove automatic title

plt.tight_layout()
plt.show()


In [None]:
# Violin plots for better distribution visualization
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.ravel()

risk_order = ['low risk', 'mid risk', 'high risk']

for idx, col in enumerate(numerical_cols):
    sns.violinplot(data=df, x='RiskLevel', y=col, order=risk_order, 
                  palette=['green', 'orange', 'red'], ax=axes[idx])
    axes[idx].set_title(f'{col} Distribution by Risk Level', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel('Risk Level')
    axes[idx].set_ylabel(col)
    axes[idx].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()


In [None]:
# Mean values by risk level
print("Mean Feature Values by Risk Level:")
print("="*60)

mean_by_risk = df.groupby('RiskLevel')[numerical_cols].mean().round(2)
print(mean_by_risk)

# Calculate differences between risk levels
print("\n\nFeature Differences (High Risk - Low Risk):")
print("="*60)

if 'high risk' in mean_by_risk.index and 'low risk' in mean_by_risk.index:
    diff = (mean_by_risk.loc['high risk'] - mean_by_risk.loc['low risk']).sort_values(ascending=False)
    print(diff)


## 8. Pair Plot

Comprehensive view of feature relationships colored by risk level.


In [None]:
# Pair plot (may take a moment to render)
print("Generating pair plot... (this may take a moment)")

# Select a subset of features for clarity
features_subset = ['Age', 'SystolicBP', 'DiastolicBP', 'BS', 'RiskLevel']

if all(col in df.columns for col in features_subset[:-1]):
    pairplot = sns.pairplot(df[features_subset], hue='RiskLevel', 
                           palette=['green', 'orange', 'red'],
                           diag_kind='kde', plot_kws={'alpha': 0.6})
    pairplot.fig.suptitle('Feature Pair Plot by Risk Level', y=1.02, fontsize=16, fontweight='bold')
    plt.show()
    print("✓ Pair plot generated successfully!")
else:
    print("⚠ Some features not found in dataset")


## Key Findings Summary

Based on the exploratory data analysis:


In [None]:
print("="*60)
print("KEY FINDINGS FROM EDA")
print("="*60)

print("\n1. DATA QUALITY:")
print(f"   - Dataset shape: {df.shape}")
print(f"   - Missing values: {df.isnull().sum().sum()}")
print(f"   - Duplicate rows: {df.duplicated().sum()}")

print("\n2. TARGET DISTRIBUTION:")
risk_dist = df['RiskLevel'].value_counts()
for risk, count in risk_dist.items():
    print(f"   - {risk}: {count} ({count/len(df)*100:.1f}%)")

print("\n3. FEATURE CHARACTERISTICS:")
for col in numerical_cols:
    print(f"   - {col}:")
    print(f"     Mean: {df[col].mean():.2f}, Std: {df[col].std():.2f}")
    print(f"     Range: [{df[col].min():.2f}, {df[col].max():.2f}]")

print("\n4. RECOMMENDATIONS:")
print("   ✓ Dataset is clean and ready for modeling")
print("   ✓ Consider feature scaling due to different ranges")
if df.duplicated().sum() > 0:
    print("   ⚠ Remove duplicate rows before modeling")
if imbalance_ratio > 2:
    print("   ⚠ Use class weights or resampling for imbalanced classes")
print("   ✓ Stratified splitting recommended for train/val/test")

print("\n" + "="*60)
print("EDA COMPLETE!")
print("="*60)
