# üß† Complete ASD Screening ML Training Pipeline

## SenseAI Project - Machine Learning Model Development

---

### What This Notebook Does:
1. **Binary Classification**: ASD vs Non-ASD
2. **Severity Classification**: Low, Moderate, High cognitive risk
3. **Feature Importance Analysis**: Which markers matter most
4. **Model Comparison**: Logistic Regression, Random Forest, XGBoost, SVM, Ordinal Regression

### Algorithms Used:
- **Logistic Regression** - Binary classification baseline
- **Random Forest** - Feature importance + non-linear patterns
- **XGBoost** - Best performance for structured data
- **SVM** - Non-linear decision boundaries
- **Ordinal Regression** - Severity level prediction (ordered categories)

---

## üìã Step 1: Setup Environment


In [None]:
# ============================================
# STEP 1: Install Required Packages (FIXED VERSION)
# ============================================
# Run this cell first!

!pip install pandas numpy scikit-learn xgboost lightgbm mord matplotlib seaborn joblib imbalanced-learn -q

print("‚úÖ All packages installed successfully!")
print("\nüì¶ Packages:")
print("  ‚Ä¢ pandas - Data manipulation")
print("  ‚Ä¢ numpy - Numerical computing")
print("  ‚Ä¢ scikit-learn - ML algorithms")
print("  ‚Ä¢ xgboost - Gradient boosting")
print("  ‚Ä¢ lightgbm - Fast gradient boosting (NEW)")
print("  ‚Ä¢ mord - Ordinal regression")
print("  ‚Ä¢ imbalanced-learn - SMOTE for class imbalance (NEW)")
print("  ‚Ä¢ matplotlib/seaborn - Visualization")


In [None]:
# ============================================
# STEP 2: Import Libraries (FIXED VERSION)
# ============================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    precision_score, recall_score, f1_score, roc_auc_score, roc_curve
)
import xgboost as xgb
import lightgbm as lgb  # NEW
from mord import LogisticAT  # Ordinal regression (FIXED)
from imblearn.over_sampling import SMOTE  # NEW - for class imbalance
import warnings
warnings.filterwarnings('ignore')

# Set style for plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("‚úÖ All libraries imported successfully!")
print("   ‚Ä¢ Added: LightGBM, Ordinal Regression (mord), SMOTE")


## üìÅ Step 2: Upload Your Dataset

Choose ONE of the following methods to upload your data:


In [None]:
# ============================================
# OPTION A: Direct File Upload (Recommended for first time)
# ============================================
# Run this cell and select your CSV file from your computer

from google.colab import files

print("=" * 70)
print("üì§ UPLOAD YOUR DATASET")
print("=" * 70)
print("\n‚úÖ RECOMMENDED DATASET:")
print("   üìÅ File: improved_merged_dataset.csv")
print("   üìä Rows: 500 (250 ASD + 250 Control)")
print("   üéØ Expected Accuracy: 85-90% (realistic and excellent!)")
print("   üìç Location: SAMPLE_DATASETS/improved_merged_dataset.csv")
print("\n‚ö†Ô∏è ALTERNATIVE (if you don't have improved one):")
print("   üìÅ File: merged_complete_dataset.csv")
print("   üìä Rows: 180 (90 ASD + 90 Control)")
print("   ‚ö†Ô∏è May show 95%+ accuracy (overfitting)")
print("\n" + "=" * 70)
print("üëâ SELECT: improved_merged_dataset.csv (if available)")
print("=" * 70)

uploaded = files.upload()

print("\n‚úÖ Files uploaded successfully!")
print(f"üìÅ Uploaded files: {list(uploaded.keys())}")
print("\n‚ö†Ô∏è IMPORTANT: Note the EXACT filename above!")
print("   Example: If you see 'improved_merged_dataset.csv', use that exact name.")
print("   The next cell will use this filename automatically.")


In [None]:
# ============================================
# OPTION B: Google Drive (For large datasets)
# ============================================
# Uncomment and run this if you prefer Google Drive

# from google.colab import drive
# drive.mount('/content/drive')
# 
# # Set your Google Drive path
# DRIVE_PATH = '/content/drive/MyDrive/SAMPLE_DATASETS/'
# df = pd.read_csv(DRIVE_PATH + 'merged_complete_dataset.csv')
# print(f"‚úÖ Loaded from Google Drive: {len(df)} samples")


## üìä Step 3: Load and Explore Data


In [None]:
# ============================================
# STEP 3: Load the Dataset
# ============================================

