In [None]:
import os
import csv
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# =========================
# Config
# =========================
OUT_DIR = "comment10_vgg16unet_outputs"
os.makedirs(OUT_DIR, exist_ok=True)

thr = 0.5
N_SHOW = 6

# =========================
# Load model + test split
# =========================
model = tf.keras.models.load_model("vgg16_unet_6band_fixedscale_fullmodel.h5", compile=False)

data = np.load("splits_vgg16_unet_6band.npz", allow_pickle=True)
X_test = data["X_test"]
y_test = data["y_test"]

y_true = y_test[..., 0].astype(np.uint8)

# =========================
# Scaling
# =========================
def scale_np_image(img):
    img = img.astype(np.float32)
    mx = float(np.max(img))
    if mx <= 1.5:
        return img
    if mx <= 300.0:
        return img / 255.0
    return img / 10000.0

X01_test = np.stack([scale_np_image(x) for x in X_test], axis=0).astype(np.float32)

# =========================
# Helpers
# =========================
def stretch_rgb_per_channel(rgb, p_low=2, p_high=98, eps=1e-6):
    rgb = rgb.astype(np.float32)
    out = np.zeros_like(rgb, dtype=np.float32)
    for c in range(3):
        lo = np.percentile(rgb[..., c], p_low)
        hi = np.percentile(rgb[..., c], p_high)
        if abs(hi - lo) < eps:
            out[..., c] = rgb[..., c]
        else:
            out[..., c] = (rgb[..., c] - lo) / (hi - lo)
    return np.clip(out, 0, 1)

def safe_index(num, den, eps=1e-6):
    den2 = np.where(np.abs(den) < eps, np.nan, den)
    x = num / den2
    return np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

def compute_indices(img01):
    r, g, b, nir, swir1, swir2 = [img01[..., i] for i in range(6)]
    ndvi  = safe_index(nir - r, nir + r)
    mndwi = safe_index(g - swir1, g + swir1)
    nbr   = safe_index(nir - swir2, nir + swir2)
    return ndvi, mndwi, nbr

def binarize(p, thr=0.5):
    return (p >= thr).astype(np.uint8)

def error_masks(gt, pred):
    fp = ((gt == 0) & (pred == 1)).astype(np.uint8)
    fn = ((gt == 1) & (pred == 0)).astype(np.uint8)
    tp = ((gt == 1) & (pred == 1)).astype(np.uint8)
    tn = ((gt == 0) & (pred == 0)).astype(np.uint8)
    return tn, fp, fn, tp

def error_map(gt, pred):
    tn, fp, fn, tp = error_masks(gt, pred)
    out = np.zeros_like(gt, dtype=np.uint8)
    out[fp == 1] = 1
    out[fn == 1] = 2
    out[tp == 1] = 3
    return out

def compute_iou(gt, pred):
    gt = gt.reshape(-1).astype(np.uint8)
    pred = pred.reshape(-1).astype(np.uint8)
    inter = np.sum((gt == 1) & (pred == 1))
    union = np.sum((gt == 1) | (pred == 1))
    return float(inter / (union + 1e-12))

def pick_mixed_tiles(y_true, min_frac=0.05, max_frac=0.95):
    frac = y_true.reshape(y_true.shape[0], -1).mean(axis=1)
    return np.where((frac >= min_frac) & (frac <= max_frac))[0]

# =========================
# Grad-CAM (segmentation)
# =========================
def gradcam_segmentation(model, img01, cam_layer_name, region_mask=None):
    img = tf.convert_to_tensor(img01[None, ...], dtype=tf.float32)

    cam_layer = model.get_layer(cam_layer_name)
    grad_model = tf.keras.Model([model.inputs], [cam_layer.output, model.output])

    with tf.GradientTape() as tape:
        conv_out, pred = grad_model(img, training=False)  # pred: (1,H,W,1)
        pred = pred[..., 0]                               # (1,H,W)

        if region_mask is None:
            target = tf.reduce_mean(pred)
        else:
            rm = tf.convert_to_tensor(region_mask[None, ...], dtype=tf.float32)
            target = tf.reduce_sum(pred * rm) / (tf.reduce_sum(rm) + 1e-12)

    grads = tape.gradient(target, conv_out)               # (1,h,w,c)
    weights = tf.reduce_mean(grads, axis=(1, 2))          # (1,c)
    cam = tf.reduce_sum(conv_out * weights[:, None, None, :], axis=-1)  # (1,h,w)
    cam = tf.nn.relu(cam)[0].numpy()

    cam = tf.image.resize(cam[..., None], (img01.shape[0], img01.shape[1]), method="bilinear").numpy()[..., 0]
    cam = cam - cam.min()
    cam = cam / (cam.max() + 1e-12)
    return cam.astype(np.float32)

