###############################################################################################################################

In [None]:
import os, json, cv2, math
from concurrent.futures import ThreadPoolExecutor, as_completed
from tesserocr import PyTessBaseAPI, OEM, PSM
from PIL import Image

# CONFIG
TRAIN_DIR = r"C:\Users\stopc\Desktop\LPR_Project\data\A\val"
METADATA_PATH = os.path.join(TRAIN_DIR, "metadata.json")
N_WORKERS = 8

# LOAD METADATA
with open(METADATA_PATH, "r") as f:
    records = json.load(f)


# WORKER FUNCTION
def process_chunk(chunk):
    """
    OCR all plates in <chunk> (list of metadata dicts).
    Returns: (digit_ok, digit_total, plate_ok, plate_total, lines)
    """
    with PyTessBaseAPI(oem=OEM.LSTM_ONLY, psm=PSM.SINGLE_WORD) as api:
        api.SetVariable("tessedit_char_whitelist", "0123456789")
        api.SetVariable("load_system_dawg", "0")
        api.SetVariable("load_freq_dawg", "0")

        dig_ok = dig_tot = pl_ok = pl_tot = 0
        out_lines = []

        for rec in chunk:
            idx = rec["index"]
            truth = rec["plate_number"]
            img = cv2.imread(os.path.join(TRAIN_DIR, f"original_{idx}.png"))
            if img is None:
                out_lines.append(f"Plate {idx:4d} | [MISSING IMAGE]")
                continue

            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            _, bin_ = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)

            ocr_str = ""
            for i, (x, y, w, h) in enumerate(rec["digit_bboxes"]):

                patch = bin_[y : y + h, x : x + w]
                # convert NumPy array to PIL for tesserocr
                patch_pil = Image.fromarray(patch)
                api.SetImage(patch_pil)

                txt = api.GetUTF8Text().strip()
                digit = txt[0] if txt and txt[0].isdigit() else ""
                ocr_str += digit

                dig_tot += 1
                if digit == truth[i]:
                    dig_ok += 1

            if ocr_str == truth:
                pl_ok += 1
            pl_tot += 1

            status = "OK" if ocr_str == truth else "ERR"
            out_lines.append(f"Plate {idx:4d} | GT={truth} | OCR={ocr_str:<6} | {status}")

        return dig_ok, dig_tot, pl_ok, pl_tot, out_lines


#  SPLIT WORK & RUN THREADS
chunk_size = math.ceil(len(records) / N_WORKERS)
chunks = [records[i : i + chunk_size] for i in range(0, len(records), chunk_size)]

digit_ok = digit_tot = plate_ok = plate_tot = 0
all_lines = []

with ThreadPoolExecutor(max_workers=N_WORKERS) as exe:
    futures = [exe.submit(process_chunk, c) for c in chunks]
    for fut in as_completed(futures):
        d_ok, d_tot, p_ok, p_tot, lines = fut.result()
        digit_ok += d_ok
        digit_tot += d_tot
        plate_ok += p_ok
        plate_tot += p_tot
        all_lines.extend(lines)

#  SUMMARY
print("\nRESULTS")
print(f"  Digit-level accuracy: {digit_ok / digit_tot * 100:.2f}% " f"({digit_ok}/{digit_tot})")
print(f"  Plate-level accuracy: {plate_ok / plate_tot * 100:.2f}% " f"({plate_ok}/{plate_tot})")

In [None]:
import os, json, cv2, math
from concurrent.futures import ThreadPoolExecutor, as_completed
from tesserocr import PyTessBaseAPI, OEM, PSM
from PIL import Image

# CONFIG
DIR = r"C:\Users\stopc\Desktop\LPR_Project\results\A\restormer"
METADATA_DIR = r"C:\Users\stopc\Desktop\LPR_Project\data\full_grid"
METADATA_PATH = os.path.join(METADATA_DIR, "metadata.json")
N_WORKERS = 8

# LOAD METADATA
with open(METADATA_PATH, "r") as f:
    records = json.load(f)


