In [None]:
# ============================================================
# Full training + interpretability (Grad-CAM) with FIXED scaling
# Works for 6-band Landsat-style inputs where values may be 0..10000
# Band order assumed: [R, G, B, NIR, SWIR1, SWIR2]
# ============================================================

import os
import csv
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Conv2DTranspose,
    Concatenate, BatchNormalization, Activation
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler

# =========================
# Parameters
# =========================
tile_size = 256
input_channels = 6
batch_size = 32
epochs = 100
thr = 0.5

OUT_DIR = "reviewer_comment10_outputs_6band_fixed"
os.makedirs(OUT_DIR, exist_ok=True)

# =========================
# Load your data
# =========================
# tiles_img.shape = (num_samples, 256, 256, 6)
# tiles_mask.shape = (num_samples, 256, 256)
tiles_img, tiles_mask = load_data("datasets_temp/images", "datasets_temp/masks", tile_size)

tiles_mask = tiles_mask.reshape(tiles_mask.shape + (1,)).astype(np.uint8)

# =========================
# Sanity checks (critical)
# =========================
print("tiles_img dtype:", tiles_img.dtype)
print("tiles_img global min/max:", float(np.min(tiles_img)), float(np.max(tiles_img)))
print("tiles_mask unique:", np.unique(tiles_mask))

for k in range(min(input_channels, tiles_img.shape[-1])):
    v = tiles_img[..., k]
    print(f"band[{k}] min/max/mean/std:",
          float(v.min()), float(v.max()), float(v.mean()), float(v.std()))

# ============================================================
# Scaling utilities (FIX)
# ============================================================
def scale_np_image(img):
    """
    Returns float32 image approximately in [0,1] range without clipping.
    Heuristic:
      - if max <= 1.5: already 0..1
      - elif max <= 300: treat as 0..255-ish -> /255
      - else: treat as reflectance 0..10000-ish -> /10000
    """
    img = img.astype(np.float32)
    mx = float(np.max(img))
    if mx <= 1.5:
        return img
    if mx <= 300.0:
        return img / 255.0
    return img / 10000.0

@tf.function
def scale_tf_image(image):
    image = tf.cast(image, tf.float32)
    mx = tf.reduce_max(image)

    # reflectance-like (0..10000)
    image = tf.cond(mx > 300.0, lambda: image / 10000.0, lambda: image)
    # 8-bit-like (0..255)
    image = tf.cond((mx > 1.5) & (mx <= 300.0), lambda: image / 255.0, lambda: image)
    return image

# ============================================================
# Visualization helpers (FIX)
# ============================================================
def stretch_rgb_per_channel(rgb, p_low=2, p_high=98, eps=1e-6):
    """
    Robust per-channel percentile stretch for visualization.
    Handles near-constant channels safely.
    """
    rgb = rgb.astype(np.float32)
    out = np.zeros_like(rgb, dtype=np.float32)
    for c in range(3):
        lo = np.percentile(rgb[..., c], p_low)
        hi = np.percentile(rgb[..., c], p_high)
        if abs(hi - lo) < eps:
            out[..., c] = rgb[..., c]
        else:
            out[..., c] = (rgb[..., c] - lo) / (hi - lo)
    return np.clip(out, 0.0, 1.0)

def safe_index(num, den, eps=1e-6):
    den2 = np.where(np.abs(den) < eps, np.nan, den)
    x = num / den2
    return np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

def compute_indices(img01):
    """
    img01 expected scaled ~0..1 (not clipped).
    """
    r = img01[..., 0]
    g = img01[..., 1]
    b = img01[..., 2]
    nir = img01[..., 3]
    swir1 = img01[..., 4]
    swir2 = img01[..., 5]

    ndvi  = safe_index(nir - r, nir + r)
    mndwi = safe_index(g - swir1, g + swir1)
    nbr   = safe_index(nir - swir2, nir + swir2)
    return ndvi, mndwi, nbr

def binarize(p, thr=0.5):
    return (p >= thr).astype(np.uint8)

def error_masks(gt, pred):
    fp = ((gt == 0) & (pred == 1)).astype(np.uint8)
    fn = ((gt == 1) & (pred == 0)).astype(np.uint8)
    tp = ((gt == 1) & (pred == 1)).astype(np.uint8)
    tn = ((gt == 0) & (pred == 0)).astype(np.uint8)
    return tn, fp, fn, tp