CAM_LAYER = "decoder_cam_conv"  # from model file

# =========================
# Predict
# =========================
y_prob = model.predict(X01_test, batch_size=16, verbose=1)[..., 0]  # (N,H,W)

mixed = pick_mixed_tiles(y_true, 0.05, 0.95)
if mixed.size == 0:
    mixed = pick_mixed_tiles(y_true, 0.01, 0.99)

ious = np.array([compute_iou(y_true[i], binarize(y_prob[i], thr)) for i in mixed], dtype=np.float32)
best_idxs  = mixed[np.argsort(-ious)[:N_SHOW]]
worst_idxs = mixed[np.argsort(ious)[:N_SHOW]]

# =========================
# Save panels
# =========================
def save_gradcam_panels(idxs, out_png, focus="gt"):
    n = len(idxs)
    fig, axs = plt.subplots(n, 5, figsize=(14, 2.7*n))
    if n == 1:
        axs = np.expand_dims(axs, axis=0)

    for r, idx in enumerate(idxs):
        img01 = X01_test[idx]
        rgb_vis = stretch_rgb_per_channel(img01[..., :3])

        gt = y_true[idx]
        prob = y_prob[idx]
        pred = binarize(prob, thr)
        em = error_map(gt, pred)
        tn, fp, fn, tp = error_masks(gt, pred)

        if focus == "gt":
            region = gt.astype(np.float32)
        elif focus == "fp":
            region = fp.astype(np.float32)
        elif focus == "fn":
            region = fn.astype(np.float32)
        else:
            region = None

        cam = gradcam_segmentation(model, img01, CAM_LAYER, region_mask=region)

        axs[r, 0].imshow(rgb_vis); axs[r, 0].set_title(f"RGB (tile {idx})")
        axs[r, 1].imshow(gt, vmin=0, vmax=1); axs[r, 1].set_title("GT")
        axs[r, 2].imshow(pred, vmin=0, vmax=1); axs[r, 2].set_title("Pred")
        axs[r, 3].imshow(em, vmin=0, vmax=3); axs[r, 3].set_title("Error (FP/FN/TP)")
        axs[r, 4].imshow(rgb_vis); axs[r, 4].imshow(cam, alpha=0.45); axs[r, 4].set_title(f"Grad-CAM ({focus})")

        for c in range(5):
            axs[r, c].set_xticks([]); axs[r, c].set_yticks([])

    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()

def save_indices_panel(idxs, out_png):
    n = len(idxs)
    fig, axs = plt.subplots(n, 4, figsize=(12, 2.7*n))
    if n == 1:
        axs = np.expand_dims(axs, axis=0)

    for r, idx in enumerate(idxs):
        img01 = X01_test[idx]
        rgb_vis = stretch_rgb_per_channel(img01[..., :3])
        ndvi, mndwi, nbr = compute_indices(img01)

        axs[r, 0].imshow(rgb_vis); axs[r, 0].set_title(f"RGB (tile {idx})")
        axs[r, 1].imshow(ndvi, vmin=-1, vmax=1); axs[r, 1].set_title("NDVI")
        axs[r, 2].imshow(mndwi, vmin=-1, vmax=1); axs[r, 2].set_title("MNDWI")
        axs[r, 3].imshow(nbr, vmin=-1, vmax=1); axs[r, 3].set_title("NBR")

        for c in range(4):
            axs[r, c].set_xticks([]); axs[r, c].set_yticks([])

    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()