# WORKER FUNCTION
def process_chunk(chunk):
    """
    OCR all plates in <chunk> (list of metadata dicts).
    Returns: (digit_ok, digit_total, plate_ok, plate_total, lines)
    """
    with PyTessBaseAPI(oem=OEM.LSTM_ONLY, psm=PSM.SINGLE_WORD) as api:
        api.SetVariable("tessedit_char_whitelist", "0123456789")
        api.SetVariable("load_system_dawg", "0")
        api.SetVariable("load_freq_dawg", "0")

        dig_ok = dig_tot = pl_ok = pl_tot = 0
        out_lines = []

        for rec in chunk:
            idx = rec["index"]
            truth = rec["plate_number"]
            img = cv2.imread(os.path.join(DIR, f"reconstructed_{idx}.png"))
            if img is None:
                out_lines.append(f"Plate {idx:4d} | [MISSING IMAGE]")
                continue

            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            _, bin_ = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)

            ocr_str = ""
            for i, (x, y, w, h) in enumerate(rec["digit_bboxes"]):

                patch = bin_[y : y + h, x : x + w]
                # convert NumPy array to PIL for tesserocr
                patch_pil = Image.fromarray(patch)
                api.SetImage(patch_pil)

                txt = api.GetUTF8Text().strip()
                digit = txt[0] if txt and txt[0].isdigit() else ""
                ocr_str += digit

                dig_tot += 1
                if digit == truth[i]:
                    dig_ok += 1

            if ocr_str == truth:
                pl_ok += 1
            pl_tot += 1

            status = "OK" if ocr_str == truth else "ERR"
            out_lines.append(f"Plate {idx:4d} | GT={truth} | OCR={ocr_str:<6} | {status}")

        return dig_ok, dig_tot, pl_ok, pl_tot, out_lines


#  SPLIT WORK & RUN THREADS
chunk_size = math.ceil(len(records) / N_WORKERS)
chunks = [records[i : i + chunk_size] for i in range(0, len(records), chunk_size)]

digit_ok = digit_tot = plate_ok = plate_tot = 0
all_lines = []

with ThreadPoolExecutor(max_workers=N_WORKERS) as exe:
    futures = [exe.submit(process_chunk, c) for c in chunks]
    for fut in as_completed(futures):
        d_ok, d_tot, p_ok, p_tot, lines = fut.result()
        digit_ok += d_ok
        digit_tot += d_tot
        plate_ok += p_ok
        plate_tot += p_tot
        all_lines.extend(lines)

#  PRINT DETAIL (comment out if too verbose)
for ln in sorted(all_lines):
    # only print plates where OCR != GT
    if "| ERR" in ln or "[MISSING IMAGE]" in ln:
        print(ln)

#  SUMMARY
print("\nRESULTS")
print(f"  Digit-level accuracy: {digit_ok / digit_tot * 100:.2f}% " f"({digit_ok}/{digit_tot})")
print(f"  Plate-level accuracy: {plate_ok / plate_tot * 100:.2f}% " f"({plate_ok}/{plate_tot})")

In [None]:
import os
import matplotlib.pyplot as plt
from PIL import Image

#  CONFIG
DATA_DIR = "results/A/restormer"

# error list: (index, OCR_str)
errors = [
    (39522, "602395"),
    (39523, "137218"),
    (39524, "865378"),
    (39525, "653310"),
    (39526, "389910"),
    (39527, "464894"),
    (39528, "289290"),
    (39529, "029805"),
    (39530, "436875"),
    (39531, "738414"),
    (39532, "489838"),
    (39533, "889556"),
    (39534, "082369"),
    (39535, "423725"),
    (39536, "888024"),
    (39537, "845095"),
    (39538, "826932"),
    (39539, "911231"),
    (39540, "230770"),
    (39541, "735023"),
    (39542, "325040"),
    (39543, "323484"),
    (39544, "338354"),
    (39545, "755102"),
    (39546, "707206"),
    (39547, "449489"),
    (39548, "479843"),
    (39549, "133239"),
    (39550, "301370"),
    (39551, "588666"),
    (39552, "733089"),
    (39553, "745937"),
    (39554, "763805"),
    (39555, "881739"),
    (39556, "982350"),
    (39557, "482399"),
    (39558, "147356"),
    (39559, "103312"),
]

#  PLOTTING
n = len(errors)
cols = 5
rows = (n + cols - 1) // cols

