
# 🌾 Crop Disease Risk — Experiments (Google Colab)

This notebook runs **four experiments** using your datasets and saves figures + results tables.

**Experiments:**
1. **Exp1**: Random Forest Regression — Weather + Region (one-hot)  
2. **Exp2**: Random Forest Regression — Weather only  
3. **Exp3**: Per-Region models (R² per region table)  
4. **Exp4**: High-Risk Classification (HistGradientBoosting) + **Permutation Importance**  

**Outputs created:**
- `metrics_summary.csv` (summary table)  
- `Exp3_PerRegion_R2.csv` (region-wise R²)  
- `Exp4_PermutationImportance.csv` (importance for classifier)  
- SHAP bar plots for Exp1 & Exp2  
- PDP/ICE plots for top features  
- R² and RMSE bar charts for experiment comparison  

---

### 📂 How to provide data
**Option A — Upload files (simple):**
- When prompted, upload:
  - `aggregated_weather_by_region_year.csv`
  - `mean-time-series-wheat-regional-data-september-2024-1.xlsx`

**Option B — Use Google Drive:**
- Mount Drive and set `BASE` to your folder path inside Drive.


## 🔧 Setup & Install

In [None]:

# If running in Colab, install dependencies:
# (Uncomment if needed)
# !pip install pandas numpy scikit-learn shap matplotlib openpyxl

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import RandomForestRegressor, HistGradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, roc_auc_score, f1_score
from sklearn.inspection import permutation_importance, PartialDependenceDisplay
import shap

print("✅ Libraries imported.")


## 📥 Provide Data (Upload or Drive)

In [None]:

from google.colab import files

# === Option A: Upload files ===
print("👉 Option A: Click to upload the two files when prompted.")
uploaded = files.upload()  # use file picker

# File names expected (rename if your uploads differ)
WEATHER_FILE = "aggregated_weather_by_region_year.csv"
DISEASE_FILE = "mean-time-series-wheat-regional-data-september-2024-1.xlsx"

# If user uploaded different names, try to map by extension/guess
if WEATHER_FILE not in uploaded:
    # guess a csv
    for k in uploaded:
        if k.lower().endswith(".csv"):
            WEATHER_FILE = k
            break

if DISEASE_FILE not in uploaded:
    for k in uploaded:
        if k.lower().endswith(".xlsx"):
            DISEASE_FILE = k
            break

print("Using files:")
print("  WEATHER_FILE:", WEATHER_FILE)
print("  DISEASE_FILE:", DISEASE_FILE)


### (Optional) Use Google Drive Instead

In [None]:

# If you prefer Drive, run this cell instead of the upload cell above:
# from google.colab import drive
# drive.mount('/content/drive')
# BASE = "/content/drive/MyDrive/YourFolder"  # <-- change to your folder
# WEATHER_FILE = BASE + "/aggregated_weather_by_region_year.csv"
# DISEASE_FILE = BASE + "/mean-time-series-wheat-regional-data-september-2024-1.xlsx"
# print("Using Drive paths:\n ", WEATHER_FILE, "\n ", DISEASE_FILE)


## 📚 Load & Merge Data

In [None]:

TARGET_DISEASE = "Zymoseptoria_tritici"  # change to Yellow_rust etc. if needed
OUTDIR = "./figs_results"
os.makedirs(OUTDIR, exist_ok=True)

# Load
weather = pd.read_csv(WEATHER_FILE)
disease = pd.read_excel(DISEASE_FILE, skiprows=1)
disease.columns = disease.columns.str.strip().str.replace(" ", "_")

# Keep only relevant columns
disease = disease[['Region', 'Year', TARGET_DISEASE]].dropna()
disease.columns = ['region', 'year', 'DiseaseSeverity']

# Merge
df = weather.merge(disease, on=['region', 'year'], how='inner').dropna()

print("Weather shape:", weather.shape)
print("Disease shape:", disease.shape)
print("Merged shape :", df.shape)
df.head()


## 🧰 Helper Functions

In [None]:

