# StegaShield Visualization Dashboard

This notebook inspects pairs of original vs. watermarked images, computes PSNR/SSIM, and produces summary plots similar to the reference screenshots.

**Quick start**
1. Set the dataset paths in the config cell below (expects `split/originals` and `split/watermarked`).
2. Optional: run the `pip install` cell if you're in a fresh Colab/runtime.
3. Execute the notebook top to bottom to regenerate tables and visualizations.



In [None]:
# Optional: install dependencies when running in Colab / fresh env
# !pip install --quiet numpy pandas matplotlib seaborn pillow opencv-python


In [None]:
from pathlib import Path
from typing import List, Dict
import shutil

import cv2
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

from stegashield_profiles import embed_image
from training.test_harness_det import TestHarness

plt.style.use("seaborn-v0_8")

# --- Configuration ---
ORIGINAL_SUBDIR = "originals"
WATERMARKED_SUBDIR = "watermarked"
MAX_PAIRS_PER_SPLIT = None  # set to an int to subsample
DISPLAY_SAMPLE_COUNT = 3

# Scenario 1: you already have dataset/<split>/{originals,watermarked}
DATA_ROOT = Path("dataset")
SPLITS: List[str] = ["train", "val", "test"]

# Scenario 2: you only have a folder of raw originals
AUTO_BUILD_FROM_ORIGINALS = True
ORIGINALS_ONLY_DIR = Path("originals_unused")
AUTO_DATA_ROOT = Path("dataset_autogen")
AUTO_SPLIT_NAME = "auto"
WATERMARK_MODE = "hybrid"
AUTO_MESSAGE = "AutoProfile"
AUTO_USER_KEY = None  # set to a string if you want hashed payloads
AUTO_MAX_ORIGINALS = 200  # set to None for no limit

if AUTO_BUILD_FROM_ORIGINALS and ORIGINALS_ONLY_DIR.exists():
    auto_orig_dir = AUTO_DATA_ROOT / AUTO_SPLIT_NAME / ORIGINAL_SUBDIR
    auto_wm_dir = AUTO_DATA_ROOT / AUTO_SPLIT_NAME / WATERMARKED_SUBDIR
    auto_orig_dir.mkdir(parents=True, exist_ok=True)
    auto_wm_dir.mkdir(parents=True, exist_ok=True)

    selected_paths = [p for p in ORIGINALS_ONLY_DIR.iterdir() if p.is_file()]
    if AUTO_MAX_ORIGINALS is not None:
        selected_paths = selected_paths[:AUTO_MAX_ORIGINALS]

    for src_path in selected_paths:
        dst_orig = auto_orig_dir / src_path.name
        if not dst_orig.exists():
            shutil.copy2(src_path, dst_orig)
        dst_wm = auto_wm_dir / src_path.with_suffix(".png").name
        if not dst_wm.exists():
            try:
                meta = embed_image(
                    image_path=str(src_path),
                    message=AUTO_MESSAGE,
                    mode=WATERMARK_MODE,
                    user_key=AUTO_USER_KEY,
                    output_dir=str(auto_wm_dir),
                )
            except ValueError as exc:
                print(f"⚠️ Skipping {src_path.name}: {exc}")
                continue
            produced_path = Path(meta["image_path"])
            if produced_path != dst_wm:
                if dst_wm.exists():
                    dst_wm.unlink()
                shutil.move(produced_path, dst_wm)
            produced_meta = Path(meta["metadata_path"])
            desired_meta = auto_wm_dir / f"{dst_wm.stem}_metadata.json"
            if produced_meta != desired_meta:
                if desired_meta.exists():
                    desired_meta.unlink()
                shutil.move(produced_meta, desired_meta)
            print(f"✅ Watermarked {src_path.name} → {dst_wm.name}")
        else:
            print(f"ℹ️ Already watermarked: {dst_wm.name}")
    DATA_ROOT = AUTO_DATA_ROOT
    SPLITS = [AUTO_SPLIT_NAME]
    print(f"✓ Auto-generated dataset with {len(list(auto_orig_dir.iterdir()))} originals at {AUTO_DATA_ROOT}")

assert DATA_ROOT.exists(), f"Dataset root not found: {DATA_ROOT.resolve()}"