def error_map(gt, pred):
    # 0 TN, 1 FP, 2 FN, 3 TP
    tn, fp, fn, tp = error_masks(gt, pred)
    out = np.zeros_like(gt, dtype=np.uint8)
    out[fp == 1] = 1
    out[fn == 1] = 2
    out[tp == 1] = 3
    return out

def compute_iou(gt, pred):
    gt = gt.reshape(-1).astype(np.uint8)
    pred = pred.reshape(-1).astype(np.uint8)
    inter = np.sum((gt == 1) & (pred == 1))
    union = np.sum((gt == 1) | (pred == 1))
    return float(inter / (union + 1e-12))

def pick_mixed_tiles(y_true, min_frac=0.05, max_frac=0.95):
    frac = y_true.reshape(y_true.shape[0], -1).mean(axis=1)
    return np.where((frac >= min_frac) & (frac <= max_frac))[0]

# ============================================================
# Model: custom 6-band U-Net
# ============================================================
def custom_unet(input_size=(tile_size, tile_size, input_channels)):
    inputs = Input(input_size)

    # Encoder
    c1 = Conv2D(64, (3, 3), padding='same')(inputs)
    c1 = BatchNormalization()(c1)
    c1 = Activation("relu")(c1)
    c1 = Conv2D(64, (3, 3), padding='same')(c1)
    c1 = Activation("relu")(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(128, (3, 3), padding='same')(p1)
    c2 = BatchNormalization()(c2)
    c2 = Activation("relu")(c2)
    c2 = Conv2D(128, (3, 3), padding='same')(c2)
    c2 = Activation("relu")(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(256, (3, 3), padding='same')(p2)
    c3 = BatchNormalization()(c3)
    c3 = Activation("relu")(c3)
    c3 = Conv2D(256, (3, 3), padding='same')(c3)
    c3 = Activation("relu")(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = Conv2D(512, (3, 3), padding='same')(p3)
    c4 = BatchNormalization()(c4)
    c4 = Activation("relu")(c4)
    c4 = Conv2D(512, (3, 3), padding='same')(c4)
    c4 = Activation("relu")(c4)
    p4 = MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5 = Conv2D(1024, (3, 3), padding='same')(p4)
    c5 = BatchNormalization()(c5)
    c5 = Activation("relu")(c5)
    c5 = Conv2D(1024, (3, 3), padding='same')(c5)
    c5 = Activation("relu")(c5)

    # Decoder
    u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = Concatenate()([u6, c4])
    c6 = Conv2D(512, (3, 3), padding='same')(u6)
    c6 = BatchNormalization()(c6)
    c6 = Activation("relu")(c6)
    c6 = Conv2D(512, (3, 3), padding='same')(c6)
    c6 = Activation("relu")(c6)

    u7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = Concatenate()([u7, c3])
    c7 = Conv2D(256, (3, 3), padding='same')(u7)
    c7 = BatchNormalization()(c7)
    c7 = Activation("relu")(c7)
    c7 = Conv2D(256, (3, 3), padding='same')(c7)
    c7 = Activation("relu")(c7)

    u8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = Concatenate()([u8, c2])
    c8 = Conv2D(128, (3, 3), padding='same')(u8)
    c8 = BatchNormalization()(c8)
    c8 = Activation("relu")(c8)
    c8 = Conv2D(128, (3, 3), padding='same')(c8)
    c8 = Activation("relu")(c8)

    u9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = Concatenate()([u9, c1])
    c9 = Conv2D(64, (3, 3), padding='same', name="last_decoder_conv")(u9)
    c9 = BatchNormalization()(c9)
    c9 = Activation("relu")(c9)
    c9 = Conv2D(64, (3, 3), padding='same')(c9)
    c9 = Activation("relu")(c9)

    outputs = Conv2D(1, (1, 1), activation='sigmoid', name="segmentation_head")(c9)

    model = Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer=Adam(learning_rate=1e-4),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

unet_model = custom_unet(input_size=(tile_size, tile_size, input_channels))
unet_model.summary()

# =========================
# LR scheduler (fixed to decay progressively)
# =========================
def lr_schedule(epoch, base=1e-4, decay=0.9, step=10):
    n = epoch // step
    return base * (decay ** n)

lr_scheduler = LearningRateScheduler(lr_schedule)

# =========================
# TF Dataset with augmentation (FIXED scaling)
# =========================
def preprocess_tf(image, mask):
    image = scale_tf_image(image)
    mask = tf.cast(mask, tf.float32)
    return image, mask

def augment_tf(image, mask):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    mask  = tf.image.random_flip_left_right(mask)
    mask  = tf.image.random_flip_up_down(mask)

    k = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
    image = tf.image.rot90(image, k)
    mask  = tf.image.rot90(mask, k)
    return image, mask

X_train, X_test, y_train, y_test = train_test_split(tiles_img, tiles_mask, test_size=0.3, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.map(augment_tf, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(200).batch(batch_size).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# =========================
# Callbacks
# =========================
checkpointer = ModelCheckpoint("best_weight_unet_6band_fixedscale.h5",
                               monitor="val_loss", verbose=1, save_best_only=True, mode="min")
earlyStopping = EarlyStopping(monitor="val_loss", patience=5, verbose=1, mode="min")
callbacks = [lr_scheduler, earlyStopping, checkpointer]

# =========================
# Train
# =========================
history = unet_model.fit(
    train_dataset,
    epochs=epochs,
    validation_data=val_dataset,
    callbacks=callbacks
)

# =========================
# Evaluate
# =========================
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_dataset = test_dataset.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

test_loss, test_accuracy = unet_model.evaluate(test_dataset, verbose=1)
print("Test Loss:", float(test_loss), "Test Accuracy:", float(test_accuracy))

# =========================
# Save model
# =========================
unet_model.save("forest_detection_model_6band_unet_fixedscale.h5")

# =========================
# Plot training history
# =========================
plt.figure(figsize=(8,5))
plt.plot(history.history["loss"], label="Training Loss")
plt.plot(history.history["val_loss"], label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.show()

# ============================================================
# Grad-CAM for segmentation (FIXED scaling + better visuals)
# ============================================================
def auto_pick_cam_layer(model):
    if "last_decoder_conv" in [l.name for l in model.layers]:
        return "last_decoder_conv"
    convs = [l for l in model.layers if isinstance(l, tf.keras.layers.Conv2D)]
    for l in reversed(convs):
        k = getattr(l, "kernel_size", (1, 1))
        if k[0] > 1 or k[1] > 1:
            return l.name
    return convs[-1].name

def gradcam_segmentation(model, img01, cam_layer_name, region_mask=None):
    """
    img01: (H,W,C) float32 scaled ~0..1
    region_mask: (H,W) float {0,1} to focus gradients on GT/FP/FN pixels
    """
    img = tf.convert_to_tensor(img01[None, ...], dtype=tf.float32)

    cam_layer = model.get_layer(cam_layer_name)
    grad_model = tf.keras.Model([model.inputs], [cam_layer.output, model.output])

    with tf.GradientTape() as tape:
        conv_out, pred = grad_model(img, training=False)     # pred: (1,H,W,1)
        pred = pred[..., 0]                                  # (1,H,W)
        if region_mask is None:
            target = tf.reduce_mean(pred)
        else:
            rm = tf.convert_to_tensor(region_mask[None, ...], dtype=tf.float32)
            target = tf.reduce_sum(pred * rm) / (tf.reduce_sum(rm) + 1e-12)

    grads = tape.gradient(target, conv_out)                  # (1,h,w,c)
    weights = tf.reduce_mean(grads, axis=(1, 2))             # (1,c)
    cam = tf.reduce_sum(conv_out * weights[:, None, None, :], axis=-1)  # (1,h,w)
    cam = tf.nn.relu(cam)[0].numpy()

    cam = tf.image.resize(cam[..., None], (img01.shape[0], img01.shape[1]), method="bilinear").numpy()[..., 0]
    cam = cam - cam.min()
    cam = cam / (cam.max() + 1e-12)
    return cam.astype(np.float32)

def save_gradcam_panels(X, y_true, y_prob, idxs, cam_layer, out_png, focus="gt"):
    n = len(idxs)
    fig, axs = plt.subplots(n, 5, figsize=(14, 2.7*n))
    if n == 1:
        axs = np.expand_dims(axs, axis=0)

    for r, idx in enumerate(idxs):
        img01 = scale_np_image(X[idx])
        rgb_vis = stretch_rgb_per_channel(img01[..., :3])

        gt = y_true[idx].astype(np.uint8)
        prob = y_prob[idx].astype(np.float32)
        pred = binarize(prob, thr)
        em = error_map(gt, pred)

        tn, fp, fn, tp = error_masks(gt, pred)

        if focus == "gt":
            region = gt.astype(np.float32)
        elif focus == "fp":
            region = fp.astype(np.float32)
        elif focus == "fn":
            region = fn.astype(np.float32)
        else:
            region = None

        cam = gradcam_segmentation(unet_model, img01, cam_layer, region_mask=region)

        axs[r, 0].imshow(rgb_vis); axs[r, 0].set_title(f"RGB (tile {idx})")
        axs[r, 1].imshow(gt, vmin=0, vmax=1); axs[r, 1].set_title("Ground truth")
        axs[r, 2].imshow(pred, vmin=0, vmax=1); axs[r, 2].set_title("Prediction")
        axs[r, 3].imshow(em, vmin=0, vmax=3); axs[r, 3].set_title("Errors (FP/FN/TP)")
        axs[r, 4].imshow(rgb_vis)
        axs[r, 4].imshow(cam, alpha=0.45)
        axs[r, 4].set_title(f"Grad-CAM ({focus})")

        for c in range(5):
            axs[r, c].set_xticks([]); axs[r, c].set_yticks([])

    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()

def save_indices_panel(X, idxs, out_png):
    n = len(idxs)
    fig, axs = plt.subplots(n, 4, figsize=(12, 2.7*n))
    if n == 1:
        axs = np.expand_dims(axs, axis=0)

    for r, idx in enumerate(idxs):
        img01 = scale_np_image(X[idx])
        rgb_vis = stretch_rgb_per_channel(img01[..., :3])
        ndvi, mndwi, nbr = compute_indices(img01)

        axs[r, 0].imshow(rgb_vis); axs[r, 0].set_title(f"RGB (tile {idx})")
        axs[r, 1].imshow(ndvi, vmin=-1, vmax=1); axs[r, 1].set_title("NDVI")
        axs[r, 2].imshow(mndwi, vmin=-1, vmax=1); axs[r, 2].set_title("MNDWI")
        axs[r, 3].imshow(nbr, vmin=-1, vmax=1); axs[r, 3].set_title("NBR")

        for c in range(4):
            axs[r, c].set_xticks([]); axs[r, c].set_yticks([])

    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()

# ============================================================
# Predict on test set (IMPORTANT: scaled correctly)
# ============================================================
X01_test = np.stack([scale_np_image(x) for x in X_test], axis=0).astype(np.float32)
y_true = y_test[..., 0].astype(np.uint8)

y_prob = unet_model.predict(X01_test, batch_size=16, verbose=1)[..., 0]  # (N,H,W)

# pick mixed tiles so panels are meaningful
mixed = pick_mixed_tiles(y_true, min_frac=0.05, max_frac=0.95)
if mixed.size == 0:
    mixed = pick_mixed_tiles(y_true, min_frac=0.01, max_frac=0.99)

ious = []
for idx in mixed:
    pred = binarize(y_prob[idx], thr)
    ious.append(compute_iou(y_true[idx], pred))
ious = np.array(ious, dtype=np.float32)

N_SHOW = 6
best_idxs  = mixed[np.argsort(-ious)[:N_SHOW]]
worst_idxs = mixed[np.argsort(ious)[:N_SHOW]]

cam_layer = auto_pick_cam_layer(unet_model)
print("Grad-CAM layer:", cam_layer)
print("Best idx:", best_idxs)
print("Worst idx:", worst_idxs)

save_gradcam_panels(X_test, y_true, y_prob, best_idxs,  cam_layer,
                    os.path.join(OUT_DIR, "gradcam_best_mixed_gtfocus.png"), focus="gt")

save_gradcam_panels(X_test, y_true, y_prob, worst_idxs, cam_layer,
                    os.path.join(OUT_DIR, "gradcam_worst_mixed_gtfocus.png"), focus="gt")

save_gradcam_panels(X_test, y_true, y_prob, worst_idxs, cam_layer,
                    os.path.join(OUT_DIR, "gradcam_worst_mixed_fp_focus.png"), focus="fp")

save_gradcam_panels(X_test, y_true, y_prob, worst_idxs, cam_layer,
                    os.path.join(OUT_DIR, "gradcam_worst_mixed_fn_focus.png"), focus="fn")

save_indices_panel(X_test, worst_idxs, os.path.join(OUT_DIR, "spectral_indices_worst_mixed.png"))

# ============================================================
# Simple per-tile proxies + stratified summary (optional but helpful)
# ============================================================
def tile_proxies(img01):
    rgb = img01[..., :3]
    brightness = rgb.mean(axis=-1)

    # "cloud/haze-like" and "shadow-like" are just simple proxies
    bright_frac = float(np.mean(brightness > 0.85))
    dark_frac   = float(np.mean(brightness < 0.15))

    ndvi, mndwi, nbr = compute_indices(img01)
    water_frac = float(np.mean(mndwi > 0.2))   # water/inundation proxy
    veg_frac   = float(np.mean(ndvi > 0.4))    # strong vegetation proxy

    return {
        "mean_brightness": float(brightness.mean()),
        "std_brightness": float(brightness.std()),
        "bright_frac": bright_frac,
        "dark_frac": dark_frac,
        "water_frac_mndwi": water_frac,
        "veg_frac_ndvi": veg_frac,
        "ndvi_mean": float(ndvi.mean()),
        "mndwi_mean": float(mndwi.mean()),
        "nbr_mean": float(nbr.mean()),
    }

rows = []
for i in range(X_test.shape[0]):
    img01 = X01_test[i]
    gt = y_true[i]
    pred = binarize(y_prob[i], thr)

    tn, fp, fn, tp = error_masks(gt, pred)
    fp_rate = float(fp.sum() / (gt.size + 1e-12))
    fn_rate = float(fn.sum() / (gt.size + 1e-12))
    iou = compute_iou(gt, pred)

    prox = tile_proxies(img01)

    rows.append({
        "tile_id": i,
        "gt_forest_frac": float(gt.mean()),
        "iou": iou,
        "fp_rate": fp_rate,
        "fn_rate": fn_rate,
        **prox
    })

csv_path = os.path.join(OUT_DIR, "per_tile_error_proxies_6band_fixed.csv")
with open(csv_path, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
    w.writeheader()
    w.writerows(rows)

def summarize_group(name, selector):
    sel = np.array([selector(r) for r in rows], dtype=bool)
    if sel.sum() == 0:
        return None
    return {
        "group": name,
        "n_tiles": int(sel.sum()),
        "mean_fp_rate": float(np.mean([rows[i]["fp_rate"] for i in range(len(rows)) if sel[i]])),
        "mean_fn_rate": float(np.mean([rows[i]["fn_rate"] for i in range(len(rows)) if sel[i]])),
        "mean_iou": float(np.mean([rows[i]["iou"] for i in range(len(rows)) if sel[i]])),
    }

groups = []
groups.append(summarize_group("high_cloud_haze (bright_frac>0.10)", lambda r: r["bright_frac"] > 0.10))
groups.append(summarize_group("high_shadow (dark_frac>0.10)", lambda r: r["dark_frac"] > 0.10))
groups.append(summarize_group("high_water (water_frac_mndwi>0.10)", lambda r: r["water_frac_mndwi"] > 0.10))
groups.append(summarize_group("low_water (water_frac_mndwi<=0.10)", lambda r: r["water_frac_mndwi"] <= 0.10))
groups = [g for g in groups if g is not None]

sum_path = os.path.join(OUT_DIR, "stratified_error_summary_6band_fixed.csv")
if len(groups) > 0:
    with open(sum_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(groups[0].keys()))
        w.writeheader()
        w.writerows(groups)

print("Saved outputs to:", OUT_DIR)
print("Figures:")
print(" - gradcam_best_mixed_gtfocus.png")
print(" - gradcam_worst_mixed_gtfocus.png")
print(" - gradcam_worst_mixed_fp_focus.png")
print(" - gradcam_worst_mixed_fn_focus.png")
print(" - spectral_indices_worst_mixed.png")
print("CSVs:")
print(" -", csv_path)
if len(groups) > 0:
    print(" -", sum_path)