def eval_regression(y_true, y_pred):
    return {
        "RMSE": mean_squared_error(y_true, y_pred, squared=False),
        "MAE": mean_absolute_error(y_true, y_pred),
        "R2": r2_score(y_true, y_pred)
    }

def save_metrics_table(results_dict, out_csv):
    rows = []
    for exp_name, m in results_dict.items():
        row = {"Experiment": exp_name}
        row.update(m)
        rows.append(row)
    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print("💾 Saved metrics table to:", out_csv)

def plot_metrics_bar(results_dict, out_r2_png):
    exps = list(results_dict.keys())
    r2s = [results_dict[k]["R2"] for k in exps]
    rmses = [results_dict[k]["RMSE"] for k in exps]

    # R2 bar
    plt.figure(figsize=(8,4))
    plt.bar(exps, r2s)
    plt.title("R² by Experiment")
    plt.ylabel("R²")
    plt.xticks(rotation=20)
    plt.tight_layout()
    plt.savefig(out_r2_png, dpi=200)
    plt.show()

    # RMSE bar
    out_rmse_png = out_r2_png.replace("R2","RMSE")
    plt.figure(figsize=(8,4))
    plt.bar(exps, rmses)
    plt.title("RMSE by Experiment")
    plt.ylabel("RMSE")
    plt.xticks(rotation=20)
    plt.tight_layout()
    plt.savefig(out_rmse_png, dpi=200)
    plt.show()

def run_rf_regression_and_plots(X, y, label):
    # Split
    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.3, random_state=42)
    # Train
    rf = RandomForestRegressor(random_state=42)
    rf.fit(X_tr, y_tr)
    # Predict
    y_pred = rf.predict(X_te)
    metrics = eval_regression(y_te, y_pred)

    # SHAP bar
    try:
        explainer = shap.TreeExplainer(rf)
        shap_vals = explainer.shap_values(X_te)
        shap.summary_plot(shap_vals, X_te, plot_type="bar", show=False)
        plt.title(f"SHAP Summary (bar) – {label}")
        plt.tight_layout()
        path = os.path.join(OUTDIR, f"SHAP_bar_{label}.png")
        plt.savefig(path, dpi=200)
        plt.show()
        print("💾 Saved SHAP bar to:", path)
    except Exception as e:
        print(f"[WARN] SHAP plot failed for {label}: {e}")

    # PDP/ICE for top 3 features by RF importance
    try:
        importances = pd.Series(rf.feature_importances_, index=X.columns).sort_values(ascending=False)
        top3 = list(importances.head(3).index)
        for f in top3:
            PartialDependenceDisplay.from_estimator(rf, X_tr, [f], kind="both")
            plt.suptitle(f"PDP/ICE – {label} – {f}")
            plt.tight_layout()
            path = os.path.join(OUTDIR, f"PDP_ICE_{label}_{f}.png")
            plt.savefig(path, dpi=200)
            plt.show()
            print("💾 Saved PDP/ICE to:", path)
    except Exception as e:
        print(f"[WARN] PDP/ICE failed for {label}: {e}")

    return metrics


## 🚀 Run Experiments

### Experiment 1 — Baseline (Weather + Region one-hot)

In [None]:

results = {}

E1 = "Exp1_Baseline_Weather+Region"
X1 = df.drop(columns=['DiseaseSeverity'])
X1 = pd.get_dummies(X1, columns=['region'], drop_first=True)
y = df['DiseaseSeverity']

results[E1] = run_rf_regression_and_plots(X1, y, E1)
results[E1]


### Experiment 2 — Weather Only

In [None]:

E2 = "Exp2_WeatherOnly"
X2 = df.drop(columns=['DiseaseSeverity','region','year'])

results[E2] = run_rf_regression_and_plots(X2, y, E2)
results[E2]


### Experiment 3 — Per-Region Models

In [None]:

E3 = "Exp3_PerRegion"
region_scores = []
for reg in sorted(df['region'].unique()):
    sub = df[df['region']==reg]
    if len(sub) < 30:
        continue
    Xr = sub.drop(columns=['DiseaseSeverity','region','year'])
    yr = sub['DiseaseSeverity']
    X_tr, X_te, y_tr, y_te = train_test_split(Xr, yr, test_size=0.3, random_state=42)
    rf = RandomForestRegressor(random_state=42).fit(X_tr, y_tr)
    pred = rf.predict(X_te)
    r2 = r2_score(y_te, pred)
    region_scores.append({"region": reg, "R2": r2})

