In [None]:
"""
toolbox.py (FINAL with robust Grad-CAM)

Usage (examples):
    python toolbox.py audit
    python toolbox.py extract_embeddings --images_dir sample_images --out embeddings.csv
    python toolbox.py runserver --host 127.0.0.1 --port 8000
    python toolbox.py test_request --image sample_images/example.jpg --host 127.0.0.1 --port 8000

Place your models in folder:
    models/Cnn_model.h5
    models/xgboost_soil_health_model.pkl

Notes:
- This version expects XGBoost trained on soil-only features listed in SOIL_FEATURES.
- Soil form keys expected by the API are cleaned (spaces -> '_', '%' -> '_pct', '/' -> '_'):
  PH, EC_ds_m, OC_pct, N_kg_hectre, P_kg_hectre, K_kg_hectre, S_ppm, Zn_ppm, B_ppm, Fe_ppm, Mn_ppm, Cu_ppm
"""

import sys, os, io, base64, argparse, json
from pathlib import Path

# ------- helper to ensure packages -------
def ensure_packages(packages):
    import importlib, subprocess, sys
    missing = []
    for pkg in packages:
        try:
            importlib.import_module(pkg)
        except Exception:
            missing.append(pkg)
    if not missing:
        return
    print("Installing missing packages:", missing)
    subprocess.check_call([sys.executable, "-m", "pip", "install", *missing])

# packages used across modes
ensure_packages([
    "tensorflow",
    "numpy",
    "Pillow",
    "opencv_python",
    "joblib",
    "fastapi",
    "uvicorn",
    "xgboost",
    "shap",
    "requests",
    "pandas",
    "python-multipart"
])

# heavy imports
import numpy as np
from PIL import Image
import cv2
import joblib
import tensorflow as tf
from tensorflow.keras.models import load_model, Model
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import JSONResponse
import uvicorn
import shap
import xgboost
import requests

# ---------------- Config ----------------
MODELS_DIR = Path("models")
CNN_PATH = MODELS_DIR / "Cnn_model.h5"                     # change if your filename differs
XGB_PATH = MODELS_DIR / "xgboost_soil_health_model.pkl"    # change if your filename differs

DEFAULT_IMG_SIZE = (224, 224)
CLASS_LABELS = ["healthy", "diseased"]

# Soil features (exact order from your XGBoost booster)
SOIL_FEATURES = [
    "PH", "EC ds/m", "OC %", "N kg/hectre", "P kg/hectre", "K kg/hectre",
    "S ppm", "Zn ppm", "B ppm", "Fe ppm", "Mn ppm", "Cu ppm"
]

# cleaned keys mapping for multipart/form-data usage
def clean_key(s):
    return s.replace(" ", "_").replace("/", "_").replace("%", "_pct")

SOIL_KEYS = [clean_key(f) for f in SOIL_FEATURES]

# ----------------------------------------

def load_cnn():
    if not CNN_PATH.exists():
        raise FileNotFoundError(f"Could not find {CNN_PATH}")
    print("Loading CNN from", CNN_PATH)
    model = load_model(str(CNN_PATH))
    return model

def load_xgb():
    if not XGB_PATH.exists():
        raise FileNotFoundError(f"Could not find {XGB_PATH}")
    print("Loading XGBoost from", XGB_PATH)
    model = joblib.load(str(XGB_PATH))
    return model