save_gradcam_panels(best_idxs,  os.path.join(OUT_DIR, "gradcam_best_mixed_gtfocus.png"), focus="gt")
save_gradcam_panels(worst_idxs, os.path.join(OUT_DIR, "gradcam_worst_mixed_gtfocus.png"), focus="gt")
save_gradcam_panels(worst_idxs, os.path.join(OUT_DIR, "gradcam_worst_mixed_fp_focus.png"), focus="fp")
save_gradcam_panels(worst_idxs, os.path.join(OUT_DIR, "gradcam_worst_mixed_fn_focus.png"), focus="fn")
save_indices_panel(worst_idxs,  os.path.join(OUT_DIR, "spectral_indices_worst_mixed.png"))

# =========================
# CSV diagnostics
# =========================
def tile_proxies(img01):
    rgb = img01[..., :3]
    brightness = rgb.mean(axis=-1)
    bright_frac = float(np.mean(brightness > 0.85))
    dark_frac   = float(np.mean(brightness < 0.15))

    ndvi, mndwi, nbr = compute_indices(img01)
    water_frac = float(np.mean(mndwi > 0.2))
    veg_frac   = float(np.mean(ndvi > 0.4))

    return {
        "mean_brightness": float(brightness.mean()),
        "std_brightness": float(brightness.std()),
        "bright_frac": bright_frac,
        "dark_frac": dark_frac,
        "water_frac_mndwi": water_frac,
        "veg_frac_ndvi": veg_frac,
        "ndvi_mean": float(ndvi.mean()),
        "mndwi_mean": float(mndwi.mean()),
        "nbr_mean": float(nbr.mean()),
    }

rows = []
for i in range(X01_test.shape[0]):
    img01 = X01_test[i]
    gt = y_true[i]
    pred = binarize(y_prob[i], thr)

    tn, fp, fn, tp = error_masks(gt, pred)
    fp_rate = float(fp.sum() / (gt.size + 1e-12))
    fn_rate = float(fn.sum() / (gt.size + 1e-12))
    iou = compute_iou(gt, pred)

    prox = tile_proxies(img01)
    rows.append({
        "tile_id": i,
        "gt_forest_frac": float(gt.mean()),
        "iou": iou,
        "fp_rate": fp_rate,
        "fn_rate": fn_rate,
        **prox
    })

csv_path = os.path.join(OUT_DIR, "per_tile_error_proxies_vgg16unet.csv")
with open(csv_path, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
    w.writeheader()
    w.writerows(rows)

def summarize_group(name, selector):
    sel = np.array([selector(r) for r in rows], dtype=bool)
    if sel.sum() == 0:
        return None
    return {
        "group": name,
        "n_tiles": int(sel.sum()),
        "mean_fp_rate": float(np.mean([rows[i]["fp_rate"] for i in range(len(rows)) if sel[i]])),
        "mean_fn_rate": float(np.mean([rows[i]["fn_rate"] for i in range(len(rows)) if sel[i]])),
        "mean_iou": float(np.mean([rows[i]["iou"] for i in range(len(rows)) if sel[i]])),
    }

groups = []
groups.append(summarize_group("high_cloud_haze (bright_frac>0.10)", lambda r: r["bright_frac"] > 0.10))
groups.append(summarize_group("high_shadow (dark_frac>0.10)", lambda r: r["dark_frac"] > 0.10))
groups.append(summarize_group("high_water (water_frac_mndwi>0.10)", lambda r: r["water_frac_mndwi"] > 0.10))
groups.append(summarize_group("low_water (water_frac_mndwi<=0.10)", lambda r: r["water_frac_mndwi"] <= 0.10))
groups = [g for g in groups if g is not None]

sum_path = os.path.join(OUT_DIR, "stratified_error_summary_vgg16unet.csv")
if groups:
    with open(sum_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(groups[0].keys()))
        w.writeheader()
        w.writerows(groups)

print("Saved outputs to:", OUT_DIR)
print("Figures saved: gradcam_* and spectral_indices_*")
print("CSVs saved:", csv_path, sum_path if groups else "")
