In [None]:
import pandas as pd
import numpy as np
import gradio as gr
import xgboost as xgb
import joblib
import traceback
import matplotlib.pyplot as plt
import io
from PIL import Image

# === 1. LOAD MODEL + PREPROCESSOR + DATA ===
general_model = xgb.XGBClassifier()
general_model.load_model("mooa_xgb_model_v4.json")

general_preprocessor = joblib.load("mooa_preprocessor_v4.joblib")

super_dataset = pd.read_csv("super_dataset.csv")

# Luzon lake baseline environmental data (updated 2025-10-08)
luzon_lakes = pd.DataFrame({
    "Lake Name": [
        "Laguna_de_Bay", "Lake_Taal", "Sampaloc_Lake", "Yambo_Lake",
        "Pandin_Lake", "Mohicap_Lake", "Palakpakin_Lake", "Nabao_Lake",
        "Tadlac_Lake", "Tikub_Lake", "Lake_Buhi", "Lake_Danao", "Bunot_Lake"
    ],
    "pH": [9.12, 8.32, 7.9, 7.9, 7.8, 7.7, 8.0, 6.33, 7.44, 8.08, 7.95, 7.81, 7.2],
    "Salinity (ppt)": [0.746, 0.85, 0.1, 0.1, 0.1, 0.1, 0.1, 0.25, 0.361, 0.1, 0.7, 0.1, 0.1],
    "Dissolved Oxygen (mg/L)": [7.54, 5.61, 3.1, 5.0, 7.3, 4.1, 5.0, 3.14, 7.27, 5.53, 6.89, 7.15, 7.7],
    "BOD (mg/L)": [1.93, 3.82, 8.0, 2.5, 2.0, 6.8, 3.1, 3.0, 2.33, 2.3, 1.76, 2.49, 10.2],
    "Turbidity (NTU)": [161.88, 28.0, 28.0, 9.8, 6.5, 10.0, 28.0, 3.5, 3.5, 3.5, 6.18, 2.25, 9.0],
    "Temperature (¬∞C)": [28.5, 25.5, 27.8, 26.5, 25.8, 26.2, 24.2, 28.0, 29.5, 30.4, 28.5, 29.5, 28.5]
})

# === 2. FEATURE IMPORTANCE ANALYSIS ===
def get_feature_importance_plots():
    """Generate feature importance visualizations."""
    try:
        # Get feature importances from the model
        booster = general_model.get_booster()
        importance_dict = booster.get_score(importance_type='gain')
        
        # Convert to DataFrame
        importance_df = pd.DataFrame(
            list(importance_dict.items()),
            columns=['Feature', 'Importance']
        ).sort_values(by='Importance', ascending=False)
        
        # Map model features (f0, f1, ...) to preprocessor feature names
        try:
            feature_names = general_preprocessor.get_feature_names_out()
            feature_map = {f"f{i}": name for i, name in enumerate(feature_names)}
            importance_df["Feature"] = importance_df["Feature"].map(feature_map)
            mapping_success = True
        except Exception as e:
            print("‚ö†Ô∏è Could not map feature names:", e)
            mapping_success = False
        
        # Create two plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
        
        # Plot 1: XGBoost built-in plot
        xgb.plot_importance(general_model, importance_type='gain', max_num_features=20, ax=ax1)
        ax1.set_title("Top 20 Features (XGBoost Built-in)", fontsize=14, fontweight='bold')
        
        # Plot 2: Mapped feature names (if successful)
        if mapping_success:
            top20 = importance_df.head(20).iloc[::-1]
            ax2.barh(range(len(top20)), top20["Importance"], color="skyblue")
            ax2.set_yticks(range(len(top20)))
            ax2.set_yticklabels(top20["Feature"], fontsize=9)
            ax2.set_xlabel("Gain (Average Improvement)", fontsize=11)
            ax2.set_ylabel("Feature", fontsize=11)
            ax2.set_title("Top 20 Features (Mapped Names)", fontsize=14, fontweight='bold')
            ax2.grid(axis="x", linestyle="--", alpha=0.6)
        else:
            ax2.text(0.5, 0.5, "Feature name mapping failed", 
                    ha='center', va='center', fontsize=12)
            ax2.set_title("Mapped Names (Unavailable)", fontsize=14)
        
        plt.tight_layout()
        
        # Convert plot to image
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        img = Image.open(buf)
        plt.close()
        
        # Return both image and table
        return img, importance_df.head(30)
        
    except Exception as e:
        # Create error image
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.text(0.5, 0.5, f"Error generating plots:\n{str(e)}", 
               ha='center', va='center', fontsize=12, color='red')
        ax.axis('off')
        
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100)
        buf.seek(0)
        img = Image.open(buf)
        plt.close()
        
        error_df = pd.DataFrame([{"Error": str(e)}])
        return img, error_df