plt.figure(figsize=(cols * 4, rows * 1.3))
for i, (idx, ocr_str) in enumerate(errors, start=1):
    img_path = os.path.join(DATA_DIR, f"reconstructed_{idx}.png")
    if not os.path.exists(img_path):
        continue

    img = Image.open(img_path)
    ax = plt.subplot(rows, cols, i)
    ax.imshow(img)
    # Title now includes the plate index before the OCR result
    ax.set_title(f"idx:{idx} | OCR:{ocr_str}", fontsize=10)
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import os
import json
import cv2
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
from tesserocr import PyTessBaseAPI, OEM, PSM
from PIL import Image

#  CONFIG
DIR = r"C:\Users\stopc\Desktop\LPR_Project\results\A\restormer"
METADATA_DIR = r"C:\Users\stopc\Desktop\LPR_Project\data\full_grid"
METADATA_PATH = os.path.join(METADATA_DIR, "metadata.json")
N_WORKERS = 8
UPSCALE_FACTOR = 2
FALLBACK_DIGIT = "0"

# 1) Load full metadata (each rec must include "index", "plate_number", and "digit_bboxes")
with open(METADATA_PATH, "r", encoding="utf-8") as f:
    records = json.load(f)
if not records:
    raise RuntimeError("No records found in metadata.json")

# 2) Binarisation / upscaling recipes (always with digit whitelist)
RECIPES = [
    ("fixed", False, PSM.SINGLE_WORD),
    ("otsu", False, PSM.SINGLE_WORD),
    ("adaptive", False, PSM.SINGLE_CHAR),
    ("fixed", True, PSM.SINGLE_CHAR),
    ("otsu", True, PSM.SINGLE_WORD),
]


def preprocess(gray, mode, invert):
    flag = cv2.THRESH_BINARY_INV if invert else cv2.THRESH_BINARY
    if mode == "fixed":
        _, bin_img = cv2.threshold(gray, 150, 255, flag)
    elif mode == "otsu":
        _, bin_img = cv2.threshold(gray, 0, 255, flag | cv2.THRESH_OTSU)
    elif mode == "adaptive":
        bin_img = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, flag, 11, 2)
    else:
        raise ValueError(f"Unknown mode: {mode}")
    return bin_img


# 3) Patch-level OCR → (digit_char, used_fallback_flag)
def ocr_single_digit(api, gray_patch):
    for mode, inv, psm in RECIPES:
        bin_p = preprocess(gray_patch, mode, inv)
        h, w = bin_p.shape
        up = cv2.resize(bin_p, (w * UPSCALE_FACTOR, h * UPSCALE_FACTOR), cv2.INTER_CUBIC)

        api.SetPageSegMode(psm)
        api.SetVariable("tessedit_char_whitelist", "0123456789")
        api.SetVariable("classify_bln_numeric_mode", "1")

        api.SetImage(Image.fromarray(up))
        api.Recognize()
        txt = api.GetUTF8Text() or ""
        for ch in txt:
            if ch.isdigit():
                return ch, False

    # none succeeded → fallback
    return FALLBACK_DIGIT, True


# 4) Full-plate OCR rescue → digit string (possibly shorter)
def ocr_full_plate(api, gray_plate):
    for mode, inv, psm in RECIPES:
        bin_p = preprocess(gray_plate, mode, inv)
        h, w = bin_p.shape
        up = cv2.resize(bin_p, (w * UPSCALE_FACTOR, h * UPSCALE_FACTOR), cv2.INTER_CUBIC)

        api.SetPageSegMode(psm)
        api.SetVariable("tessedit_char_whitelist", "0123456789")
        api.SetVariable("classify_bln_numeric_mode", "1")

        api.SetImage(Image.fromarray(up))
        api.Recognize()
        digits = "".join(ch for ch in (api.GetUTF8Text() or "") if ch.isdigit())
        if digits:
            return digits

    return ""