# ---------------- Audit ----------------
def audit_models():
    print("=== MODEL AUDIT ===")
    # CNN audit
    try:
        cnn = load_cnn()
    except Exception as e:
        print("Error loading cnn:", e)
        cnn = None

    if cnn is not None:
        print("\n--- CNN Summary ---")
        cnn.summary()
        inp_shape = cnn.input_shape
        print(f"\nCNN input_shape: {inp_shape}")
        # last conv layer
        last_conv = None
        for layer in reversed(cnn.layers):
            if "conv" in layer.name.lower():
                last_conv = layer.name
                break
        print("Last conv layer (for Grad-CAM):", last_conv)
        # try to inspect global pooling
        try:
            gap = cnn.get_layer("global_average_pooling2d")
            print("GlobalAveragePooling2D layer found:", gap.output_shape)
        except Exception:
            print("No named GlobalAveragePooling2D layer found by that name.")
        # penultimate layer
        penultimate = None
        if len(cnn.layers) >= 2:
            penultimate = cnn.layers[-2].name
            print("Penultimate layer:", penultimate)
            # get embedding size with dummy
            try:
                embed_model = Model(inputs=cnn.input, outputs=cnn.get_layer("global_average_pooling2d").output)
                H = inp_shape[1] or DEFAULT_IMG_SIZE[0]
                W = inp_shape[2] or DEFAULT_IMG_SIZE[1]
                C = inp_shape[3] or 3
                dummy = np.zeros((1, H, W, C), dtype="float32")
                emb = embed_model.predict(dummy)
                emb_dim = int(np.prod(emb.shape[1:]))
                print("Embedding shape (from global_average_pooling2d):", emb.shape, "flattened dim:", emb_dim)
            except Exception as e:
                print("Could not compute embedding on dummy input with global pool:", e)
        else:
            print("CNN too shallow to determine penultimate layer.")

    # XGBoost audit
    try:
        xgb = load_xgb()
    except Exception as e:
        print("Error loading xgboost:", e)
        xgb = None

    if xgb is not None:
        print("\n--- XGBoost Info ---")
        try:
            booster = xgb.get_booster()
            f_names = booster.feature_names
            print("Booster feature names:", f_names)
            if f_names:
                print("Number of features (from booster):", len(f_names))
        except Exception as e:
            print("Could not read booster.feature_names:", e)
        print("Supports predict_proba:", hasattr(xgb, "predict_proba"))
        # try probing with guesses (we won't expect success because booster has names)
        guesses = [64, 128, 256, 512, 768, 1024]
        probed = None
        for g in guesses:
            try:
                _ = xgb.predict(np.zeros((1, g)))
                probed = g
                break
            except Exception:
                continue
        if probed:
            print("XGBoost accepted input length (probing):", probed)
        else:
            print("Could not infer XGBoost feature length by probing. Check training code or booster feature names.")

    print("\n=== AUDIT COMPLETE ===\n")

# ------------- Embeddings -------------
def extract_embeddings(images_dir: str, out_csv: str):
    cnn = load_cnn()
    inp_shape = cnn.input_shape
    H = inp_shape[1] or DEFAULT_IMG_SIZE[0]
    W = inp_shape[2] or DEFAULT_IMG_SIZE[1]
    # use global average pooling2d as embed model
    try:
        embed_model = Model(inputs=cnn.input, outputs=cnn.get_layer("global_average_pooling2d").output)
        print("Using global_average_pooling2d for embeddings.")
    except Exception as e:
        print("Error creating embed model from global pool:", e)
        # fallback to penultimate
        try:
            embed_model = Model(inputs=cnn.input, outputs=cnn.layers[-2].output)
            print("Fallback: using layers[-2] for embeddings.")
        except Exception as e2:
            print("Could not create embedding extractor model. Inspect your CNN architecture.")
            raise RuntimeError("No embedding extractor available.")

    images = list(Path(images_dir).glob("*"))
    rows = []
    print(f"Found {len(images)} files in {images_dir}. Extracting embeddings...")
    for p in images:
        if p.suffix.lower() not in (".jpg", ".jpeg", ".png"):
            continue
        try:
            img = Image.open(p).convert("RGB").resize((W, H))
            arr = np.array(img).astype("float32") / 255.0
            x = np.expand_dims(arr, 0)
            emb = embed_model.predict(x)
            emb_flat = emb.reshape(-1)
            row = {"image": p.name}
            for i, v in enumerate(emb_flat):
                row[f"e{i}"] = float(v)
            rows.append(row)
        except Exception as e:
            print("Error on", p, e)
    import pandas as pd
    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print("Saved embeddings to", out_csv)