region_df = pd.DataFrame(region_scores).sort_values("R2", ascending=False)
path = os.path.join(OUTDIR, "Exp3_PerRegion_R2.csv")
region_df.to_csv(path, index=False)
print("💾 Saved per-region R² table to:", path)

# store mean R2 for summary
results[E3] = {"RMSE": np.nan, "MAE": np.nan, "R2": region_df["R2"].mean() if len(region_df) else np.nan}
region_df.head(10)


### Experiment 4 — High-Risk Classification + Permutation Importance

In [None]:

E4 = "Exp4_HighRisk_Classifier_PermImp"
thr = df['DiseaseSeverity'].quantile(0.75)
df_cls = df.copy()
df_cls["HighRisk"] = (df_cls["DiseaseSeverity"] >= thr).astype(int)

Xc = df_cls.drop(columns=['region','year','DiseaseSeverity','HighRisk'])
yc = df_cls['HighRisk']

Xc_tr, Xc_te, yc_tr, yc_te = train_test_split(Xc, yc, test_size=0.3, random_state=42, stratify=yc)
clf = HistGradientBoostingClassifier(random_state=42)
clf.fit(Xc_tr, yc_tr)

proba = clf.predict_proba(Xc_te)[:,1]
pred  = (proba >= 0.5).astype(int)

roc = roc_auc_score(yc_te, proba)
f1  = f1_score(yc_te, pred)
print("ROC-AUC:", roc)
print("F1     :", f1)

# Permutation importance
perm = permutation_importance(clf, Xc_te, yc_te, n_repeats=15, random_state=42, scoring='roc_auc')
perm_series = pd.Series(perm.importances_mean, index=Xc.columns).sort_values(ascending=False)
perm_csv = os.path.join(OUTDIR, "Exp4_PermutationImportance.csv")
perm_series.to_csv(perm_csv)
print("💾 Saved permutation importance CSV to:", perm_csv)

plt.figure(figsize=(8,5))
perm_series.head(12).plot(kind='bar')
plt.title("Permutation Importance (ROC-AUC) – High-Risk Classifier")
plt.ylabel("Mean Importance (Δ ROC-AUC)")
plt.tight_layout()
perm_png = os.path.join(OUTDIR, "Exp4_PermImp_Top12.png")
plt.savefig(perm_png, dpi=200)
plt.show()
print("💾 Saved permutation importance plot to:", perm_png)

# store ROC in R2 slot for unified summary plotting
results[E4] = {"RMSE": np.nan, "MAE": np.nan, "R2": roc}


## 💾 Save Metrics Summary + Bar Charts

In [None]:

def plot_metrics_bar(results_dict, out_r2_png):
    exps = list(results_dict.keys())
    r2s = [results_dict[k]["R2"] for k in exps]
    rmses = [results_dict[k]["RMSE"] for k in exps]

    # R2 bar
    plt.figure(figsize=(8,4))
    plt.bar(exps, r2s)
    plt.title("R² by Experiment")
    plt.ylabel("R²")
    plt.xticks(rotation=20)
    plt.tight_layout()
    plt.savefig(out_r2_png, dpi=200)
    plt.show()

    # RMSE bar
    out_rmse_png = out_r2_png.replace("R2","RMSE")
    plt.figure(figsize=(8,4))
    plt.bar(exps, rmses)
    plt.title("RMSE by Experiment")
    plt.ylabel("RMSE")
    plt.xticks(rotation=20)
    plt.tight_layout()
    plt.savefig(out_rmse_png, dpi=200)
    plt.show()

metrics_csv = os.path.join(OUTDIR, "metrics_summary.csv")
save_metrics_table(results, metrics_csv)

plot_metrics_bar(results, os.path.join(OUTDIR, "R2_by_Experiment.png"))
print("✅ Done. See figures and tables in:", OUTDIR)