# 5) Worker: process a list of records
def process_chunk(chunk):
    with PyTessBaseAPI(oem=OEM.LSTM_ONLY, psm=PSM.SINGLE_WORD) as api:
        dig_ok = dig_tot = pl_ok = pl_tot = 0
        plates_no_fallback = 0
        out_lines = []

        for rec in chunk:
            idx = rec["index"]
            truth = rec["plate_number"]
            img = cv2.imread(os.path.join(DIR, f"reconstructed_{idx}.png"))
            if img is None:
                out_lines.append(f"Plate {idx:4d} | [MISSING IMAGE]")
                continue

            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

            #  patch-level OCR
            chars = []
            fb_mask = []
            for i, (x, y, w, h) in enumerate(rec["digit_bboxes"]):
                digit, fb = ocr_single_digit(api, gray[y : y + h, x : x + w])
                chars.append(digit)
                fb_mask.append(fb)
                dig_tot += 1
                if digit == truth[i]:
                    dig_ok += 1

            #  full-plate rescue if needed
            if any(fb_mask):
                full = ocr_full_plate(api, gray)
                if len(full) >= len(chars):
                    for i, fb in enumerate(fb_mask):
                        if fb:
                            chars[i] = full[i]
                            fb_mask[i] = False

            plate_ocr = "".join(chars)
            if plate_ocr == truth:
                pl_ok += 1
            pl_tot += 1
            if not any(fb_mask):
                plates_no_fallback += 1

            out_lines.append(
                f"Plate {idx:4d} | GT={truth} | OCR={plate_ocr:<6} | " f"{'OK' if plate_ocr==truth else 'ERR'}"
            )

        return dig_ok, dig_tot, pl_ok, pl_tot, plates_no_fallback, out_lines


#  SPLIT & RUN
chunk_size = math.ceil(len(records) / N_WORKERS)
chunks = [records[i : i + chunk_size] for i in range(0, len(records), chunk_size)]

digit_ok = digit_tot = plate_ok = plate_tot = total_no_fallback = 0
all_lines = []

with ThreadPoolExecutor(max_workers=N_WORKERS) as exe:
    futures = [exe.submit(process_chunk, c) for c in chunks]
    for fut in as_completed(futures):
        d_ok, d_tot, p_ok, p_tot, no_fb, lines = fut.result()
        digit_ok += d_ok
        digit_tot += d_tot
        plate_ok += p_ok
        plate_tot += p_tot
        total_no_fallback += no_fb
        all_lines.extend(lines)

#  DETAIL (only errors)
for ln in sorted(all_lines):
    if "| ERR" in ln or "[MISSING IMAGE]" in ln:
        print(ln)

#  SUMMARY
print("\nRESULTS")
print(f"  Digit-level accuracy: {digit_ok/ digit_tot*100:.2f}% " f"({digit_ok}/{digit_tot})")
print(f"  Plate-level accuracy: {plate_ok/ plate_tot*100:.2f}% " f"({plate_ok}/{plate_tot})")

total_plates = len(records)
fallback_plates = total_plates - total_no_fallback
print(
    f"  Plates with all digits from Tesseract (no fallbacks): "
    f"{total_no_fallback}/{total_plates} "
    f"({total_no_fallback/total_plates*100:.2f}%)"
)
print(
    f"  Plates requiring at least one fallback '0': "
    f"{fallback_plates}/{total_plates} "
    f"({fallback_plates/total_plates*100:.2f}%)"
)

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
from scripts.lp_processing import (
    create_license_plate,
    warp_image,
    simulate_noise,
    dewarp_image,
    crop_to_original_size,
)

random.seed(100)
np.random.seed(100)

# --- Test parameters (now 256×64 plates) ---
alpha, beta = 84, 0
width, height = 512, 128
text_size = 100
focal_length = width  # matches production f=original_width

# 1. Create a clean plate (PIL image in RGB) and get its corners
plate_pil, src_points, plate_number, _ = create_license_plate(width, height, text_size)

# 2. PIL → NumPy (RGB)
plate_rgb = np.array(plate_pil)

# 3. Warp (expects RGB, returns RGB)
warped_rgb, dst_points = warp_image(plate_rgb, np.array(src_points), alpha, beta, focal_length)

# 4. Noise (expects RGB, returns RGB)
noisy_rgb = simulate_noise(warped_rgb)

# 5. Dewarp (expects RGB, returns RGB)
dewarped_rgb = dewarp_image(noisy_rgb, src_points, dst_points)

orig_crop   = crop_to_original_size(plate_rgb,   width, height)
dewarp_crop = crop_to_original_size(dewarped_rgb, width, height)

fig, axes = plt.subplots(3, 1)

axes[0].imshow(orig_crop)
axes[0].axis('off')
axes[0].set_title(f'Original Plate')

axes[1].imshow(noisy_rgb)
axes[1].axis('off')
axes[1].set_title('Warped + Simulated Camera Noise')

axes[2].imshow(dewarp_crop)
axes[2].axis('off')
axes[2].set_title(f'Distorted Plate')