# ------------- Grad-CAM (robust) -------------
def make_gradcam_heatmap(img_array, model, last_conv_layer_name=None, class_index=None, IMG_SIZE=DEFAULT_IMG_SIZE):
    """
    Robust Grad-CAM: performs forward pass inside GradientTape so gradients can be computed
    for nested/sequential models. Returns heatmap (H,W) and predicted class index.
    """
    # pick last conv layer if not provided
    if last_conv_layer_name is None:
        for layer in reversed(model.layers):
            if "conv" in layer.name.lower():
                last_conv_layer_name = layer.name
                break
    if last_conv_layer_name is None:
        raise ValueError("No conv layer found and last_conv_layer_name not provided.")

    # helper to call layer in inference mode
    def call_layer(layer, x):
        try:
            return layer(x, training=False)
        except TypeError:
            return layer(x)

    x = tf.convert_to_tensor(img_array, dtype=tf.float32)

    # Forward pass inside the tape so TensorFlow records ops
    conv_outputs = None
    out = x
    with tf.GradientTape() as tape:
        # we will watch conv_outputs once we create it
        for layer in model.layers:
            out = call_layer(layer, out)
            if layer.name == last_conv_layer_name:
                conv_outputs = out
                tape.watch(conv_outputs)
        preds = out  # final predictions after full forward pass
        if class_index is None:
            class_index = tf.argmax(preds[0])
        loss = preds[:, class_index]

    # Compute gradients of the loss w.r.t. the conv outputs
    grads = tape.gradient(loss, conv_outputs)
    if grads is None:
        raise RuntimeError("Gradients are None â€” model may not be differentiable for this output.")
    grads = grads[0]  # remove batch dim
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1))

    conv_out = conv_outputs[0]  # remove batch dim
    cam = tf.zeros(conv_out.shape[0:2], dtype=tf.float32)
    for i in range(conv_out.shape[-1]):
        cam += pooled_grads[i] * conv_out[:, :, i]
    cam = tf.maximum(cam, 0)
    denom = tf.math.reduce_max(cam) + 1e-8
    cam = cam / denom
    cam = cam.numpy()
    cam = cv2.resize(cam, (IMG_SIZE[1], IMG_SIZE[0]))
    return cam, int(class_index)

def overlay_heatmap_on_pil(img_pil, heatmap, alpha=0.4):
    img = np.array(img_pil.resize((heatmap.shape[1], heatmap.shape[0]))).astype("uint8")
    heatmap_uint8 = np.uint8(255 * heatmap)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(heatmap_color, alpha, img, 1 - alpha, 0)
    return Image.fromarray(overlay)