def find_pairs(split: str) -> List[Dict]:
    orig_dir = DATA_ROOT / split / ORIGINAL_SUBDIR
    wm_dir = DATA_ROOT / split / WATERMARKED_SUBDIR
    if not orig_dir.exists() or not wm_dir.exists():
        print(f"⚠️ Skipping split '{split}' (missing {orig_dir} or {wm_dir})")
        return []

    orig_files = {p.stem: p for p in orig_dir.iterdir() if p.is_file()}
    wm_files = {p.stem: p for p in wm_dir.iterdir() if p.is_file()}
    common = sorted(set(orig_files.keys()) & set(wm_files.keys()))
    if MAX_PAIRS_PER_SPLIT:
        common = common[:MAX_PAIRS_PER_SPLIT]

    pairs = []
    for stem in common:
        pairs.append({
            "split": split,
            "image_id": stem,
            "original_path": orig_files[stem],
            "watermarked_path": wm_files[stem],
        })
    print(f"✓ {split}: found {len(pairs)} matched pairs")
    return pairs

all_pairs: List[Dict] = []
for split in SPLITS:
    all_pairs.extend(find_pairs(split))

if not all_pairs:
    raise SystemExit("No overlapping images found between original/watermarked folders.")

len(all_pairs)


> **Tip:** If you only have raw originals (e.g., `originals_unused/`), leave `AUTO_BUILD_FROM_ORIGINALS = True`. The notebook will copy them into `dataset_autogen/auto/originals`, embed hybrid watermarks on the fly, and use those generated outputs for every visualization cell below. Set the flag to `False` once you already have a curated dataset to avoid re-watermarking on each run.


## Optional: Run Attack Harness from Here
Toggle `RUN_HARNESS = True` in the next cell to embed + attack-test a subset of your dataset without leaving this notebook. The harness will populate `outputs/results.csv`, which feeds the attack visualizations later.


In [None]:
RUN_HARNESS = False
HARNESS_IMAGE_LIMIT = 50
HARNESS_OUTPUT_DIR = Path("outputs")
HARNESS_RESULTS_NAME = "results.csv"
HARNESS_MESSAGE_LSB = "StegaShield_Dataset2025"
HARNESS_MESSAGE_SEMI = "StegaShield_SemiFragile"

if RUN_HARNESS:
    harness_images = []
    for split in SPLITS:
        src_dir = DATA_ROOT / split / ORIGINAL_SUBDIR
        if not src_dir.exists():
            continue
        imgs = sorted([p for p in src_dir.iterdir() if p.is_file()])
        if HARNESS_IMAGE_LIMIT:
            imgs = imgs[:HARNESS_IMAGE_LIMIT]
        harness_images.extend([str(p) for p in imgs])

    if harness_images:
        harness = TestHarness(output_dir=str(HARNESS_OUTPUT_DIR), lsb_message=HARNESS_MESSAGE_LSB, semi_message=HARNESS_MESSAGE_SEMI)
        csv_path = harness.run_batch(harness_images, csv_name=HARNESS_RESULTS_NAME)
        print(f"✅ Test harness completed. Results stored at {csv_path}")
    else:
        print("⚠️ No images found to run the harness. Check DATA_ROOT and splits.")


In [None]:
HARNESS_RESULTS_CSV = Path("outputs") / "results.csv"
print(f"Harness CSV: {HARNESS_RESULTS_CSV}")


In [None]:
def load_rgb(path: Path) -> np.ndarray:
    arr = cv2.imread(str(path), cv2.IMREAD_COLOR)
    if arr is None:
        raise FileNotFoundError(path)
    return cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)


def compute_psnr(img_a: np.ndarray, img_b: np.ndarray) -> float:
    return cv2.PSNR(img_a, img_b)


def compute_ssim(img_a: np.ndarray, img_b: np.ndarray) -> float:
    img_a_gray = cv2.cvtColor(img_a, cv2.COLOR_RGB2GRAY)
    img_b_gray = cv2.cvtColor(img_b, cv2.COLOR_RGB2GRAY)
    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2
    kernel = (11, 11)
    sigma = 1.5
    mu1 = cv2.GaussianBlur(img_a_gray, kernel, sigma)
    mu2 = cv2.GaussianBlur(img_b_gray, kernel, sigma)
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = cv2.GaussianBlur(img_a_gray * img_a_gray, kernel, sigma) - mu1_sq
    sigma2_sq = cv2.GaussianBlur(img_b_gray * img_b_gray, kernel, sigma) - mu2_sq
    sigma12 = cv2.GaussianBlur(img_a_gray * img_b_gray, kernel, sigma) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return float(ssim_map.mean())