plt.tight_layout()
plt.show()


In [None]:
# #Heatmap of worst PSNR, SSIM, and OCR in parallel
import os
import json
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from tqdm import tqdm
import numpy as np
import mlflow
import mlflow.pytorch
import cv2
from pytorch_msssim import ssim
import pytesseract
from joblib import Parallel, delayed

# --------------------
# Configuration
# --------------------
data_dir = "data/full_grid"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MLflow model load
mlflow.set_experiment("Unet")
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name("Unet")
runs = client.search_runs(
    experiment_ids=experiment.experiment_id, order_by=["attributes.start_time DESC"], max_results=1
)
run_id = runs[0].info.run_id
model_uri = f"runs:/{run_id}/model"
model = mlflow.pytorch.load_model(model_uri)
model.eval().to(device)
print(f"Model loaded from run {run_id} in experiment '{experiment.name}' successfully.")

# --------------------
# Functions
# --------------------


def calculate_psnr(outputs, targets):
    mse = F.mse_loss(outputs, targets)
    if mse == 0:
        return float("inf")
    psnr = 10 * torch.log10(1 / mse)
    return psnr.item()


def ocr_single_digit(image_bgr):
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
    config = r"--oem 1 --psm 10 -c tessedit_char_whitelist=0123456789"
    text = pytesseract.image_to_string(thresh, config=config).strip()
    if len(text) == 1 and text.isdigit():
        return text
    return "?"