# ------------- FastAPI server -------------
def create_app():
    cnn = None
    embed_model = None
    xgb = None
    shap_explainer = None
    last_conv_layer = None

    print("Loading models for server...")
    # load cnn and build embed model (global pool)
    try:
        cnn = load_cnn()
        try:
            embed_model = Model(inputs=cnn.input, outputs=cnn.get_layer("global_average_pooling2d").output)
            print("Embed model created from global_average_pooling2d.")
        except Exception:
            try:
                embed_model = Model(inputs=cnn.input, outputs=cnn.layers[-2].output)
                print("Embed model created from layers[-2].")
            except Exception:
                embed_model = None
                print("No embed model available.")
        # last conv for gradcam
        for layer in reversed(cnn.layers):
            if "conv" in layer.name.lower():
                last_conv_layer = layer.name
                break
        print("Last conv layer set to:", last_conv_layer)
    except Exception as e:
        print("CNN load error:", e)

    try:
        xgb = load_xgb()
    except Exception as e:
        print("XGB load error:", e)

    # build shap explainer if xgb present; use background zeros with correct feature count
    if xgb is not None:
        try:
            feature_count = len(SOIL_FEATURES)
            background = np.zeros((1, feature_count))
            shap_explainer = shap.TreeExplainer(xgb, data=background)
            print("SHAP explainer ready with background shape:", background.shape)
        except Exception as e:
            print("SHAP init error:", e)
            shap_explainer = None

    app = FastAPI()

    @app.get("/")
    async def root():
        return {"status": "ok", "note": "POST /predict with form fields image + soil fields (see docs)"}

    @app.post("/predict")
    async def predict(request: Request, image: UploadFile = File(...)):
        """
        Expects multipart/form-data with:
         - file 'image'
         - soil fields as form fields (cleaned keys): PH, EC_ds_m, OC_pct, N_kg_hectre, P_kg_hectre, K_kg_hectre,
           S_ppm, Zn_ppm, B_ppm, Fe_ppm, Mn_ppm, Cu_ppm
        Returns json with label/confidence/gradcam_overlay_b64/shap_top_features/advice
        """
        if cnn is None:
            return JSONResponse({"error": "CNN not loaded on server."}, status_code=500)

        # read image bytes
        contents = await image.read()
        img_pil = Image.open(io.BytesIO(contents)).convert("RGB")
        inp_shape = cnn.input_shape
        H = inp_shape[1] or DEFAULT_IMG_SIZE[0]
        W = inp_shape[2] or DEFAULT_IMG_SIZE[1]
        img_resized = img_pil.resize((W, H))
        arr = np.array(img_resized).astype("float32") / 255.0
        x = np.expand_dims(arr, 0)

        # cnn predict
        preds = cnn.predict(x)[0]
        idx = int(np.argmax(preds))
        conf = float(np.max(preds))
        label = CLASS_LABELS[idx] if CLASS_LABELS else str(idx)

        # grad-cam
        overlay_b64 = None
        try:
            heatmap, _ = make_gradcam_heatmap(x, cnn, last_conv_layer, class_index=idx, IMG_SIZE=(H, W))
            overlay = overlay_heatmap_on_pil(img_pil, heatmap)
            buff = io.BytesIO()
            overlay.save(buff, format="PNG")
            overlay_b64 = base64.b64encode(buff.getvalue()).decode("utf-8")
        except Exception as e:
            print("Grad-CAM error:", e)

        # Build soil-only input for XGBoost using SOIL_KEYS order (cleaned names)
        form = await request.form()
        soil_vals = []
        for key in SOIL_KEYS:
            val = form.get(key)
            if val is None:
                # allow alternate key names in case frontend uses original labels
                # try to match original feature text
                orig_idx = SOIL_KEYS.index(key)
                orig_name = SOIL_FEATURES[orig_idx]
                # try raw name
                val = form.get(orig_name)
            try:
                soil_vals.append(float(val) if val is not None else 0.0)
            except Exception:
                soil_vals.append(0.0)

        model_input = np.array(soil_vals, dtype="float32").reshape(1, -1)

        xgb_conf = None
        shap_top = None
        try:
            if xgb is not None:
                if hasattr(xgb, "predict_proba"):
                    xgb_probs = xgb.predict_proba(model_input)[0]
                    xgb_idx = int(np.argmax(xgb_probs))
                    xgb_conf = float(np.max(xgb_probs))
                else:
                    xgb_idx = int(xgb.predict(model_input)[0])
                    xgb_conf = None

                # SHAP: compute contribution vector and pick top3 features by absolute value
                if shap_explainer is not None:
                    try:
                        sv = shap_explainer.shap_values(model_input)
                        if isinstance(sv, list):
                            arr = np.array(sv[0]).reshape(-1)
                        else:
                            arr = np.array(sv).reshape(-1)
                        idxs = np.argsort(np.abs(arr))[-3:][::-1]
                        shap_top = []
                        for i in idxs:
                            shap_top.append({
                                "feature": SOIL_FEATURES[i],
                                "key": SOIL_KEYS[i],
                                "shap_value": float(arr[i])
                            })
                    except Exception as e:
                        print("SHAP compute error:", e)
                        shap_top = None
        except Exception as e:
            print("XGB/SHAP error:", e)

        # simple severity & advice heuristics
        severity = "Low"
        if conf >= 0.85 or (xgb_conf is not None and xgb_conf >= 0.85):
            severity = "High" if (CLASS_LABELS and not CLASS_LABELS[0].lower().startswith("healthy")) or (CLASS_LABELS is None and label != "0") else "Low"
        elif conf >= 0.6:
            severity = "Medium"

        advice = "No issue detected. Keep regular care." if (CLASS_LABELS is None and label == "0") or (CLASS_LABELS and label.lower().startswith("healthy")) else "Remove affected leaves, improve ventilation, consider treatment if it spreads."

        return {
            "label": label,
            "confidence": conf,
            "xgb_confidence": xgb_conf,
            "severity": severity,
            "advice": advice,
            "gradcam_overlay_b64": overlay_b64,
            "shap_top_features": shap_top
        }

    return app