def compute_diff_heatmap(img_a: np.ndarray, img_b: np.ndarray, gain: float = 6.0) -> np.ndarray:
    diff = (img_b.astype(np.float32) - img_a.astype(np.float32)) * gain + 128
    diff = np.clip(diff, 0, 255).astype(np.uint8)
    return diff


In [None]:
records = []
for pair in all_pairs:
    orig = load_rgb(pair["original_path"])
    wm = load_rgb(pair["watermarked_path"])
    psnr = compute_psnr(orig, wm)
    ssim = compute_ssim(orig, wm)
    records.append({
        "split": pair["split"],
        "image_id": pair["image_id"],
        "psnr": psnr,
        "ssim": ssim,
        "original_path": pair["original_path"],
        "watermarked_path": pair["watermarked_path"],
    })

metrics_df = pd.DataFrame(records)
metrics_df.head()


In [None]:
def show_sample_triples(df: pd.DataFrame, n: int = 3):
    subset = df.sample(min(n, len(df)), random_state=42)
    fig, axes = plt.subplots(len(subset), 3, figsize=(14, 4 * len(subset)))
    if len(subset) == 1:
        axes = np.expand_dims(axes, axis=0)

    for row_axes, (_, row) in zip(axes, subset.iterrows()):
        orig = load_rgb(row["original_path"])
        wm = load_rgb(row["watermarked_path"])
        diff = compute_diff_heatmap(orig, wm, gain=8.0)

        row_axes[0].imshow(orig)
        row_axes[0].set_title(f"Original\n{row['image_id']}")
        row_axes[0].axis("off")

        row_axes[1].imshow(wm)
        row_axes[1].set_title(f"Watermarked\nPSNR: {row['psnr']:.2f} dB, SSIM: {row['ssim']:.4f}")
        row_axes[1].axis("off")

        row_axes[2].imshow(diff)
        row_axes[2].set_title("Difference (amplified)")
        row_axes[2].axis("off")

    plt.tight_layout()

show_sample_triples(metrics_df, n=DISPLAY_SAMPLE_COUNT)


In [None]:
PSNR_THRESHOLD = 30.0
SSIM_THRESHOLD = 0.90

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

sns.histplot(data=metrics_df, x="psnr", hue="split", multiple="stack", ax=axes[0, 0])
axes[0, 0].axvline(PSNR_THRESHOLD, color="red", linestyle="--", label=f"{PSNR_THRESHOLD} dB")
axes[0, 0].set_title("PSNR Distribution by Split")
axes[0, 0].legend()

sns.histplot(data=metrics_df, x="ssim", hue="split", multiple="stack", ax=axes[0, 1])
axes[0, 1].axvline(SSIM_THRESHOLD, color="red", linestyle="--", label=f"{SSIM_THRESHOLD} SSIM")
axes[0, 1].set_title("SSIM Distribution by Split")
axes[0, 1].legend()

sns.boxplot(data=metrics_df, x="split", y="psnr", ax=axes[0, 2])
axes[0, 2].axhline(PSNR_THRESHOLD, color="red", linestyle="--")
axes[0, 2].set_title("PSNR Box Plot by Split")

sns.boxplot(data=metrics_df, x="split", y="ssim", ax=axes[1, 0])
axes[1, 0].axhline(SSIM_THRESHOLD, color="red", linestyle="--")
axes[1, 0].set_title("SSIM Box Plot by Split")

sns.scatterplot(data=metrics_df, x="psnr", y="ssim", hue="split", ax=axes[1, 1], s=25, alpha=0.8)
axes[1, 1].axvline(PSNR_THRESHOLD, color="red", linestyle="--")
axes[1, 1].axhline(SSIM_THRESHOLD, color="red", linestyle="--")
axes[1, 1].set_title("PSNR vs SSIM Correlation")