# Load the merged dataset (after upload)
# 
# üìÅ WHICH DATASET TO USE:
# 
# ‚úÖ RECOMMENDED: improved_merged_dataset.csv
#    - 500 rows (250 ASD + 250 Control)
#    - Realistic noise and variation
#    - Expected accuracy: 85-90% (realistic and excellent!)
#    - Best for ML training and thesis
#
# ‚ö†Ô∏è ALTERNATIVE: merged_complete_dataset.csv
#    - 180 rows (90 ASD + 90 Control)
#    - May show 95%+ accuracy (overfitting - too perfect)
#    - Use only if you don't have improved_merged_dataset.csv
#
# üîß HOW TO CHANGE:
#   1. After uploading in the previous cell, check the filename shown
#   2. Update 'dataset_filename' below to match EXACTLY
#   3. Example: If upload shows 'improved_merged_dataset.csv', use that

# Change this to match your uploaded filename:
dataset_filename = 'improved_merged_dataset.csv'  # ‚Üê CHANGE THIS if your file has different name

# Try to load the dataset
try:
    df = pd.read_csv(dataset_filename)
    print(f"‚úÖ Successfully loaded: {dataset_filename}")
except FileNotFoundError:
    print(f"‚ùå ERROR: File '{dataset_filename}' not found!")
    print("\nüí° SOLUTIONS:")
    print("   1. Check the filename from the upload output above")
    print("   2. Make sure you uploaded the file in the previous cell")
    print("   3. Update 'dataset_filename' above to match exactly")
    print("\n   Common filenames:")
    print("   - improved_merged_dataset.csv (RECOMMENDED)")
    print("   - merged_complete_dataset.csv (Alternative)")
    raise

print(f"‚úÖ Loaded dataset: {dataset_filename}")

print("=" * 60)
print("üìä DATASET OVERVIEW")
print("=" * 60)
print(f"\nüìà Total Samples: {len(df)}")
print(f"üìã Total Features: {len(df.columns)}")

# Show class distribution
print("\nüè∑Ô∏è Class Distribution:")
if 'asd_label' in df.columns:
    print(f"   ASD (1): {sum(df['asd_label'] == 1)}")
    print(f"   Control (0): {sum(df['asd_label'] == 0)}")

if 'severity_label' in df.columns:
    print("\nüìä Severity Distribution:")
    print(df['severity_label'].value_counts())

# Show age group distribution
if 'age_group' in df.columns:
    print("\nüë∂ Age Group Distribution:")
    print(df['age_group'].value_counts())

# Display first few rows
print("\nüìã Sample Data:")
df.head()