# ------------- Test client -------------
def send_test_request(server_url, image_path, soil_dict):
    files = {"image": open(image_path, "rb")}
    # soil_dict keys should be cleaned keys matching SOIL_KEYS
    data = soil_dict
    print("Sending request to", server_url)
    r = requests.post(server_url + "/predict", files=files, data=data)
    try:
        print("Status", r.status_code)
        print("Response JSON:")
        print(json.dumps(r.json(), indent=2))
    except Exception as e:
        print("Error reading response:", e)
        print(r.text)

# ---------------- CLI ----------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("mode", choices=["audit", "extract_embeddings", "runserver", "test_request"], help="mode to run")
    p.add_argument("--images_dir", default="sample_images", help="images folder for extract_embeddings")
    p.add_argument("--out", default="embeddings.csv", help="output csv for embeddings")
    p.add_argument("--host", default="127.0.0.1")
    p.add_argument("--port", default=8000, type=int)
    p.add_argument("--image", help="image path for test_request")
    # convenience: allow passing basic soil values for test_request (optional)
    p.add_argument("--PH", type=float, default=6.8)
    p.add_argument("--EC_ds_m", type=float, default=0.35)
    p.add_argument("--OC_pct", type=float, default=1.0)
    p.add_argument("--N_kg_hectre", type=float, default=100.0)
    p.add_argument("--P_kg_hectre", type=float, default=20.0)
    p.add_argument("--K_kg_hectre", type=float, default=120.0)
    p.add_argument("--S_ppm", type=float, default=10.0)
    p.add_argument("--Zn_ppm", type=float, default=2.0)
    p.add_argument("--B_ppm", type=float, default=0.5)
    p.add_argument("--Fe_ppm", type=float, default=30.0)
    p.add_argument("--Mn_ppm", type=float, default=5.0)
    p.add_argument("--Cu_ppm", type=float, default=1.0)

    args = p.parse_args()

    if args.mode == "audit":
        audit_models()
    elif args.mode == "extract_embeddings":
        if not Path(args.images_dir).exists():
            print("images_dir does not exist:", args.images_dir)
            return
        extract_embeddings(args.images_dir, args.out)
    elif args.mode == "runserver":
        app = create_app()
        print(f"Starting server on {args.host}:{args.port} ...")
        uvicorn.run(app, host=args.host, port=args.port)
    elif args.mode == "test_request":
        if not args.image:
            print("Provide --image for test_request")
            return
        # assemble soil dict using cleaned keys
        soil = {
            "PH": args.PH,
            "EC_ds_m": args.EC_ds_m,
            "OC_pct": args.OC_pct,
            "N_kg_hectre": args.N_kg_hectre,
            "P_kg_hectre": args.P_kg_hectre,
            "K_kg_hectre": args.K_kg_hectre,
            "S_ppm": args.S_ppm,
            "Zn_ppm": args.Zn_ppm,
            "B_ppm": args.B_ppm,
            "Fe_ppm": args.Fe_ppm,
            "Mn_ppm": args.Mn_ppm,
            "Cu_ppm": args.Cu_ppm
        }
        server_url = f"http://{args.host}:{args.port}"
        send_test_request(server_url, args.image, soil)
    else:
        print("Unknown mode")

if __name__ == "__main__":
    main()


In [None]:
curl.exe -X POST -F "PH=6.8" -F "EC_ds_m=0.35" -F "OC_pct=1.2" -F "N_kg_hectre=100" -F "P_kg_hectre=20" -F "K_kg_hectre=120" -F "S_ppm=10" -F "Zn_ppm=2" -F "B_ppm=0.5" -F "Fe_ppm=30" -F "Mn_ppm=5" -F "Cu_ppm=1" "http://127.0.0.1:8000/predict"

In [None]:
 curl.exe -X POST -F "image=@sample_images/1.21.jpeg" http://127.0.0.1:8000/predict