In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from bfm_finetune.metrics import compute_ssim_metric, compute_spc

In [None]:
def create_image_with_black_spot(H, W, center, radius):
    """
    Creates a 2D numpy array of shape [H, W] with a white background (value=1)
    and a black circular spot (value=0) at the specified center and radius.
    """
    img = np.ones((H, W), dtype=np.float32)
    Y, X = np.ogrid[:H, :W]
    dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
    mask = dist <= radius
    img[mask] = 0.0
    return img

H, W = 152, 320
radius = 20

gt_center = (W / 2, H / 2)  # (cx, cy)
gt_img = create_image_with_black_spot(H, W, gt_center, radius)

# Prediction: Black spot shifted to the right by X pixels.
pred_center = (W / 2 + 40, H / 2)
pred_img = create_image_with_black_spot(H, W, pred_center, radius)

gt_tensor = torch.tensor(gt_img).unsqueeze(0).unsqueeze(0).unsqueeze(0)   # shape: [1, 1, 1, H, W]
pred_tensor = torch.tensor(pred_img).unsqueeze(0).unsqueeze(0).unsqueeze(0)  # shape: [1, 1, 1, H, W]

In [None]:
ssim_value = compute_ssim_metric(pred_tensor, gt_tensor)
spc_value = compute_spc(pred_tensor, gt_tensor)

print("SSIM between ground truth and prediction:", ssim_value)
print("SPC between ground truth and prediction:", spc_value)

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(gt_img, cmap='gray')
plt.title("Ground Truth (Black Spot at Center)")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(pred_img, cmap='gray')
plt.title("Prediction (Black Spot shifted Right)")
plt.axis("off")

plt.show()