def align_and_update_bboxes(original_np, generated_np, digit_bboxes):
    search_margin = 16

    def process_digit_bbox(bbox):
        x, y, w, h = bbox
        original_digit = original_np[y : y + h, x : x + w, :]
        original_digit_gray = cv2.cvtColor((original_digit * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

        # Define search window
        search_x1 = max(0, x - search_margin)
        search_y1 = max(0, y - search_margin)
        search_x2 = min(generated_np.shape[1], x + w + search_margin)
        search_y2 = min(generated_np.shape[0], y + h + search_margin)
        search_region = generated_np[search_y1:search_y2, search_x1:search_x2, :]
        search_region_gray = cv2.cvtColor((search_region * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

        # Template matching
        result = cv2.matchTemplate(search_region_gray, original_digit_gray, cv2.TM_CCOEFF_NORMED)
        _, _, _, max_loc = cv2.minMaxLoc(result)
        best_x, best_y = max_loc[0] + search_x1, max_loc[1] + search_y1

        # Compute PSNR and SSIM
        aligned_digit = generated_np[best_y : best_y + h, best_x : best_x + w, :]
        original_digit_tensor = torch.from_numpy(original_digit.transpose(2, 0, 1)).unsqueeze(0)
        aligned_digit_tensor = torch.from_numpy(aligned_digit.transpose(2, 0, 1)).unsqueeze(0)

        psnr_val = calculate_psnr(aligned_digit_tensor, original_digit_tensor)
        ssim_val = ssim(aligned_digit_tensor, original_digit_tensor, data_range=1.0, size_average=True).item()

        return psnr_val, ssim_val, (best_x, best_y, w, h)

    results = Parallel(n_jobs=-1)(delayed(process_digit_bbox)(bbox) for bbox in digit_bboxes)
    psnr_values = [r[0] for r in results]
    ssim_values = [r[1] for r in results]
    updated_bboxes = [r[2] for r in results]

    return psnr_values, ssim_values, updated_bboxes


def compute_ocr_metrics(image_bgr, updated_bboxes, plate_number_gt, margin):
    def process_bbox(bbox):
        x, y, w, h = bbox
        x1 = max(0, x - margin)
        y1 = max(0, y - margin)
        x2 = min(image_bgr.shape[1], x + w + margin)
        y2 = min(image_bgr.shape[0], y + h + margin)
        digit_patch = image_bgr[y1:y2, x1:x2]
        recognized_digit = ocr_single_digit(digit_patch)
        return recognized_digit

    recognized_digits = Parallel(n_jobs=-1)(delayed(process_bbox)(bbox) for bbox in updated_bboxes)
    recognized_text = "".join(recognized_digits)
    gt = plate_number_gt
    correct_digits = sum(1 for a, b in zip(gt, recognized_text) if a == b)
    ocr_accuracy = correct_digits / len(gt) if len(gt) > 0 else 0.0
    ocr_binary = 1.0 if recognized_text == gt else 0.0
    return recognized_text, ocr_accuracy, ocr_binary


# --------------------------------------
# Compute metrics for each (alpha, beta)
# --------------------------------------
metadata_path = os.path.join(data_dir, "metadata.json")
with open(metadata_path, "r") as f:
    metadata_list = json.load(f)

psnr_dict_worst = {}
ssim_dict_worst = {}
ocr_acc_dict_avg = {}
ocr_bin_dict_avg = {}

to_tensor = transforms.ToTensor()

for metadata in tqdm(metadata_list, desc="Processing images", unit="plate"):
    alpha = metadata["alpha"]
    beta = metadata["beta"]
    digit_bboxes = metadata["digit_bboxes"]
    plate_number_gt = metadata["plate_number"]
    index = metadata["index"]

    index = metadata["index"]
    original_path = os.path.join(data_dir, f"original_{index}.png")
    distorted_path = os.path.join(data_dir, f"distorted_{index}.png")

    if not (os.path.exists(original_path) and os.path.exists(distorted_path)):
        continue

    original_img = to_tensor(Image.open(original_path).convert("RGB")).unsqueeze(0).to(device)
    distorted_img = to_tensor(Image.open(distorted_path).convert("RGB")).unsqueeze(0).to(device)

    with torch.no_grad():
        generated_img = model(distorted_img)
        generated_img = torch.clamp(generated_img, 0.0, 1.0)

    original_np = original_img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    generated_np = generated_img.squeeze(0).permute(1, 2, 0).cpu().numpy()

    # Parallelized CPU operations
    psnr_per_number, ssim_per_number, updated_bboxes = align_and_update_bboxes(original_np, generated_np, digit_bboxes)
    image_bgr = (generated_np * 255).astype(np.uint8)
    image_bgr = cv2.cvtColor(image_bgr, cv2.COLOR_RGB2BGR)
    recognized_text, ocr_accuracy, ocr_binary = compute_ocr_metrics(
        image_bgr, updated_bboxes, plate_number_gt, margin=2
    )

    # Take worst PSNR and SSIM
    worst_psnr = np.min(psnr_per_number) if psnr_per_number else 0.0
    worst_ssim = np.min(ssim_per_number) if ssim_per_number else 0.0

    if (alpha, beta) not in psnr_dict_worst:
        psnr_dict_worst[(alpha, beta)] = []
        ssim_dict_worst[(alpha, beta)] = []
        ocr_acc_dict_avg[(alpha, beta)] = []
        ocr_bin_dict_avg[(alpha, beta)] = []

    psnr_dict_worst[(alpha, beta)].append(worst_psnr)
    ssim_dict_worst[(alpha, beta)].append(worst_ssim)
    ocr_acc_dict_avg[(alpha, beta)].append(ocr_accuracy)
    ocr_bin_dict_avg[(alpha, beta)].append(ocr_binary)

alpha_values = sorted(set(a for (a, b) in psnr_dict_worst.keys()))
beta_values = sorted(set(b for (a, b) in psnr_dict_worst.keys()))
num_alphas, num_betas = len(alpha_values), len(beta_values)


def create_matrix_from_dict(data_dict):
    mat = np.full((num_betas, num_alphas), np.nan)
    alpha_to_index = {val: i for i, val in enumerate(alpha_values)}
    beta_to_index = {val: i for i, val in enumerate(beta_values)}
    for (a, b), val_list in data_dict.items():
        val = np.min(val_list) if val_list else np.nan
        mat[beta_to_index[b], alpha_to_index[a]] = val
    return mat


psnr_matrix = create_matrix_from_dict(psnr_dict_worst)
ssim_matrix = create_matrix_from_dict(ssim_dict_worst)
ocr_acc_matrix = create_matrix_from_dict(ocr_acc_dict_avg)
ocr_bin_matrix = create_matrix_from_dict(ocr_bin_dict_avg)

In [None]:
def show_image_details_for(alpha, beta):
    # find the one record in metadata_list
    found = next((m for m in metadata_list if m["alpha"] == alpha and m["beta"] == beta), None)
    if found is None:
        print("No images found for that angle.")
        return

    # unpack everything from `found`
    index = found["index"]
    plate_number_gt = found["plate_number"]
    digit_bboxes = sorted(found["digit_bboxes"], key=lambda b: b[0])

    # file paths
    original_path = os.path.join(data_dir, f"original_{index}.png")
    distorted_path = os.path.join(data_dir, f"distorted_{index}.png")

    # load & run the model
    orig_t = to_tensor(Image.open(original_path).convert("RGB")).unsqueeze(0).to(device)
    dist_t = to_tensor(Image.open(distorted_path).convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        gen_t = model(dist_t).clamp(0.0, 1.0)

    # to numpy
    orig_np = orig_t.squeeze(0).permute(1, 2, 0).cpu().numpy()
    gen_np = gen_t.squeeze(0).permute(1, 2, 0).cpu().numpy()

    # compute per-digit PSNR/SSIM + updated bboxes
    psnr_vals, ssim_vals, updated_bboxes = align_and_update_bboxes(orig_np, gen_np, digit_bboxes)

    # load for display
    distorted_image_cv = cv2.imread(distorted_path)
    distorted_image_rgb = cv2.cvtColor(distorted_image_cv, cv2.COLOR_BGR2RGB)

    # Original image (with rectangles and text)
    original_image_cv = cv2.imread(original_path)
    for i, (x, y, w, h) in enumerate(digit_bboxes, start=1):
        cv2.rectangle(original_image_cv, (x, y), (x + w, y + h), (0, 0, 255), 1)
        cv2.putText(original_image_cv, str(i), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 150, 0), 1)
    original_image_rgb = cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB)

    # Generated image (with rectangles and text)
    generated_bgr = (gen_np * 255).astype(np.uint8)
    generated_bgr = cv2.cvtColor(generated_bgr, cv2.COLOR_RGB2BGR)
    generated_show = generated_bgr.copy()
    for i, (x, y, w, h) in enumerate(updated_bboxes, start=1):
        cv2.rectangle(generated_show, (x, y), (x + w, y + h), (0, 0, 255), 1)
        cv2.putText(generated_show, str(i), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 150, 0), 1)
    generated_image_rgb = cv2.cvtColor(generated_show, cv2.COLOR_BGR2RGB)

    recognized_text, ocr_accuracy, ocr_binary = compute_ocr_metrics(
        generated_bgr, updated_bboxes, plate_number_gt, margin=2
    )

    # Prepare table
    table_data = [["Digit", "PSNR(dB)", "SSIM"]]
    for i, (p, s) in enumerate(zip(psnr_vals, ssim_vals), start=1):
        table_data.append([str(i), f"{p:.2f}", f"{s:.3f}"])
    transposed_table_data = list(zip(*table_data))

    # Plot in three rows
    fig2 = plt.figure(figsize=(11, 9))
    plt.subplot(3, 1, 1)
    plt.imshow(distorted_image_rgb)
    plt.title(f"Distorted (α={alpha}, β={beta})")
    plt.axis("off")
    plt.subplot(3, 1, 2)
    plt.imshow(original_image_rgb)
    plt.title("Original")
    plt.axis("off")
    plt.subplot(3, 1, 3)
    plt.imshow(generated_image_rgb)
    plt.title(
        f"Generated (GT={plate_number_gt}, Rec={recognized_text}, OCR Acc={ocr_accuracy*100:.1f}%, Bin={int(ocr_binary)})"
    )
    plt.axis("off")

    # Table underneath
    tbl = plt.table(cellText=transposed_table_data, cellLoc="center", loc="center", bbox=[0, -0.5, 1, 0.4])
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(12)

    plt.tight_layout()
    plt.show()


def format_coord(x, y):
    col = int(round(x))
    row = int(round(y))
    if 0 <= row < num_betas and 0 <= col < num_alphas:
        alpha = alpha_values[col]
        beta = beta_values[row]
        psnr_value = psnr_matrix_clipped[row, col]
        return (
            f"Alpha: {alpha:.0f}, Beta: {beta:.0f}, PSNR: {psnr_value:.2f} dB"
            if not np.isnan(psnr_value)
            else f"Alpha: {alpha:.0f}, Beta: {beta:.0f}, PSNR: N/A"
        )
    return "Alpha: N/A, Beta: N/A"


psnr_matrix_clipped = np.clip(psnr_matrix, None, 20)

current_metric = "PSNR"
fig, ax = plt.subplots(figsize=(11, 9))
plt.subplots_adjust(bottom=0.15)  # space for buttons

# Draw initial heatmap
im = ax.imshow(psnr_matrix_clipped, origin="lower", aspect="auto", cmap="viridis")
ax.set_title("Worst PSNR per Image (Minimum Digit PSNR)")
cbar = plt.colorbar(im, ax=ax, label="PSNR (dB)")
ax.set_xticks(range(0, num_alphas, 5))
ax.set_xticklabels(alpha_values[::5])
ax.set_yticks(range(0, num_betas, 5))
ax.set_yticklabels(beta_values[::5])
ax.set_xlabel("Alpha (degrees)")
ax.set_ylabel("Beta (degrees)")
ax.format_coord = format_coord  # Set the coordinate display format

# Define button positions
button_width = 0.1  # Button width
button_height = 0.05  # Button height
button_spacing = 0.02  # Space between buttons

# Compute x-coordinates for buttons
x_start = 0.2  # Starting x-position
y_position = 0.03
x_psnr = x_start
x_ssim = x_psnr + button_width + button_spacing
x_ocr_acc = x_ssim + button_width + button_spacing
x_ocr_bin = x_ocr_acc + button_width + button_spacing

# Add buttons
ax_psnr = plt.axes([x_psnr, y_position, button_width, button_height])
ax_ssim = plt.axes([x_ssim, y_position, button_width, button_height])
ax_ocr_acc = plt.axes([x_ocr_acc, y_position, button_width, button_height])
ax_ocr_bin = plt.axes([x_ocr_bin, y_position, button_width, button_height])

btn_psnr = Button(ax_psnr, "PSNR")
btn_ssim = Button(ax_ssim, "SSIM")
btn_ocr_acc = Button(ax_ocr_acc, "OCR Acc")
btn_ocr_bin = Button(ax_ocr_bin, "OCR Bin")


def update_heatmap(metric):
    global current_metric
    current_metric = metric
    ax.clear()

    if metric == "PSNR":
        data = psnr_matrix_clipped
        title = "Worst PSNR per Image (Minimum Digit PSNR)"
        cbar_label = "PSNR (dB)"
    elif metric == "SSIM":
        data = ssim_matrix
        title = "Worst SSIM per Image (Minimum Digit SSIM)"
        cbar_label = "SSIM"
    elif metric == "OCR_Accuracy":
        data = ocr_acc_matrix
        title = "OCR Accuracy"
        cbar_label = "OCR Acc"
    else:
        data = ocr_bin_matrix
        title = "OCR Binary Accuracy (1=All Correct)"
        cbar_label = "OCR Binary"

    # Update heatmap
    im = ax.imshow(data, origin="lower", aspect="auto", cmap="viridis")
    ax.set_title(title)
    ax.set_xticks(range(0, num_alphas, 5))
    ax.set_xticklabels(alpha_values[::5])
    ax.set_yticks(range(0, num_betas, 5))
    ax.set_yticklabels(beta_values[::5])
    ax.set_xlabel("Alpha (degrees)")
    ax.set_ylabel("Beta (degrees)")

    # Update colorbar
    cbar.mappable = im
    cbar.set_label(cbar_label)
    cbar.update_normal(im)

    fig.canvas.draw_idle()


def on_psnr_clicked(event):
    update_heatmap("PSNR")


def on_ssim_clicked(event):
    update_heatmap("SSIM")


def on_ocr_acc_clicked(event):
    update_heatmap("OCR_Accuracy")


def on_ocr_bin_clicked(event):
    update_heatmap("OCR_Binary")


btn_psnr.on_clicked(on_psnr_clicked)
btn_ssim.on_clicked(on_ssim_clicked)
btn_ocr_acc.on_clicked(on_ocr_acc_clicked)
btn_ocr_bin.on_clicked(on_ocr_bin_clicked)


# Connect the click event after setting up the entire figure
def on_click(event):
    if event.inaxes == ax:  # Ensure the click is within the heatmap axis
        x, y = event.xdata, event.ydata
        if x is None or y is None:
            return
        col = int(round(x))
        row = int(round(y))
        if 0 <= row < num_betas and 0 <= col < num_alphas:
            alpha = alpha_values[col]
            beta = beta_values[row]
            show_image_details_for(alpha, beta)


cid = fig.canvas.mpl_connect("button_press_event", on_click)  # Connect after all setups

plt.show()