# === 3. BUILD INPUT DATAFRAME ===
def build_input_dataframe(species_name, temperature, ph, salinity, do, bod, turbidity):
    """Build input DataFrame using biological + user + lake environmental data."""
    species_row = super_dataset[super_dataset['species'] == species_name]
    if species_row.empty:
        raise ValueError(f"Species '{species_name}' not found in dataset.")
    species_row = species_row.iloc[0].to_dict()

    rows = []
    for _, lake in luzon_lakes.iterrows():
        row = {**species_row}
        row.update({
            "waterbody_name": lake["Lake Name"],
            "wb_ph_min": lake["pH"],
            "wb_ph_max": lake["pH"],
            "wb_salinity_min": lake["Salinity (ppt)"],
            "wb_salinity_max": lake["Salinity (ppt)"],
            "wb_do_min": lake["Dissolved Oxygen (mg/L)"],
            "wb_do_max": lake["Dissolved Oxygen (mg/L)"],
            "wb_bod_min": lake["BOD (mg/L)"],
            "wb_bod_max": lake["BOD (mg/L)"],
            "wb_turbidity_min": lake["Turbidity (NTU)"],
            "wb_turbidity_max": lake["Turbidity (NTU)"],
            "wb_temp_min": lake["Temperature (¬∞C)"],
            "wb_temp_max": lake["Temperature (¬∞C)"],
            "input_temp": temperature,
            "input_ph": ph,
            "input_salinity": salinity,
            "input_do": do,
            "input_bod": bod,
            "input_turbidity": turbidity
        })
        rows.append(row)

    return pd.DataFrame(rows)