In [None]:
# ============================================
# STEP 3b: Data Visualization
# ============================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: ASD vs Control distribution
if 'asd_label' in df.columns:
    ax1 = axes[0, 0]
    colors = ['#2ecc71', '#e74c3c']
    labels = ['Control (TD)', 'ASD']
    counts = df['asd_label'].value_counts().sort_index()
    ax1.bar(labels, counts.values, color=colors)
    ax1.set_title('üè∑Ô∏è ASD vs Control Distribution', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Count')
    for i, v in enumerate(counts.values):
        ax1.text(i, v + 0.5, str(v), ha='center', fontweight='bold')

# Plot 2: Age group distribution
if 'age_group' in df.columns:
    ax2 = axes[0, 1]
    age_counts = df['age_group'].value_counts()
    ax2.bar(age_counts.index, age_counts.values, color=['#3498db', '#9b59b6', '#f39c12'])
    ax2.set_title('üë∂ Age Group Distribution', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Count')

# Plot 3: Severity distribution (ASD only)
if 'severity_label' in df.columns:
    ax3 = axes[1, 0]
    asd_df = df[df['asd_label'] == 1]
    if len(asd_df) > 0:
        sev_counts = asd_df['severity_label'].value_counts().sort_index()
        colors = ['#27ae60', '#f39c12', '#e74c3c']
        ax3.bar(['Level 1\n(Mild)', 'Level 2\n(Moderate)', 'Level 3\n(Severe)'], 
                sev_counts.values[:3] if len(sev_counts) >= 3 else sev_counts.values, 
                color=colors[:len(sev_counts)])
        ax3.set_title('üìä ASD Severity Distribution', fontsize=12, fontweight='bold')
        ax3.set_ylabel('Count')

# Plot 4: Risk level distribution
if 'risk_level' in df.columns:
    ax4 = axes[1, 1]
    risk_counts = df['risk_level'].value_counts()
    colors = ['#27ae60', '#f39c12', '#e74c3c']
    ax4.bar(risk_counts.index, risk_counts.values, color=colors[:len(risk_counts)])
    ax4.set_title('‚ö†Ô∏è Risk Level Distribution', fontsize=12, fontweight='bold')
    ax4.set_ylabel('Count')

plt.tight_layout()
plt.show()


## üîß Step 4: Feature Engineering

Calculate key ASD markers using clinical equations:


In [None]:
# ============================================
# STEP 4: Feature Engineering - Key ASD Equations (FIXED)
# ============================================

print("üîß FEATURE ENGINEERING - Key ASD Equations")
print("=" * 60)

# Display the equations being used
equations = """
üìê KEY EQUATIONS FOR ASD DETECTION:

1Ô∏è‚É£ SWITCH COST (Cognitive Flexibility)
   Switch_Cost = RT_PostSwitch - RT_PreSwitch
   ‚û§ High value (>400ms) indicates cognitive rigidity

2Ô∏è‚É£ PERSEVERATIVE ERROR RATE (Rule Adherence)
   Perseverative_Rate = (Perseverative_Errors / Post_Switch_Trials) √ó 100
   ‚û§ High rate (>30%) indicates difficulty adapting to new rules

3Ô∏è‚É£ ACCURACY DROP (Rule Switching Difficulty)
   Accuracy_Drop = ((Pre_Accuracy - Post_Accuracy) / Pre_Accuracy) √ó 100
   ‚û§ High drop (>20%) indicates rule-switching problems

4Ô∏è‚É£ INHIBITION ERROR RATE (Impulse Control - Frog Jump)
   Commission_Error_Rate = (Commission_Errors / Total_NoGo_Trials) √ó 100
   ‚û§ High rate (>40%) indicates inhibitory control deficit

5Ô∏è‚É£ REACTION TIME VARIABILITY (Attention Consistency)
   RT_Variability = Standard_Deviation(Reaction_Times)
   ‚û§ High variability (>300ms) indicates attention issues
"""
print(equations)

# ============================================
# CALCULATE DERIVED FEATURES (NEW - FIXED)
# ============================================

print("\nüîß Calculating derived features...")

# 1. Switch Cost (if DCCS data available)
if 'avg_rt_pre_switch_ms' in df.columns and 'avg_rt_post_switch_correct_ms' in df.columns:
    df['switch_cost_ms'] = df['avg_rt_post_switch_correct_ms'] - df['avg_rt_pre_switch_ms']
    df['switch_cost_ms'] = df['switch_cost_ms'].fillna(0)
    print("   ‚úÖ Added: switch_cost_ms")
else:
    df['switch_cost_ms'] = 0

# 2. Accuracy Drop (if DCCS data available)
if 'pre_switch_accuracy' in df.columns and 'post_switch_accuracy' in df.columns:
    df['accuracy_drop_percent'] = ((df['pre_switch_accuracy'] - df['post_switch_accuracy']) / 
                                    df['pre_switch_accuracy'].replace(0, 1)) * 100
    df['accuracy_drop_percent'] = df['accuracy_drop_percent'].fillna(0)
    print("   ‚úÖ Added: accuracy_drop_percent")
else:
    df['accuracy_drop_percent'] = 0

# 3. Commission Error Rate (if Frog Jump data available)
if 'commission_errors' in df.columns and 'nogo_trials' in df.columns:
    df['commission_error_rate_calc'] = (df['commission_errors'] / df['nogo_trials'].replace(0, 1)) * 100
    df['commission_error_rate_calc'] = df['commission_error_rate_calc'].fillna(0)
    print("   ‚úÖ Added: commission_error_rate_calc")
elif 'commission_errors' in df.columns and 'total_trials' in df.columns:
    df['commission_error_rate_calc'] = (df['commission_errors'] / df['total_trials'].replace(0, 1)) * 100
    df['commission_error_rate_calc'] = df['commission_error_rate_calc'].fillna(0)
    print("   ‚úÖ Added: commission_error_rate_calc (from total_trials)")
else:
    df['commission_error_rate_calc'] = 0

# 4. Perseverative Error Rate (if DCCS data available)
if 'total_perseverative_errors' in df.columns and 'post_switch_accuracy' in df.columns:
    # Estimate post-switch trials (assuming ~11 trials based on DCCS protocol)
    estimated_post_trials = 11
    df['perseverative_rate_calc'] = (df['total_perseverative_errors'] / estimated_post_trials) * 100
    df['perseverative_rate_calc'] = df['perseverative_rate_calc'].fillna(0)
    print("   ‚úÖ Added: perseverative_rate_calc")
else:
    df['perseverative_rate_calc'] = 0

print("\n‚úÖ Feature engineering complete!")


## üéØ Step 5: Prepare Data for Training


In [None]:
# ============================================
# STEP 5: Prepare Features for Training (FIXED)
# ============================================

# Comprehensive feature list (all possible features from your games)
all_possible_features = [
    # Demographics
    'age_months',
    'completion_time_sec',
    
    # DCCS Features (Age 5.5-6+)
    'pre_switch_accuracy', 'post_switch_accuracy', 'mixed_block_accuracy',
    'total_perseverative_errors', 'perseverative_error_rate_post_switch',
    'avg_rt_pre_switch_ms', 'avg_rt_post_switch_correct_ms',
    'switch_cost_ms', 'accuracy_drop_percent',  # Derived features
    'number_of_consecutive_perseverations', 'total_rule_switch_errors',
    'longest_streak_correct', 'avg_reaction_time_ms',
    
    # Frog Jump Features (Age 3.5-5)
    'go_accuracy', 'nogo_accuracy', 'overall_accuracy',
    'commission_errors', 'omission_errors',
    'commission_error_rate', 'commission_error_rate_calc',  # Both original and derived
    'omission_error_rate', 'avg_rt_go_ms', 'rt_variability',
    'inhibition_failure_rate', 'anticipatory_responses', 'late_responses',
    'longest_correct_streak', 'longest_error_streak',
    
    # Questionnaire Features (Age 2-3)
    'critical_items_failed', 'critical_items_fail_rate',
    'q1_name_response', 'q4_eye_contact', 'q5_pointing',
    'q7_imitation', 'q9_joint_attention',
    'social_responsiveness_score', 'cognitive_flexibility_score',
    'joint_attention_score', 'social_communication_score',
    'failed_items_total', 'failed_items_rate', 'risk_score',
    
    # Clinical Reflection (Common to all)
    'attention_level', 'engagement_level', 'frustration_tolerance',
    'instruction_following', 'overall_behavior', 'enhanced_risk_score',
    
    # Derived features
    'perseverative_rate_calc',
]

# Filter to only columns that exist in your dataset
available_features = [col for col in all_possible_features if col in df.columns]
print(f"‚úÖ Using {len(available_features)} features:")
for f in available_features:
    print(f"   ‚Ä¢ {f}")

# Prepare X (features) - Use median fill for better handling (FIXED)
X = df[available_features].copy()

# Better missing value handling (use median for numeric, 0 for others)
# ‚úÖ FIXED: Using median instead of 0 prevents distortion (e.g., 0ms RT is impossible)
print("üîß Handling missing values...")
for col in X.columns:
    if X[col].dtype in ['float64', 'int64']:
        median_val = X[col].median()
        if pd.isna(median_val) or median_val == 0:
            # If median is NaN or 0, use 0 (but log it)
            X[col] = X[col].fillna(0)
        else:
            # Use median for realistic imputation
            X[col] = X[col].fillna(median_val)
    else:
        X[col] = X[col].fillna(0)

# Alternative: Use pandas median fill directly (cleaner)
# X = X.fillna(X.median(numeric_only=True)).fillna(0)
print("   ‚úÖ Missing values filled (median for numeric, 0 for others)")

# Prepare labels
y_binary = df['asd_label'].astype(int)  # 0=Control, 1=ASD

# Fix severity label (handle string/numeric mix)
y_severity = df['severity_label'].copy()
if y_severity.dtype == 'object':
    y_severity = y_severity.map({'0': 0, '1': 1, '2': 2, '3': 3, 0: 0, 1: 1, 2: 2, 3: 3})
y_severity = y_severity.fillna(0).astype(int)

print(f"\nüìä Feature Matrix: {X.shape}")
print(f"üè∑Ô∏è Binary Labels: {dict(y_binary.value_counts())}")
print(f"üè∑Ô∏è Severity Labels: {dict(y_severity.value_counts())}")


In [None]:
# ============================================
# STEP 6: Split Data (Train/Test)
# ============================================

X_train, X_test, y_train, y_test = train_test_split(
    X, y_binary, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_binary
)

# Scale features (important for SVM and Logistic Regression)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

print("‚úÖ Data Split Complete!")
print(f"   Training: {len(X_train)} samples (ASD: {sum(y_train==1)}, Control: {sum(y_train==0)})")
print(f"   Testing: {len(X_test)} samples (ASD: {sum(y_test==1)}, Control: {sum(y_test==0)})")


## ü§ñ Step 6: Train Multiple ML Models

Training 5 different algorithms to find the best one for ASD detection.


In [None]:
# ============================================
# STEP 7: Train All Models (FIXED - Added LightGBM)
# ============================================

models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
    'XGBoost': xgb.XGBClassifier(n_estimators=100, random_state=42, eval_metric='logloss', verbosity=0),
    'LightGBM': lgb.LGBMClassifier(n_estimators=100, random_state=42, verbose=-1),  # NEW
    'SVM': SVC(kernel='rbf', probability=True, random_state=42),
    'Gradient Boosting': GradientBoostingClassifier(n_estimators=100, random_state=42)
}

results = {}
trained_models = {}

print("üöÄ TRAINING MODELS...")
print("=" * 60)

# Check for class imbalance
class_counts = pd.Series(y_train).value_counts()
print(f"\nüìä Class Distribution (Train):")
print(f"   Control (0): {class_counts.get(0, 0)}")
print(f"   ASD (1): {class_counts.get(1, 0)}")

# Apply SMOTE if imbalanced (ASD < 40% of total)
if len(class_counts) == 2:
    minority_ratio = min(class_counts) / len(y_train)
    if minority_ratio < 0.4:
        print(f"\n‚ö†Ô∏è Class imbalance detected (minority: {minority_ratio:.1%})")
        print("   Applying SMOTE to balance classes...")
        smote = SMOTE(random_state=42)
        X_train_scaled, y_train = smote.fit_resample(X_train_scaled, y_train)
        print(f"   ‚úÖ After SMOTE: {len(X_train_scaled)} samples (balanced)")

for name, model in models.items():
    try:
        # Train
        model.fit(X_train_scaled, y_train)
        trained_models[name] = model
        
        # Predict
        y_pred = model.predict(X_test_scaled)
        y_prob = model.predict_proba(X_test_scaled)[:, 1]
        
        # Calculate metrics
        acc = accuracy_score(y_test, y_pred)
        auc = roc_auc_score(y_test, y_prob) if len(np.unique(y_test)) > 1 else 0
        prec = precision_score(y_test, y_pred, zero_division=0)
        rec = recall_score(y_test, y_pred, zero_division=0)
        f1 = f1_score(y_test, y_pred, zero_division=0)
        
        results[name] = {
            'accuracy': acc,
            'auc': auc,
            'precision': prec,
            'recall': rec,
            'f1': f1,
            'predictions': y_pred,
            'probabilities': y_prob
        }
        
        print(f"\n‚úÖ {name}:")
        print(f"   Accuracy: {acc:.2%} | AUC: {auc:.3f} | F1: {f1:.3f}")
    except Exception as e:
        print(f"\n‚ùå {name}: Error - {str(e)}")

# Best model
if results:
    best_model_name = max(results, key=lambda x: results[x]['accuracy'])
    print(f"\n{'='*60}")
    print(f"üèÜ BEST MODEL: {best_model_name}")
    print(f"   Accuracy: {results[best_model_name]['accuracy']:.2%}")
    print(f"   AUC-ROC: {results[best_model_name]['auc']:.3f}")
    print(f"\n‚ö†Ô∏è NOTE: If accuracy >95%, your sample data may be too 'perfect'.")
    print(f"   Real data typically achieves 82-92% accuracy.")
else:
    print("\n‚ùå No models trained successfully!")


## üìä Step 7: Visualize Model Comparison


In [None]:
# ============================================
# STEP 8: Visualize Model Comparison (FIXED - Added ROC Curve)
# ============================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

model_names = list(results.keys())
colors = ['#3498db', '#2ecc71', '#e74c3c', '#9b59b6', '#f39c12', '#1abc9c']

# Accuracy comparison
ax1 = axes[0]
accuracies = [results[m]['accuracy'] for m in model_names]
bars = ax1.bar(model_names, accuracies, color=colors[:len(model_names)])
ax1.set_ylabel('Accuracy', fontsize=12)
ax1.set_title('üéØ Model Accuracy Comparison', fontsize=14, fontweight='bold')
ax1.set_ylim([0, 1.1])
ax1.tick_params(axis='x', rotation=45)
for bar, acc in zip(bars, accuracies):
    ax1.text(bar.get_x() + bar.get_width()/2, acc + 0.02, f'{acc:.1%}', 
             ha='center', fontweight='bold', fontsize=10)

# AUC-ROC comparison
ax2 = axes[1]
aucs = [results[m]['auc'] for m in model_names]
bars2 = ax2.bar(model_names, aucs, color=colors[:len(model_names)])
ax2.set_ylabel('AUC-ROC', fontsize=12)
ax2.set_title('üìà Model AUC-ROC Comparison', fontsize=14, fontweight='bold')
ax2.set_ylim([0, 1.1])
ax2.tick_params(axis='x', rotation=45)
for bar, auc in zip(bars2, aucs):
    ax2.text(bar.get_x() + bar.get_width()/2, auc + 0.02, f'{auc:.3f}', 
             ha='center', fontweight='bold', fontsize=10)

# ROC Curve (NEW - Doctors love this!)
ax3 = axes[2]
for name, color in zip(model_names, colors[:len(model_names)]):
    if 'probabilities' in results[name]:
        fpr, tpr, _ = roc_curve(y_test, results[name]['probabilities'])
        ax3.plot(fpr, tpr, label=f"{name} (AUC={results[name]['auc']:.3f})", linewidth=2, color=color)
ax3.plot([0, 1], [0, 1], 'k--', label='Random', linewidth=1)
ax3.set_xlabel('False Positive Rate', fontsize=11)
ax3.set_ylabel('True Positive Rate', fontsize=11)
ax3.set_title('üìä ROC Curves (Binary Classification)', fontsize=14, fontweight='bold')
ax3.legend(loc='lower right', fontsize=9)
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## üîç Step 8: Feature Importance Analysis

Which cognitive markers are most predictive of ASD?


In [None]:
# ============================================
# STEP 9: Feature Importance
# ============================================

rf_model = trained_models['Random Forest']
importance_df = pd.DataFrame({
    'Feature': available_features,
    'Importance': rf_model.feature_importances_
}).sort_values('Importance', ascending=True)

plt.figure(figsize=(10, 8))
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(importance_df)))
plt.barh(importance_df['Feature'], importance_df['Importance'], color=colors)
plt.xlabel('Importance Score', fontsize=12)
plt.ylabel('Feature', fontsize=12)
plt.title('üìä Feature Importance for ASD Detection (Random Forest)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nüéØ TOP 5 MOST IMPORTANT FEATURES:")
print("=" * 50)
for idx, row in importance_df.tail(5).iloc[::-1].iterrows():
    print(f"   {row['Feature']}: {row['Importance']:.4f}")


## üîç Step 9: Confusion Matrix & Classification Report


In [None]:
# ============================================
# STEP 10: Confusion Matrix
# ============================================

best_model = trained_models[best_model_name]
y_pred_best = results[best_model_name]['predictions']

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
cm = confusion_matrix(y_test, y_pred_best)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Control', 'ASD'],
            yticklabels=['Control', 'ASD'],
            annot_kws={'size': 16})
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('Actual', fontsize=12)
plt.title(f'üîç Confusion Matrix - {best_model_name}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Classification Report
print(f"\nüìã CLASSIFICATION REPORT ({best_model_name}):")
print("=" * 60)
print(classification_report(y_test, y_pred_best, target_names=['Control', 'ASD']))


## üìä Step 10: Severity Classification (ASD Risk Levels)

For children diagnosed with ASD, predict the severity level (Level 1, 2, or 3).


In [None]:
# ============================================
# STEP 11: Severity Classification (FIXED - Ordinal Regression)
# ============================================

print("üìä SEVERITY CLASSIFICATION (ORDINAL REGRESSION)")
print("=" * 60)

# Filter ASD children only (severity > 0)
asd_df = df[df['asd_label'] == 1].copy()
asd_df = asd_df[asd_df['severity_label'] > 0].copy()  # Remove control (0)
print(f"ASD samples for severity prediction: {len(asd_df)}")

if len(asd_df) >= 6:  # Need minimum samples
    X_sev = asd_df[available_features].copy()
    
    # Better missing value handling
    for col in X_sev.columns:
        if X_sev[col].dtype in ['float64', 'int64']:
            median_val = X_sev[col].median()
            if pd.isna(median_val):
                X_sev[col] = X_sev[col].fillna(0)
            else:
                X_sev[col] = X_sev[col].fillna(median_val)
        else:
            X_sev[col] = X_sev[col].fillna(0)
    
    # Get severity labels (1, 2, 3 only)
    y_sev = asd_df['severity_label'].copy()
    if y_sev.dtype == 'object':
        y_sev = y_sev.map({'1': 1, '2': 2, '3': 3, 1: 1, 2: 2, 3: 3})
    y_sev = y_sev.fillna(1).astype(int)
    
    # Remove any remaining 0s
    mask = y_sev > 0
    X_sev = X_sev[mask]
    y_sev = y_sev[mask]
    
    print(f"Severity distribution:\n{y_sev.value_counts().sort_index()}")
    
    if len(X_sev) >= 6 and len(y_sev.unique()) >= 2:
        # Split data
        X_train_s, X_test_s, y_train_s, y_test_s = train_test_split(
            X_sev, y_sev, test_size=0.3, random_state=42, stratify=y_sev
        )
        
        # Scale
        scaler_sev = StandardScaler()
        X_train_s_scaled = scaler_sev.fit_transform(X_train_s)
        X_test_s_scaled = scaler_sev.transform(X_test_s)
        
        # ‚úÖ FIXED: Apply SMOTE for severity imbalance (Level 3 is often rare)
        print("\nüîß Checking severity class balance...")
        sev_class_counts = pd.Series(y_train_s).value_counts()
        print(f"   Severity distribution: {dict(sev_class_counts.sort_index())}")
        
        if len(sev_class_counts) >= 2:
            minority_ratio = min(sev_class_counts) / len(y_train_s)
            if minority_ratio < 0.3:  # If any class < 30%
                print(f"   ‚ö†Ô∏è Class imbalance detected (minority: {minority_ratio:.1%})")
                print("   Applying SMOTE to balance severity classes...")
                smote_sev = SMOTE(random_state=42)
                X_train_s_scaled, y_train_s = smote_sev.fit_resample(X_train_s_scaled, y_train_s)
                print(f"   ‚úÖ After SMOTE: {len(X_train_s_scaled)} samples (balanced)")
        
        # FIXED: Use Ordinal Regression (LogisticAT) instead of Random Forest
        print("\nüîß Training Ordinal Regression (LogisticAT)...")
        print("   (This treats severity as ordered: Level 1 < Level 2 < Level 3)")
        
        try:
            ordinal_model = LogisticAT(alpha=0.1)
            ordinal_model.fit(X_train_s_scaled, y_train_s)
            
            y_pred_sev = ordinal_model.predict(X_test_s_scaled)
            sev_accuracy = accuracy_score(y_test_s, y_pred_sev)
            
            print(f"\n‚úÖ Severity Classification Accuracy (Ordinal): {sev_accuracy:.2%}")
            
            print("\nüìã Classification Report:")
            print(classification_report(y_test_s, y_pred_sev, 
                                        target_names=[f'Level {i}' for i in sorted(y_sev.unique())]))
            
            # Save ordinal model
            import joblib
            joblib.dump(ordinal_model, 'severity_ordinal_model.pkl')
            joblib.dump(scaler_sev, 'severity_scaler.pkl')
            print("\n‚úÖ Saved: severity_ordinal_model.pkl, severity_scaler.pkl")
            
        except Exception as e:
            print(f"\n‚ö†Ô∏è Ordinal regression failed: {e}")
            print("   Falling back to Random Forest...")
            rf_severity = RandomForestClassifier(n_estimators=100, random_state=42)
            rf_severity.fit(X_train_s_scaled, y_train_s)
            y_pred_sev = rf_severity.predict(X_test_s_scaled)
            sev_accuracy = accuracy_score(y_test_s, y_pred_sev)
            print(f"‚úÖ Severity Classification Accuracy (RF): {sev_accuracy:.2%}")
    else:
        print("‚ö†Ô∏è Not enough samples or classes for severity classification")
else:
    print("‚ö†Ô∏è Not enough ASD samples for severity classification (need at least 6)")
    print("   Continue collecting data from LRH clinic!")


## üíæ Step 11: Save Models


In [None]:
# ============================================
# STEP 12: Save Trained Models
# ============================================

import joblib

print("üíæ SAVING MODELS...")
print("=" * 60)

# Save best model for ASD detection
joblib.dump(trained_models[best_model_name], 'asd_detection_model.pkl')
print(f"‚úÖ Saved: asd_detection_model.pkl ({best_model_name})")

# Save scaler
joblib.dump(scaler, 'feature_scaler.pkl')
print(f"‚úÖ Saved: feature_scaler.pkl")

# Save all models
for name, model in trained_models.items():
    filename = f"{name.lower().replace(' ', '_')}_model.pkl"
    joblib.dump(model, filename)
    print(f"‚úÖ Saved: {filename}")

print("\nüì• Downloading models to your computer...")

# Download files
from google.colab import files
files.download('asd_detection_model.pkl')
files.download('feature_scaler.pkl')

print("\n‚úÖ Models downloaded successfully!")


## üîÆ Step 12: Predict New Child

Test the model with a new child's data.


In [None]:
# ============================================
# STEP 13: Predict New Child (Example)
# ============================================

print("üîÆ PREDICT NEW CHILD")
print("=" * 60)

# Example 1: Child with ASD-like features
asd_like_child = {
    'age_months': 70,
    'completion_time_sec': 280,
    'total_score_or_trials': 28,
    'accuracy_overall': 55.0,
    'primary_asd_marker_1': 6,      # High perseverative errors
    'primary_asd_marker_2': 50.0,   # High perseverative rate
    'primary_asd_marker_3': 450,    # High switch cost
    'attention_level': 2,
    'engagement_level': 2,
    'frustration_tolerance': 2,
    'instruction_following': 2,
    'overall_behavior': 2,
    'enhanced_risk_score': 35.0
}

# Example 2: Control-like child
control_like_child = {
    'age_months': 70,
    'completion_time_sec': 190,
    'total_score_or_trials': 28,
    'accuracy_overall': 95.0,
    'primary_asd_marker_1': 0,      # No perseverative errors
    'primary_asd_marker_2': 0.0,    # No perseverative rate
    'primary_asd_marker_3': 90,     # Low switch cost
    'attention_level': 5,
    'engagement_level': 5,
    'frustration_tolerance': 5,
    'instruction_following': 5,
    'overall_behavior': 5,
    'enhanced_risk_score': 92.0
}

def predict_child(child_data, child_name):
    # Filter to available features only
    filtered = {k: v for k, v in child_data.items() if k in available_features}
    child_df = pd.DataFrame([filtered])
    
    # Scale and predict
    child_scaled = scaler.transform(child_df)
    prediction = best_model.predict(child_scaled)
    probability = best_model.predict_proba(child_scaled)
    
    print(f"\nüìã {child_name}:")
    print(f"   Diagnosis: {'üî¥ ASD RISK' if prediction[0] == 1 else 'üü¢ No ASD Concern'}")
    print(f"   Confidence: {max(probability[0]):.1%}")
    print(f"   ASD Probability: {probability[0][1]:.1%}")

predict_child(asd_like_child, "Child A (ASD-like features)")
predict_child(control_like_child, "Child B (Control-like features)")


## üìä Step 13: Cross-Validation & Final Summary


In [None]:
# ============================================
# STEP 14: Cross-Validation
# ============================================

print("üìä 5-FOLD CROSS VALIDATION")
print("=" * 60)

# Scale all data
X_all_scaled = scaler.fit_transform(X)

# Cross-validation for each model
for name, model in trained_models.items():
    cv_scores = cross_val_score(model, X_all_scaled, y_binary, cv=5, scoring='accuracy')
    print(f"\n{name}:")
    print(f"   Mean Accuracy: {cv_scores.mean():.2%} (¬±{cv_scores.std():.2%})")
    print(f"   Folds: {[f'{s:.1%}' for s in cv_scores]}")


In [None]:
# ============================================
# üéâ TRAINING COMPLETE - FINAL SUMMARY
# ============================================

print("\n" + "=" * 60)
print("üéâ ML TRAINING COMPLETE!")
print("=" * 60)

summary = f"""
üìä DATASET SUMMARY:
   ‚Ä¢ Total Samples: {len(df)}
   ‚Ä¢ ASD Children: {sum(df['asd_label']==1)}
   ‚Ä¢ Control Children: {sum(df['asd_label']==0)}
   ‚Ä¢ Features Used: {len(available_features)}

üèÜ BEST MODEL: {best_model_name}
   ‚Ä¢ Accuracy: {results[best_model_name]['accuracy']:.2%}
   ‚Ä¢ AUC-ROC: {results[best_model_name]['auc']:.3f}
   ‚Ä¢ Precision: {results[best_model_name]['precision']:.3f}
   ‚Ä¢ Recall: {results[best_model_name]['recall']:.3f}
   ‚Ä¢ F1-Score: {results[best_model_name]['f1']:.3f}

üìÅ SAVED FILES:
   ‚Ä¢ asd_detection_model.pkl - Trained {best_model_name}
   ‚Ä¢ feature_scaler.pkl - StandardScaler for preprocessing

üöÄ NEXT STEPS:
   1. Collect more data (target: 100+ ASD, 150+ Control)
   2. Fine-tune hyperparameters for better accuracy
   3. Deploy model to Flutter app via REST API
   4. Continue collecting data from LRH and preschools

üìê KEY ASD MARKERS (from feature importance):
"""
print(summary)

# Show top features
for idx, row in importance_df.tail(3).iloc[::-1].iterrows():
    print(f"   ‚Ä¢ {row['Feature']}: {row['Importance']:.4f}")

print("\n" + "=" * 60)
print("‚úÖ You can now use the trained model to predict ASD in new children!")
print("=" * 60)
