This notebook visualizes the metric, confusion metric and plots

In [None]:
import os, numpy as np, matplotlib.pyplot as plt, rasterio
from tensorflow.keras.models import load_model
from rasterio.windows import Window

In [None]:
RAW_IMAGE = "../data/raw/20241110_053942_45_24f7_3B_AnalyticMS_SR_8b_clip.tif"
RAW_MASK  = "../data/raw/20241110_053942_45_24f7_3B_AnalyticMS_SR_8b_clip_Hybrid_mask.tif"
MODEL_DIR = "../experiments_all"
SAVE_PATH = "../experiments_all/full_comparison_scene.png"

In [None]:
from rasterio.windows import Window
def tile_raster_pair(image_path, mask_path, tile_size=IMG_SIZE):
    imgs, msks = [], []
    with rasterio.open(image_path) as img_src, rasterio.open(mask_path) as mask_src:
        for top in range(0, img_src.height - tile_size[0] + 1, tile_size[0]):
            for left in range(0, img_src.width - tile_size[1] + 1, tile_size[1]):
                window = Window(left, top, tile_size[1], tile_size[0])
                img = np.moveaxis(img_src.read(window=window), 0, 2).astype(np.float32)
                mask = mask_src.read(1, window=window).astype(np.uint8)
                img = img / (np.max(img) + 1e-8)
                imgs.append(img)
                msks.append(np.expand_dims(mask, -1))
    return np.array(imgs), np.array(msks)
X, Y = tile_raster_pair(RAW_IMAGE, RAW_MASK)
print(f" Loaded {len(X)} tiles for evaluation, shape {X.shape}")

In [None]:
model_paths = {
    "U-Net": os.path.join(OUT_DIR, "unet.keras"),
    "ResU-Net": os.path.join(OUT_DIR, "resunet.keras"),
    "Attn-U-Net": os.path.join(OUT_DIR, "attnunet.keras"),
    "Attn-ResU-Net": os.path.join(OUT_DIR, "attnresunet.keras"),
    "ASDMS": os.path.join(OUT_DIR, "asdms.keras")
}
models_loaded = {}
for name, path in model_paths.items():
    if os.path.exists(path):
        models_loaded[name] = tf.keras.models.load_model(path, compile=False)
        print(f"Loaded {name}")
    else:
        print(f"Missing model file: {path}")

In [None]:
def evaluate_model(model, X, Y_true):
    preds = model.predict(X, verbose=0)
    preds_bin = (preds > 0.5).astype(np.uint8)
    y_true_flat = Y_true.flatten()
    y_pred_flat = preds_bin.flatten()
    acc = accuracy_score(y_true_flat, y_pred_flat)
    pre = precision_score(y_true_flat, y_pred_flat, zero_division=0)
    rec = recall_score(y_true_flat, y_pred_flat, zero_division=0)
    f1 = f1_score(y_true_flat, y_pred_flat, zero_division=0)
    iou = np.sum((y_true_flat & y_pred_flat)) / np.sum((y_true_flat | y_pred_flat) + 1e-6)
    cm = confusion_matrix(y_true_flat, y_pred_flat)
    return {"Accuracy": acc, "Precision": pre, "Recall": rec, "F1": f1, "IoU": iou}, preds_bin, cm
metrics_dict, conf_matrices = {}, {}
predictions = {}
for name, model in models_loaded.items():
    print(f"\n Evaluating {name}...")
    metrics, preds_bin, cm = evaluate_model(model, X, Y)
    metrics_dict[name] = metrics
    predictions[name] = preds_bin
    conf_matrices[name] = cm
    print(metrics)
metrics_df = pd.DataFrame(metrics_dict).T
metrics_df.to_csv(os.path.join(OUT_DIR, "model_metrics.csv"))
print("\n Metrics saved to model_metrics.csv")

In [None]:
plt.figure(figsize=(8,5))
sns.barplot(x=metrics_df.index, y=metrics_df["Accuracy"])
plt.title("Model Accuracy Comparison")
plt.xticks(rotation=30)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "accuracy_graph.png"), dpi=300)
plt.show()
plt.figure(figsize=(8,5))
sns.barplot(x=metrics_df.index, y=metrics_df["Recall"])
plt.title("Model Recall Comparison")
plt.xticks(rotation=30)
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "recall_graph.png"), dpi=300)
plt.show()

In [None]:
for name, cm in conf_matrices.items():
    plt.figure(figsize=(4,4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"Confusion Matrix â€” {name}")
    plt.xlabel("Predicted"); plt.ylabel("True")
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, f"{name}_confusion_matrix.png"), dpi=300)
    plt.show()

In [None]:
with rasterio.open(RAW_IMAGE) as src:
    img = np.moveaxis(src.read(), 0, 2)
    rgb = img[..., [4,2,1]] if img.shape[-1] > 4 else img[..., :3]
    rgb = (rgb - np.min(rgb)) / (np.max(rgb) - np.min(rgb) + 1e-6)
with rasterio.open(RAW_MASK) as src:
    mask = src.read(1)
plt.figure(figsize=(10,6))
plt.imshow(rgb)
plt.title("Satellite Image (Natural Color)")
plt.axis("off")
plt.savefig(os.path.join(OUT_DIR, "satellite_only.png"), dpi=300)
plt.show()
plt.figure(figsize=(10,6))
plt.imshow(rgb, alpha=0.8)
plt.imshow(mask, cmap="Reds", alpha=0.5)
plt.title("Satellite vs Mask Overlay")
plt.axis("off")
plt.savefig(os.path.join(OUT_DIR, "satellite_vs_mask.png"), dpi=300)
plt.show()
print("\n All evaluations and visualizations completed!")
print(f" Check results in folder: {OUT_DIR}")