# === 4. PREDICTION FUNCTION ===
def predict_invasion_risk_for_lakes(species_name, temperature, ph, salinity, do, bod, turbidity, debug=False):
    """Predict invasion risk for Luzon lakes (uses model + lake-specific similarity weighting)."""
    
    def categorize_risk(score):
        """Convert numerical score to risk category."""
        if score < 0.34:
            return "üü¢ Low Risk"
        elif score < 0.67:
            return "üü° Medium Risk"
        else:
            return "üî¥ High Risk"
    
    try:
        # --- 1. Build base input ---
        input_df = build_input_dataframe(species_name, temperature, ph, salinity, do, bod, turbidity)

        # --- 2. Compute derived columns required by preprocessor ---
        input_df["temp_pref_range"] = input_df["temp_pref_max"] - input_df["temp_pref_min"]
        input_df["wb_ph_range"] = input_df["wb_ph_max"] - input_df["wb_ph_min"]
        input_df["wb_temp_range"] = input_df["wb_temp_max"] - input_df["wb_temp_min"]

        input_df["temp_in_pref_range"] = (
            (input_df["input_temp"] >= input_df["temp_pref_min"]) &
            (input_df["input_temp"] <= input_df["temp_pref_max"])
        ).astype(int)

        input_df["fish_ph_pref"] = (input_df["wb_ph_min"] + input_df["wb_ph_max"]) / 2
        input_df["ph_difference"] = abs(input_df["fish_ph_pref"] - input_df["input_ph"])

        # --- DEBUG OUTPUT ---
        if debug:
            print("\n=== DEBUGGING INFO ===")
            print(f"\nSpecies: {species_name}")
            print(f"User Inputs: temp={temperature}, pH={ph}, sal={salinity}, DO={do}, BOD={bod}, turb={turbidity}")
            print(f"\nSpecies temperature preference: {input_df.iloc[0]['temp_pref_min']}-{input_df.iloc[0]['temp_pref_max']}¬∞C")
            print(f"Temperature in preference range: {bool(input_df.iloc[0]['temp_in_pref_range'])}")
            print(f"\nCreated features: {input_df.columns.tolist()}")
            print(f"\nSample row for {input_df.iloc[0]['waterbody_name']}:")
            print(input_df.iloc[0][['waterbody_name', 'wb_ph_min', 'wb_temp_min', 'input_temp', 'input_ph', 'fish_ph_pref', 'ph_difference']])
            
            try:
                expected_features = general_preprocessor.feature_names_in_
                missing = set(expected_features) - set(input_df.columns)
                extra = set(input_df.columns) - set(expected_features)
                if missing:
                    print(f"\n‚ö†Ô∏è WARNING: Missing features: {missing}")
                if extra:
                    print(f"\n‚ö†Ô∏è WARNING: Extra features (will be ignored): {extra}")
            except AttributeError:
                print("\n‚ö†Ô∏è Preprocessor doesn't have feature_names_in_ attribute")

        # --- 3. Transform + predict ---
        try:
            X_processed = general_preprocessor.transform(input_df)
            y_pred_proba = general_model.predict_proba(X_processed)[:, 1]
        except Exception as transform_error:
            error_msg = f"Transform Error: {str(transform_error)}"
            return pd.DataFrame([{"Lake Name": "‚ùå ERROR", "Error": error_msg}])

        if debug:
            print(f"\nRaw predictions: {y_pred_proba}")

        # --- 4. Compute lake-specific similarity weights ---
        env_cols = ['pH', 'Salinity (ppt)', 'Dissolved Oxygen (mg/L)', 'BOD (mg/L)', 'Turbidity (NTU)', 'Temperature (¬∞C)']
        user_env = np.array([ph, salinity, do, bod, turbidity, temperature])
        
        similarities = []
        for _, lake in luzon_lakes.iterrows():
            lake_env = lake[env_cols].values.astype(float)
            dist = np.linalg.norm(lake_env - user_env)
            sim = np.exp(-dist / 10)
            similarities.append(sim)
        similarities = np.array(similarities)

        if debug:
            print(f"\nSimilarity weights (how close user input is to each lake):")
            for i, name in enumerate(luzon_lakes["Lake Name"]):
                print(f"  {name}: {similarities[i]:.3f}")

        adjusted_scores_multiply = y_pred_proba * similarities

        result_df = pd.DataFrame({
            "Lake Name": luzon_lakes["Lake Name"],
            "Raw Score": np.round(y_pred_proba, 3),
            "Raw Risk Level": [categorize_risk(s) for s in y_pred_proba],
            "Similarity": np.round(similarities, 3),
            "Adjusted Score": np.round(adjusted_scores_multiply, 3),
            "Adjusted Risk Level": [categorize_risk(s) for s in adjusted_scores_multiply]
        }).sort_values(by="Adjusted Score", ascending=False)

        # --- 5. Warn if input is far from all Luzon lake conditions ---
        if similarities.max() < 0.05:
            warning = "‚ö†Ô∏è WARNING: Your inputs are very different from all Luzon lake baselines. Predictions may be unreliable."
            result_df = pd.concat([
                pd.DataFrame([{
                    "Lake Name": warning,
                    "Raw Score": 0,
                    "Raw Risk Level": "N/A",
                    "Similarity": 0,
                    "Adjusted Score": 0,
                    "Adjusted Risk Level": "N/A"
                }]),
                result_df
            ], ignore_index=True)
        
        if debug:
            print("\n=== RISK INTERPRETATION ===")
            print("Thresholds:")
            print("  üü¢ Low Risk: 0.00 - 0.33")
            print("  üü° Medium Risk: 0.34 - 0.66")
            print("  üî¥ High Risk: 0.67 - 1.00")
            print("\nTop 3 lakes by adjusted risk:")
            for i, row in result_df.head(3).iterrows():
                if "WARNING" not in str(row["Lake Name"]):
                    print(f"  {row['Lake Name']}: {row['Adjusted Score']:.3f} ({row['Adjusted Risk Level']})")

        return result_df

    except Exception as e:
        error_msg = str(e)
        if len(error_msg) > 200:
            error_msg = error_msg[:200] + "..."
        return pd.DataFrame([{
            "Lake Name": "‚ùå ERROR",
            "Error": error_msg,
            "Details": "Check console for full traceback"
        }])


# === 5. GRADIO UI ===
def gradio_predict(species_name, temperature, ph, salinity, do, bod, turbidity, enable_debug):
    return predict_invasion_risk_for_lakes(species_name, temperature, ph, salinity, do, bod, turbidity, debug=enable_debug)

