In [None]:
#Heatmap Visualization of average PSNR, SSIM, and OCR Metrics with Interactive Analysis

import os
import json
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from tqdm import tqdm
import numpy as np
import mlflow
import mlflow.pytorch
import cv2
from pytorch_msssim import ssim
import pytesseract

# --------------------
# Configuration
# --------------------
data_dir = "data_test"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MLflow model load
mlflow.set_experiment('Unet')
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name('Unet')
runs = client.search_runs(
    experiment_ids=experiment.experiment_id,
    order_by=["attributes.start_time DESC"],
    max_results=1
)
run_id = runs[0].info.run_id
model_uri = f"runs:/{run_id}/model"
model = mlflow.pytorch.load_model(model_uri)
model.eval().to(device)
print(f"Model loaded from run {run_id} in experiment '{experiment.name}' successfully.")

def calculate_psnr(outputs, targets):
    mse = F.mse_loss(outputs, targets)
    if mse == 0:
        return float('inf')
    psnr = 10 * torch.log10(1 / mse)
    return psnr.item()

to_tensor = transforms.ToTensor()

def align_and_update_bboxes(original_np, reconstructed_np, digit_bboxes):
    psnr_values = []
    ssim_values = []
    updated_bboxes = []
    search_margin = 10

    for bbox in digit_bboxes:
        x, y, w, h = bbox
        original_digit = original_np[y:y+h, x:x+w, :]
        original_digit_gray = cv2.cvtColor((original_digit * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

        search_x1 = max(0, x - search_margin)
        search_y1 = max(0, y - search_margin)
        search_x2 = min(reconstructed_np.shape[1], x + w + search_margin)
        search_y2 = min(reconstructed_np.shape[0], y + h + search_margin)
        search_region = reconstructed_np[search_y1:search_y2, search_x1:search_x2, :]
        search_region_gray = cv2.cvtColor((search_region * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

        result = cv2.matchTemplate(search_region_gray, original_digit_gray, cv2.TM_CCOEFF_NORMED)
        _, _, _, max_loc = cv2.minMaxLoc(result)
        best_x, best_y = max_loc[0] + search_x1, max_loc[1] + search_y1
        updated_bboxes.append((best_x, best_y, w, h))

        aligned_digit = reconstructed_np[best_y:best_y+h, best_x:best_x+w, :]
        original_digit_tensor = torch.from_numpy(original_digit.transpose(2,0,1)).unsqueeze(0).to(device)
        aligned_digit_tensor = torch.from_numpy(aligned_digit.transpose(2,0,1)).unsqueeze(0).to(device)

        psnr_val = calculate_psnr(aligned_digit_tensor, original_digit_tensor)
        ssim_val = ssim(aligned_digit_tensor, original_digit_tensor, data_range=1.0, size_average=True).item()
        psnr_values.append(psnr_val)
        ssim_values.append(ssim_val)

    return psnr_values, ssim_values, updated_bboxes

def ocr_single_digit(image_bgr):
    """
    Recognize a single digit using Tesseract with single char mode and digit whitelist.
    """
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY)
    config = r'--oem 1 --psm 10 -c tessedit_char_whitelist=0123456789'
    text = pytesseract.image_to_string(thresh, config=config).strip()
    if len(text) == 1 and text.isdigit():
        return text
    return '?'

def compute_ocr_metrics(reconstructed_bgr, updated_bboxes, plate_number_gt):
    recognized_digits = []
    M = 16
    for (x, y, w, h) in updated_bboxes:
        x1 = max(0, x - M)
        y1 = max(0, y - M)
        x2 = min(reconstructed_bgr.shape[1], x + w + M)
        y2 = min(reconstructed_bgr.shape[0], y + h + M)
        digit_patch = reconstructed_bgr[y1:y2, x1:x2]
        recognized_digit = ocr_single_digit(digit_patch)
        recognized_digits.append(recognized_digit)

    recognized_text = "".join(recognized_digits)
    gt = plate_number_gt
    correct_digits = sum(1 for a, b in zip(gt, recognized_text) if a == b)
    ocr_accuracy = correct_digits / len(gt) if len(gt) > 0 else 0.0
    ocr_binary = 1.0 if recognized_text == gt else 0.0
    return recognized_text, ocr_accuracy, ocr_binary

# -----------------------------------
# Compute metrics for each (alpha,beta)
# -----------------------------------
metadata_files = [f for f in os.listdir(data_dir) if f.startswith('metadata_') and f.endswith('.json')]

psnr_dict_avg = {}
ssim_dict_avg = {}
ocr_acc_dict_avg = {}
ocr_bin_dict_avg = {}

for meta_file in tqdm(metadata_files, desc="Processing images", unit="image"):
    meta_path = os.path.join(data_dir, meta_file)
    with open(meta_path, 'r') as f:
        metadata = json.load(f)

    alpha, beta = metadata['alpha'], metadata['beta']
    digit_bboxes = metadata['digit_bboxes']
    plate_number_gt = metadata['plate_number']

    index = metadata['index']
    original_path = os.path.join(data_dir, f"original_{index}.png")
    distorted_path = os.path.join(data_dir, f"distorted_{index}.png")

    if not (os.path.exists(original_path) and os.path.exists(distorted_path)):
        continue

    original_img = to_tensor(Image.open(original_path).convert('RGB')).unsqueeze(0).to(device)
    distorted_img = to_tensor(Image.open(distorted_path).convert('RGB')).unsqueeze(0).to(device)

    with torch.no_grad():
        reconstructed_img = model(distorted_img)
        reconstructed_img = torch.clamp(reconstructed_img, 0, 1)

    original_np = original_img.squeeze(0).permute(1,2,0).cpu().numpy()
    reconstructed_np = reconstructed_img.squeeze(0).permute(1,2,0).cpu().numpy()

    psnr_per_number, ssim_per_number, updated_bboxes = align_and_update_bboxes(original_np, reconstructed_np, digit_bboxes)
    avg_psnr = np.mean(psnr_per_number) if psnr_per_number else 0.0
    avg_ssim = np.mean(ssim_per_number) if ssim_per_number else 0.0

    image_bgr = (reconstructed_np * 255).astype(np.uint8)
    image_bgr = cv2.cvtColor(image_bgr, cv2.COLOR_RGB2BGR)
    recognized_text, ocr_accuracy, ocr_binary = compute_ocr_metrics(image_bgr, updated_bboxes, plate_number_gt)

    if (alpha, beta) not in psnr_dict_avg:
        psnr_dict_avg[(alpha, beta)] = []
        ssim_dict_avg[(alpha, beta)] = []
        ocr_acc_dict_avg[(alpha, beta)] = []
        ocr_bin_dict_avg[(alpha, beta)] = []

    psnr_dict_avg[(alpha, beta)].append(avg_psnr)
    ssim_dict_avg[(alpha, beta)].append(avg_ssim)
    ocr_acc_dict_avg[(alpha, beta)].append(ocr_accuracy)
    ocr_bin_dict_avg[(alpha, beta)].append(ocr_binary)

# Average if multiple images per angle (if any)
for key in psnr_dict_avg:
    psnr_dict_avg[key] = np.mean(psnr_dict_avg[key])
    ssim_dict_avg[key] = np.mean(ssim_dict_avg[key])
    ocr_acc_dict_avg[key] = np.mean(ocr_acc_dict_avg[key])
    ocr_bin_dict_avg[key] = np.mean(ocr_bin_dict_avg[key])

alpha_values = sorted(set(a for (a, b) in psnr_dict_avg.keys()))
beta_values = sorted(set(b for (a, b) in psnr_dict_avg.keys()))
num_alphas, num_betas = len(alpha_values), len(beta_values)

def create_matrix_from_dict(data_dict):
    mat = np.full((num_betas, num_alphas), np.nan)
    alpha_to_index = {val: i for i, val in enumerate(alpha_values)}
    beta_to_index = {val: i for i, val in enumerate(beta_values)}
    for (a, b), val in data_dict.items():
        mat[beta_to_index[b], alpha_to_index[a]] = val
    return mat

psnr_matrix_avg = create_matrix_from_dict(psnr_dict_avg)
ssim_matrix_avg = create_matrix_from_dict(ssim_dict_avg)
ocr_acc_matrix = create_matrix_from_dict(ocr_acc_dict_avg)
ocr_bin_matrix = create_matrix_from_dict(ocr_bin_dict_avg)

# -----------------------------------
# Interactive Plot with Buttons
# -----------------------------------
current_metric = 'PSNR'
fig, ax = plt.subplots(figsize=(10,8))
plt.subplots_adjust(bottom=0.2)  # space for buttons

im = ax.imshow(psnr_matrix_avg, origin='lower', aspect='auto', cmap="viridis")
ax.set_title("Average PSNR per Digit")
cb = plt.colorbar(im, ax=ax, label='PSNR (dB)')
ax.set_xticks(range(0, num_alphas, 5))
ax.set_xticklabels(alpha_values[::5])
ax.set_yticks(range(0, num_betas, 5))
ax.set_yticklabels(beta_values[::5])
ax.set_xlabel("Alpha (degrees)")
ax.set_ylabel("Beta (degrees)")

def format_coord(x, y):
    col = int(round(x))
    row = int(round(y))
    if 0 <= row < num_betas and 0 <= col < num_alphas:
        alpha = alpha_values[col]
        beta = beta_values[row]
        if current_metric == 'PSNR':
            val = psnr_matrix_avg[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, PSNR: {val:.2f} dB" if not np.isnan(val) else "N/A"
        elif current_metric == 'SSIM':
            val = ssim_matrix_avg[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, SSIM: {val:.3f}" if not np.isnan(val) else "N/A"
        elif current_metric == 'OCR_Accuracy':
            val = ocr_acc_matrix[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, OCR Acc: {val*100:.2f}%" if not np.isnan(val) else "N/A"
        elif current_metric == 'OCR_Binary':
            val = ocr_bin_matrix[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, OCR Binary: {val:.0f}" if not np.isnan(val) else "N/A"
    return "N/A"

ax.format_coord = format_coord

def show_image_details_for(alpha, beta):
    # Re-run the detailed view logic
    # Find the file again
    found_file = None
    for meta_file in metadata_files:
        meta_path = os.path.join(data_dir, meta_file)
        with open(meta_path, 'r') as f:
            metadata = json.load(f)
        if metadata['alpha'] == alpha and metadata['beta'] == beta:
            found_file = metadata
            break

    if found_file is None:
        print("No images found for that angle.")
        return

    found_file['digit_bboxes'].sort(key=lambda bbox: bbox[0])
    index = found_file['index']
    plate_number_gt = found_file['plate_number']
    original_path = os.path.join(data_dir, f"original_{index}.png")
    distorted_path = os.path.join(data_dir, f"distorted_{index}.png")

    original_img = to_tensor(Image.open(original_path).convert('RGB')).unsqueeze(0).to(device)
    distorted_img = to_tensor(Image.open(distorted_path).convert('RGB')).unsqueeze(0).to(device)
    with torch.no_grad():
        reconstructed_tensor = model(distorted_img)
        reconstructed_tensor = torch.clamp(reconstructed_tensor, 0.0, 1.0)

    original_np = original_img.squeeze(0).permute(1,2,0).cpu().numpy()
    reconstructed_np = reconstructed_tensor.squeeze(0).permute(1,2,0).cpu().numpy()

    psnr_vals, ssim_vals, updated_bboxes = align_and_update_bboxes(original_np, reconstructed_np, found_file['digit_bboxes'])

    reconstructed_bgr = (reconstructed_np*255).astype(np.uint8)
    reconstructed_bgr = cv2.cvtColor(reconstructed_bgr, cv2.COLOR_RGB2BGR)
    recognized_text, ocr_accuracy, ocr_binary = compute_ocr_metrics(reconstructed_bgr, updated_bboxes, plate_number_gt)

    original_image_cv = cv2.imread(original_path)
    for i, bbox in enumerate(found_file['digit_bboxes'], start=1):
        x, y, w, h = bbox
        cv2.rectangle(original_image_cv, (x, y), (x+w, y+h), (0,0,255),1)
        cv2.putText(original_image_cv, str(i), (x,y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(0,150,0),1)
    original_image_rgb = cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB)

    reconstructed_show = reconstructed_bgr.copy()
    for i,bbox in enumerate(updated_bboxes, start=1):
        x,y,w,h = bbox
        cv2.rectangle(reconstructed_show, (x,y),(x+w,y+h),(0,0,255),1)
        cv2.putText(reconstructed_show,str(i),(x,y-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,150,0),1)
    reconstructed_image_rgb = cv2.cvtColor(reconstructed_show, cv2.COLOR_BGR2RGB)

    table_data = [["Digit","PSNR(dB)","SSIM"]]
    for i,(p,s) in enumerate(zip(psnr_vals, ssim_vals), start=1):
        table_data.append([str(i), f"{p:.2f}", f"{s:.3f}"])
    transposed_table_data = list(zip(*table_data))

    fig2 = plt.figure(figsize=(14,7))
    plt.subplot(2,1,1)
    plt.imshow(original_image_rgb)
    plt.title(f'Original Image (Alpha={alpha}, Beta={beta})')
    plt.axis('off')

    plt.subplot(2,1,2)
    plt.imshow(reconstructed_image_rgb)
    plt.title(f'Reconstructed Image\nGT: {plate_number_gt}, Rec: {recognized_text}, OCR Acc: {ocr_accuracy*100:.2f}%, Binary: {ocr_binary}')
    plt.axis('off')

    table = plt.table(cellText=transposed_table_data,
                      cellLoc='center',
                      loc='center',
                      bbox=[0,-0.55,1,0.4])
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    plt.tight_layout()
    plt.show()

def on_click(event):
    if event.inaxes == ax:
        x, y = event.xdata, event.ydata
        if x is None or y is None:
            return
        col = int(round(x))
        row = int(round(y))
        if 0 <= row < num_betas and 0 <= col < num_alphas:
            alpha = alpha_values[col]
            beta = beta_values[row]
            show_image_details_for(alpha, beta)

cid = fig.canvas.mpl_connect('button_press_event', on_click)

# Add buttons to switch between metrics
ax_psnr = plt.axes([0.1, 0.05, 0.1, 0.05])
ax_ssim = plt.axes([0.22, 0.05, 0.1, 0.05])
ax_ocr_acc = plt.axes([0.34, 0.05, 0.12, 0.05])
ax_ocr_bin = plt.axes([0.48, 0.05, 0.1, 0.05])

btn_psnr = Button(ax_psnr, 'PSNR')
btn_ssim = Button(ax_ssim, 'SSIM')
btn_ocr_acc = Button(ax_ocr_acc, 'OCR Acc')
btn_ocr_bin = Button(ax_ocr_bin, 'OCR Bin')

def update_heatmap(metric):
    global current_metric
    current_metric = metric
    ax.clear()
    if metric == 'PSNR':
        data = psnr_matrix_avg
        title = "Average PSNR per Digit"
        cbar_label = "PSNR (dB)"
    elif metric == 'SSIM':
        data = ssim_matrix_avg
        title = "Average SSIM per Digit"
        cbar_label = "SSIM"
    elif metric == 'OCR_Accuracy':
        data = ocr_acc_matrix
        title = "Average OCR Accuracy"
        cbar_label = "OCR Acc"
    else:
        data = ocr_bin_matrix
        title = "OCR Binary (1=All Correct)"
        cbar_label = "OCR Binary"

    im = ax.imshow(data, origin='lower', aspect='auto', cmap='viridis')
    ax.set_title(title)
    ax.set_xticks(range(0,num_alphas,5))
    ax.set_xticklabels(alpha_values[::5])
    ax.set_yticks(range(0,num_betas,5))
    ax.set_yticklabels(beta_values[::5])
    ax.set_xlabel("Alpha (degrees)")
    ax.set_ylabel("Beta (degrees)")
    ax.format_coord = format_coord
    fig.colorbar(im, ax=ax, label=cbar_label)
    fig.canvas.draw_idle()

def on_psnr_clicked(event):
    update_heatmap('PSNR')

def on_ssim_clicked(event):
    update_heatmap('SSIM')

def on_ocr_acc_clicked(event):
    update_heatmap('OCR_Accuracy')

def on_ocr_bin_clicked(event):
    update_heatmap('OCR_Binary')

btn_psnr.on_clicked(on_psnr_clicked)
btn_ssim.on_clicked(on_ssim_clicked)
btn_ocr_acc.on_clicked(on_ocr_acc_clicked)
btn_ocr_bin.on_clicked(on_ocr_bin_clicked)

current_metric = 'PSNR'  # default

plt.tight_layout()
plt.show()


###############################################################################################################################################

In [None]:
#Heatmap Visualization of worst PSNR, worst SSIM, and OCR Metrics with Interactive Analysis

import os
import json
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from tqdm import tqdm
import numpy as np
import mlflow
import mlflow.pytorch
import cv2
from pytorch_msssim import ssim
import pytesseract

# --------------------
# Configuration
# --------------------
data_dir = "data/full_grid"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MLflow model load
mlflow.set_experiment('Unet')
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name('Unet')
runs = client.search_runs(
    experiment_ids=experiment.experiment_id,
    order_by=["attributes.start_time DESC"],
    max_results=1
)
run_id = runs[0].info.run_id
model_uri = f"runs:/{run_id}/model"
model = mlflow.pytorch.load_model(model_uri)
model.eval().to(device)
print(f"Model loaded from run {run_id} in experiment '{experiment.name}' successfully.")

# --------------------
# Functions 
# --------------------

def calculate_psnr(outputs, targets):
    mse = F.mse_loss(outputs, targets)
    if mse == 0:
        return float('inf')
    psnr = 10 * torch.log10(1 / mse)
    return psnr.item()

def align_and_update_bboxes(original_np, generated_np, digit_bboxes):
    psnr_values = []
    ssim_values = []
    updated_bboxes = []
    search_margin = 10

    for bbox in digit_bboxes:
        x, y, w, h = bbox
        original_digit = original_np[y:y+h, x:x+w, :]
        original_digit_gray = cv2.cvtColor((original_digit * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

        # Define search window
        search_x1 = max(0, x - search_margin)
        search_y1 = max(0, y - search_margin)
        search_x2 = min(generated_np.shape[1], x + w + search_margin)
        search_y2 = min(generated_np.shape[0], y + h + search_margin)
        search_region = generated_np[search_y1:search_y2, search_x1:search_x2, :]
        search_region_gray = cv2.cvtColor((search_region * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

        # Template matching
        result = cv2.matchTemplate(search_region_gray, original_digit_gray, cv2.TM_CCOEFF_NORMED)
        _, _, _, max_loc = cv2.minMaxLoc(result)
        best_x, best_y = max_loc[0] + search_x1, max_loc[1] + search_y1
        updated_bboxes.append((best_x, best_y, w, h))

        # Compute PSNR and SSIM
        aligned_digit = generated_np[best_y:best_y+h, best_x:best_x+w, :]
        original_digit_tensor = torch.from_numpy(original_digit.transpose(2,0,1)).unsqueeze(0).to(device)
        aligned_digit_tensor = torch.from_numpy(aligned_digit.transpose(2,0,1)).unsqueeze(0).to(device)

        psnr_val = calculate_psnr(aligned_digit_tensor, original_digit_tensor)
        ssim_val = ssim(aligned_digit_tensor, original_digit_tensor, data_range=1.0, size_average=True).item()
        psnr_values.append(psnr_val)
        ssim_values.append(ssim_val)

    return psnr_values, ssim_values, updated_bboxes

def ocr_single_digit(image_bgr):
    """
    Recognize a single digit using Tesseract with single char mode and digit whitelist.
    """
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY)
    config = r'--oem 1 --psm 10 -c tessedit_char_whitelist=0123456789'
    text = pytesseract.image_to_string(thresh, config=config).strip()
    if len(text) == 1 and text.isdigit():
        return text
    return '?'

def compute_ocr_metrics(image_bgr, updated_bboxes, plate_number_gt, margin):
    recognized_digits = []
    for (x, y, w, h) in updated_bboxes:
        x1 = max(0, x - margin)
        y1 = max(0, y - margin)
        x2 = min(image_bgr.shape[1], x + w + margin)
        y2 = min(image_bgr.shape[0], y + h + margin)
        
        digit_patch = image_bgr[y1:y2, x1:x2]
        recognized_digit = ocr_single_digit(digit_patch)
        recognized_digits.append(recognized_digit)

    recognized_text = "".join(recognized_digits)
    gt = plate_number_gt
    correct_digits = sum(1 for a, b in zip(gt, recognized_text) if a == b)
    ocr_accuracy = correct_digits / len(gt) if len(gt) > 0 else 0.0
    ocr_binary = 1.0 if recognized_text == gt else 0.0
    return recognized_text, ocr_accuracy, ocr_binary

# --------------------------------------
# Compute metrics for each (alpha, beta)
# --------------------------------------
metadata_files = [f for f in os.listdir(data_dir) if f.startswith('metadata_') and f.endswith('.json')]

psnr_dict_worst = {}
ssim_dict_worst = {}
ocr_acc_dict_avg = {}
ocr_bin_dict_avg = {}

to_tensor = transforms.ToTensor()

for meta_file in tqdm(metadata_files, desc="Processing images", unit="image"):
    meta_path = os.path.join(data_dir, meta_file)
    with open(meta_path, 'r') as f:
        metadata = json.load(f)

    alpha, beta = metadata['alpha'], metadata['beta']
    digit_bboxes = metadata['digit_bboxes']
    plate_number_gt = metadata['plate_number']

    index = metadata['index']
    original_path = os.path.join(data_dir, f"original_{index}.png")
    distorted_path = os.path.join(data_dir, f"distorted_{index}.png")

    if not (os.path.exists(original_path) and os.path.exists(distorted_path)):
        continue

    original_img = to_tensor(Image.open(original_path).convert('RGB')).unsqueeze(0).to(device)
    distorted_img = to_tensor(Image.open(distorted_path).convert('RGB')).unsqueeze(0).to(device)

    with torch.no_grad():
        generated_img = model(distorted_img)
        generated_img = torch.clamp(generated_img, 0.0, 1.0)

    original_np = original_img.squeeze(0).permute(1,2,0).cpu().numpy()
    generated_np = generated_img.squeeze(0).permute(1,2,0).cpu().numpy()

    psnr_per_number, ssim_per_number, updated_bboxes = align_and_update_bboxes(original_np, generated_np, digit_bboxes)

    # Take the worst (minimum) PSNR and SSIM values across all digits for this image
    worst_psnr = np.min(psnr_per_number) if psnr_per_number else 0.0
    worst_ssim = np.min(ssim_per_number) if ssim_per_number else 0.0

    image_bgr = (generated_np * 255).astype(np.uint8)
    image_bgr = cv2.cvtColor(image_bgr, cv2.COLOR_RGB2BGR)
    recognized_text, ocr_accuracy, ocr_binary = compute_ocr_metrics(image_bgr, updated_bboxes, plate_number_gt, margin = 2)

    if (alpha, beta) not in psnr_dict_worst:
        psnr_dict_worst[(alpha, beta)] = []
        ssim_dict_worst[(alpha, beta)] = []
        ocr_acc_dict_avg[(alpha, beta)] = []
        ocr_bin_dict_avg[(alpha, beta)] = []

    psnr_dict_worst[(alpha, beta)].append(worst_psnr)
    ssim_dict_worst[(alpha, beta)].append(worst_ssim)
    ocr_acc_dict_avg[(alpha, beta)].append(ocr_accuracy)
    ocr_bin_dict_avg[(alpha, beta)].append(ocr_binary)

alpha_values = sorted(set(a for (a, b) in psnr_dict_worst.keys()))
beta_values = sorted(set(b for (a, b) in psnr_dict_worst.keys()))
num_alphas, num_betas = len(alpha_values), len(beta_values)

def create_matrix_from_dict(data_dict):
    mat = np.full((num_betas, num_alphas), np.nan)
    alpha_to_index = {val: i for i, val in enumerate(alpha_values)}
    beta_to_index = {val: i for i, val in enumerate(beta_values)}
    for (a, b), val_list  in data_dict.items():
        val = np.min(val_list) if val_list else np.nan
        mat[beta_to_index[b], alpha_to_index[a]] = val
    return mat

psnr_matrix = create_matrix_from_dict(psnr_dict_worst)
ssim_matrix = create_matrix_from_dict(ssim_dict_worst)
ocr_acc_matrix = create_matrix_from_dict(ocr_acc_dict_avg)
ocr_bin_matrix = create_matrix_from_dict(ocr_bin_dict_avg)

In [None]:
# Interactive Plot with Buttons
current_metric = 'PSNR'
fig, ax = plt.subplots(figsize=(10,8))
plt.subplots_adjust(bottom=0.2)  # space for buttons

im = ax.imshow(psnr_matrix, origin='lower', aspect='auto', cmap="viridis")
ax.set_title("Worst PSNR per Image (Minimum Digit PSNR)")
cb = plt.colorbar(im, ax=ax, label='PSNR (dB)')
ax.set_xticks(range(0, num_alphas, 5))
ax.set_xticklabels(alpha_values[::5])
ax.set_yticks(range(0, num_betas, 5))
ax.set_yticklabels(beta_values[::5])
ax.set_xlabel("Alpha (degrees)")
ax.set_ylabel("Beta (degrees)")

def format_coord(x, y):
    col = int(round(x))
    row = int(round(y))
    if 0 <= row < num_betas and 0 <= col < num_alphas:
        alpha = alpha_values[col]
        beta = beta_values[row]
        if current_metric == 'PSNR':
            val = psnr_matrix[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, Worst PSNR: {val:.2f} dB" if not np.isnan(val) else "N/A"
        elif current_metric == 'SSIM':
            val = ssim_matrix[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, Worst SSIM: {val:.3f}" if not np.isnan(val) else "N/A"
        elif current_metric == 'OCR_Accuracy':
            val = ocr_acc_matrix[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, OCR Acc: {val*100:.2f}%" if not np.isnan(val) else "N/A"
        elif current_metric == 'OCR_Binary':
            val = ocr_bin_matrix[row, col]
            return f"Alpha: {alpha}, Beta: {beta}, OCR Binary: {val:.0f}" if not np.isnan(val) else "N/A"
    return "N/A"

ax.format_coord = format_coord

def show_image_details_for(alpha, beta):
    # Re-run the detailed view logic
    found_file = None
    for meta_file in metadata_files:
        meta_path = os.path.join(data_dir, meta_file)
        with open(meta_path, 'r') as f:
            metadata = json.load(f)
        if metadata['alpha'] == alpha and metadata['beta'] == beta:
            found_file = metadata
            break

    if found_file is None:
        print("No images found for that angle.")
        return

    found_file['digit_bboxes'].sort(key=lambda bbox: bbox[0])
    index = found_file['index']
    plate_number_gt = found_file['plate_number']
    original_path = os.path.join(data_dir, f"original_{index}.png")
    distorted_path = os.path.join(data_dir, f"distorted_{index}.png")

    original_img = to_tensor(Image.open(original_path).convert('RGB')).unsqueeze(0).to(device)
    distorted_img = to_tensor(Image.open(distorted_path).convert('RGB')).unsqueeze(0).to(device)
    with torch.no_grad():
        generated_tensor = model(distorted_img)
        generated_tensor = torch.clamp(generated_tensor, 0.0, 1.0)

    original_np = original_img.squeeze(0).permute(1,2,0).cpu().numpy()
    generated_np = generated_tensor.squeeze(0).permute(1,2,0).cpu().numpy()

    psnr_vals, ssim_vals, updated_bboxes = align_and_update_bboxes(original_np, generated_np, found_file['digit_bboxes'])

    generated_bgr = (generated_np*255).astype(np.uint8)
    generated_bgr = cv2.cvtColor(generated_bgr, cv2.COLOR_RGB2BGR)
    recognized_text, ocr_accuracy, ocr_binary = compute_ocr_metrics(generated_bgr, updated_bboxes, plate_number_gt, margin = 2)

    original_image_cv = cv2.imread(original_path)
    for i, bbox in enumerate(found_file['digit_bboxes'], start=1):
        x, y, w, h = bbox
        cv2.rectangle(original_image_cv, (x, y), (x+w, y+h), (0,0,255),1)
        cv2.putText(original_image_cv, str(i), (x,y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(0,150,0),1)
    original_image_rgb = cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB)

    generated_show = generated_bgr.copy()
    for i,bbox in enumerate(updated_bboxes, start=1):
        x,y,w,h = bbox
        cv2.rectangle(generated_show, (x,y),(x+w,y+h),(0,0,255),1)
        cv2.putText(generated_show,str(i),(x,y-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,150,0),1)
    generated_image_rgb = cv2.cvtColor(generated_show, cv2.COLOR_BGR2RGB)

    table_data = [["Digit","PSNR(dB)","SSIM"]]
    for i,(p,s) in enumerate(zip(psnr_vals, ssim_vals), start=1):
        table_data.append([str(i), f"{p:.2f}", f"{s:.3f}"])
    transposed_table_data = list(zip(*table_data))

    fig2 = plt.figure(figsize=(14,7))
    plt.subplot(2,1,1)
    plt.imshow(original_image_rgb)
    plt.title(f'Original Image (Alpha={alpha}, Beta={beta})')
    plt.axis('off')

    plt.subplot(2,1,2)
    plt.imshow(generated_image_rgb)
    plt.title(f'Generated Image\nGT: {plate_number_gt}, Rec: {recognized_text}, OCR Acc: {ocr_accuracy*100:.2f}%, Binary: {ocr_binary}')
    plt.axis('off')

    table = plt.table(cellText=transposed_table_data,
                      cellLoc='center',
                      loc='center',
                      bbox=[0,-0.55,1,0.4])
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    plt.tight_layout()
    plt.show()

def on_click(event):
    if event.inaxes == ax:
        x, y = event.xdata, event.ydata
        if x is None or y is None:
            return
        col = int(round(x))
        row = int(round(y))
        if 0 <= row < num_betas and 0 <= col < num_alphas:
            alpha = alpha_values[col]
            beta = beta_values[row]
            show_image_details_for(alpha, beta)

cid = fig.canvas.mpl_connect('button_press_event', on_click)

# Add buttons to switch between metrics
ax_psnr = plt.axes([0.1, 0.05, 0.1, 0.05])
ax_ssim = plt.axes([0.22, 0.05, 0.1, 0.05])
ax_ocr_acc = plt.axes([0.34, 0.05, 0.12, 0.05])
ax_ocr_bin = plt.axes([0.48, 0.05, 0.1, 0.05])

btn_psnr = Button(ax_psnr, 'PSNR')
btn_ssim = Button(ax_ssim, 'SSIM')
btn_ocr_acc = Button(ax_ocr_acc, 'OCR Acc')
btn_ocr_bin = Button(ax_ocr_bin, 'OCR Bin')

def update_heatmap(metric):
    global current_metric
    current_metric = metric
    ax.clear()
    if metric == 'PSNR':
        data = psnr_matrix
        title = "Worst PSNR per Image (Minimum Digit PSNR)"
        cbar_label = "PSNR (dB)"
    elif metric == 'SSIM':
        data = ssim_matrix
        title = "Worst SSIM per Image (Minimum Digit SSIM)"
        cbar_label = "SSIM"
    elif metric == 'OCR_Accuracy':
        data = ocr_acc_matrix
        title = "Average OCR Accuracy"
        cbar_label = "OCR Acc"
    else:
        data = ocr_bin_matrix
        title = "OCR Binary (1=All Correct)"
        cbar_label = "OCR Binary"

    im = ax.imshow(data, origin='lower', aspect='auto', cmap='viridis')
    ax.set_title(title)
    ax.set_xticks(range(0,num_alphas,5))
    ax.set_xticklabels(alpha_values[::5])
    ax.set_yticks(range(0,num_betas,5))
    ax.set_yticklabels(beta_values[::5])
    ax.set_xlabel("Alpha (degrees)")
    ax.set_ylabel("Beta (degrees)")
    ax.format_coord = format_coord
    fig.colorbar(im, ax=ax, label=cbar_label)
    fig.canvas.draw_idle()

def on_psnr_clicked(event):
    update_heatmap('PSNR')

def on_ssim_clicked(event):
    update_heatmap('SSIM')

def on_ocr_acc_clicked(event):
    update_heatmap('OCR_Accuracy')

def on_ocr_bin_clicked(event):
    update_heatmap('OCR_Binary')

btn_psnr.on_clicked(on_psnr_clicked)
btn_ssim.on_clicked(on_ssim_clicked)
btn_ocr_acc.on_clicked(on_ocr_acc_clicked)
btn_ocr_bin.on_clicked(on_ocr_bin_clicked)

current_metric = 'PSNR'  # default

plt.tight_layout()
plt.show()

###############################################################################################################################################