axes[1, 2].axis("off")
summary = metrics_df.groupby("split").agg(
    psnr_mean=("psnr", "mean"),
    psnr_std=("psnr", "std"),
    ssim_mean=("ssim", "mean"),
    ssim_std=("ssim", "std"),
)
axes[1, 2].table(
    cellText=[[f"{row.psnr_mean:.2f} ± {row.psnr_std:.2f}", f"{row.ssim_mean:.3f} ± {row.ssim_std:.3f}"] for row in summary.itertuples()],
    rowLabels=summary.index,
    colLabels=["PSNR (dB)", "SSIM"],
    loc="center",
)
axes[1, 2].set_title("Summary Statistics")

plt.tight_layout()



In [None]:
OUTPUT_METRICS_CSV = DATA_ROOT / "metrics_summary.csv"
metrics_df.to_csv(OUTPUT_METRICS_CSV, index=False)
print(f"Metrics written to {OUTPUT_METRICS_CSV}")


## Attack Robustness Visuals

If you have already run `training/test_harness_det.py`, this section consumes the resulting `outputs/results.csv` file to show how each profile layer survives different attacks.


In [None]:
if HARNESS_RESULTS_CSV.exists():
    attacks_df = pd.read_csv(HARNESS_RESULTS_CSV)

    def to_bool(val):
        if pd.isna(val):
            return np.nan
        if isinstance(val, bool):
            return val
        return str(val).strip().lower() in {"true", "1", "yes", "y"}

    attacks_df["decode_success"] = attacks_df["decode_success"].apply(to_bool)
    attacks_df["bit_accuracy"] = pd.to_numeric(attacks_df["bit_accuracy"], errors="coerce")
    print(f"Loaded {len(attacks_df)} attack evaluations from {HARNESS_RESULTS_CSV}")
else:
    attacks_df = None
    print(f"⚠️ Harness CSV not found at {HARNESS_RESULTS_CSV}. Run the test harness first or update the path.")
    print("   Generate it via training/test_harness_det.py → TestHarness.run_batch(...)")


In [None]:
if attacks_df is not None and not attacks_df.empty:
    attack_summary = (
        attacks_df.groupby(["layer", "attack"]).agg(
            success_rate=("decode_success", "mean"),
            avg_bit_accuracy=("bit_accuracy", "mean"),
            samples=("decode_success", "size"),
        )
    ).reset_index()

    heatmap_data = attack_summary.pivot(index="attack", columns="layer", values="success_rate").fillna(0)
    plt.figure(figsize=(10, max(6, heatmap_data.shape[0] * 0.3)))
    sns.heatmap(heatmap_data, annot=True, fmt=".2f", vmin=0, vmax=1, cmap="YlGnBu")
    plt.title("Decode Success Rate by Attack / Layer")
    plt.xlabel("Layer")
    plt.ylabel("Attack")
    plt.tight_layout()
else:
    attack_summary = None
    print("No attack data available to plot.")


In [None]:
if attack_summary is not None and not attack_summary.empty:
    semi_summary = (
        attack_summary[attack_summary["layer"] == "SemiFragile"]
        .sort_values("avg_bit_accuracy", ascending=False)
        .reset_index(drop=True)
    )

    plt.figure(figsize=(10, max(4, semi_summary.shape[0] * 0.3)))
    sns.barplot(data=semi_summary, x="avg_bit_accuracy", y="attack", palette="magma")
    plt.title("Average Bit Accuracy per Attack (Semi-Fragile)")
    plt.xlabel("Bit Accuracy")
    plt.ylabel("Attack")
    plt.xlim(0, 1)
    plt.tight_layout()

    print("Top attacks harming Semi-Fragile layer (by bit accuracy):")
    display(semi_summary.sort_values("avg_bit_accuracy").head(5)[["attack", "avg_bit_accuracy", "success_rate", "samples"]])

    lsb_summary = attack_summary[attack_summary["layer"] == "LSB"].sort_values("success_rate")
    print("\nWorst-case attacks for LSB layer (ownership-proof):")
    display(lsb_summary.head(5)[["attack", "success_rate", "samples"]])
else:
    print("No attack summary to visualize.")


## Notes
- The histogram/box-plot thresholds (`PSNR_THRESHOLD`, `SSIM_THRESHOLD`) highlight the quality targets you care about—tweak them per project.
- To inspect additional splits (e.g., `holdout`, `ablation`), just add them to `SPLITS` and rerun the first config cell.
- `metrics_summary.csv` captures per-image stats so you can feed them into other dashboards or reports.