species_list = sorted(super_dataset['species'].unique())

with gr.Blocks(theme="soft") as demo:
    gr.Markdown("# üêü Invasive Species Risk Predictor - Luzon Lakes")
    
    with gr.Tabs():
        # === TAB 1: PREDICTION ===
        with gr.Tab("üîÆ Risk Prediction"):
            gr.Markdown("### Adjust the sliders to simulate lake conditions and see invasion risks.")
            gr.Markdown("**Debug mode prints detailed info to console**")

            with gr.Row():
                species_dropdown = gr.Dropdown(label="Species", choices=species_list, value=species_list[0])
                debug_checkbox = gr.Checkbox(label="Enable Debug Output (check console)", value=False)

            with gr.Row():
                temp_slider = gr.Slider(0, 40, 27.0, label="Temperature (¬∞C)")
                ph_slider = gr.Slider(0, 14, 7.0, label="pH")
                salinity_slider = gr.Slider(0, 10, 0.1, label="Salinity (ppt)")

            with gr.Row():
                do_slider = gr.Slider(0, 15, 5.0, label="Dissolved Oxygen (mg/L)")
                bod_slider = gr.Slider(0, 20, 5.0, label="BOD (mg/L)")
                turbidity_slider = gr.Slider(0, 500, 100, label="Turbidity (NTU)")

            predict_btn = gr.Button("üîÆ Predict Invasion Risk", variant="primary")
            
            gr.Markdown("### Results Explanation:")
            gr.Markdown("""
            - **Raw Score**: Direct XGBoost model prediction (0-1 probability)
            - **Similarity**: How close your inputs match each lake's baseline (1=exact, 0=very different)
            - **Adjusted Score**: Raw prediction √ó similarity (penalizes dissimilar conditions)
            - **Risk Categories**: üü¢ Low (0-0.33) | üü° Medium (0.34-0.66) | üî¥ High (0.67-1.00)
            """)
            
            output_table = gr.Dataframe()

            predict_btn.click(
                fn=gradio_predict,
                inputs=[species_dropdown, temp_slider, ph_slider, salinity_slider, do_slider, bod_slider, turbidity_slider, debug_checkbox],
                outputs=output_table
            )

            # Add test scenarios
            gr.Markdown("### üß™ Quick Test Scenarios:")
            with gr.Row():
                test1_btn = gr.Button("Test 1: Match Lake Taal")
                test2_btn = gr.Button("Test 2: Extreme Values")
                test3_btn = gr.Button("Test 3: Optimal Conditions")
            
            def test_lake_taal():
                return 25.5, 8.32, 0.85, 5.61, 3.82, 28.0
            
            def test_extreme():
                return 5.0, 3.0, 8.0, 2.0, 15.0, 400.0
            
            def test_optimal():
                return 26.0, 7.2, 0.2, 8.0, 2.0, 10.0
            
            test1_btn.click(fn=test_lake_taal, outputs=[temp_slider, ph_slider, salinity_slider, do_slider, bod_slider, turbidity_slider])
            test2_btn.click(fn=test_extreme, outputs=[temp_slider, ph_slider, salinity_slider, do_slider, bod_slider, turbidity_slider])
            test3_btn.click(fn=test_optimal, outputs=[temp_slider, ph_slider, salinity_slider, do_slider, bod_slider, turbidity_slider])
        
        # === TAB 2: FEATURE IMPORTANCE ===
        with gr.Tab("üìä Feature Importance"):
            gr.Markdown("### Understanding What Drives the Model's Predictions")
            gr.Markdown("""
            This analysis shows which features have the most impact on the model's invasion risk predictions.
            **Gain** measures the average improvement in prediction accuracy when using each feature.
            """)
            
            analyze_btn = gr.Button("üìà Generate Feature Importance Analysis", variant="primary")
            
            with gr.Row():
                importance_plot = gr.Image(label="Feature Importance Visualization")
            
            importance_table = gr.Dataframe(label="Top 30 Most Important Features")
            
            gr.Markdown("""
            **How to interpret:**
            - Higher gain = more important for predictions
            - Features at the top have the strongest influence on risk assessment
            - Compare both visualizations to understand feature naming conventions
            """)
            
            analyze_btn.click(
                fn=get_feature_importance_plots,
                inputs=[],
                outputs=[importance_plot, importance_table]
            )

demo.launch()