In [24]:
import numpy as np
import torch
import torch.nn.functional as F
import tifffile as tiff
import piq
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from utils import log

# def load_volume(path):
#     return tiff.imread(path).astype(np.float32)

# def normalize(volume):
#     return volume / 65535.0  # normalize 16-bit to [0, 1]

def evaluate_volume_metrics(gt, pred, mask):
    from skimage.metrics import structural_similarity as skimage_ssim

    assert gt.shape == pred.shape, "Volume shapes must match"
    gt = (torch.from_numpy(gt).float() / 65535.0).numpy()
    pred = (torch.from_numpy(pred).float() / 65535.0).numpy()
    mask = mask.astype(bool)

    D, H, W = gt.shape
    l1_vals = []
    psnr_vals = []
    ssim_vals = {7: [], 11: [], 17: []}
    ncc_vals = {7: [], 11: [], 17: []}

    def local_ncc(a, b, window):
        a, b = torch.from_numpy(a), torch.from_numpy(b)
        pad = window // 2
        a2 = a**2
        b2 = b**2
        ab = a * b
        kernel = torch.ones((1, 1, window, window))
        a_sum = F.conv2d(a.view(1,1,H,W), kernel, padding=pad)
        b_sum = F.conv2d(b.view(1,1,H,W), kernel, padding=pad)
        ab_sum = F.conv2d(ab.view(1,1,H,W), kernel, padding=pad)
        a2_sum = F.conv2d(a2.view(1,1,H,W), kernel, padding=pad)
        b2_sum = F.conv2d(b2.view(1,1,H,W), kernel, padding=pad)

        win_size = window ** 2
        a_mean = a_sum / win_size
        b_mean = b_sum / win_size
        num = ab_sum - a_mean * b_sum - b_mean * a_sum + a_mean * b_mean * win_size
        denom = (a2_sum - a_mean * a_sum - a_mean * a_sum + a_mean * a_mean * win_size).clamp(min=1e-6).sqrt() * \
                (b2_sum - b_mean * b_sum - b_mean * b_sum + b_mean * b_mean * win_size).clamp(min=1e-6).sqrt()
        return (num / denom).clamp(-1, 1).mean().item()

    num_masked = 0

    for i in range(D):
        # if not mask[i]:
        if not mask[i].any():
            continue
        g = gt[i]
        p = pred[i]

        l1_vals.append(np.mean(np.abs(p - g)))
        psnr = piq.psnr(torch.tensor(p).unsqueeze(0).unsqueeze(0),
                        torch.tensor(g).unsqueeze(0).unsqueeze(0),
                        data_range=1.0).item()
        psnr_vals.append(psnr)

        for w in [7, 11, 17]:
            try:
                ssim = skimage_ssim(p, g, data_range=1.0, win_size=w)
                ssim_vals[w].append(ssim)
            except Exception as e:
                print(f"SSIM failed on slice {i} with window {w}: {e}")
                ssim_vals[w].append(float('nan'))

            ncc_vals[w].append(local_ncc(p, g, window=w))

        num_masked += 1

    if num_masked == 0:
        return {
            "L1": None,
            "PSNR": None,
            "MeanIntensityError": round(float(np.abs(pred.mean() - gt.mean())), 4),
            "SSIM": {w: None for w in [7, 11, 17]},
            "NCC": {w: None for w in [7, 11, 17]},
            "Note": "No corrupted slices to evaluate"
        }

    return {
        "L1": round(sum(l1_vals) / num_masked, 4),
        "PSNR": round(sum(psnr_vals) / num_masked, 4),
        "MeanIntensityError": round(float(np.abs(pred.mean() - gt.mean())), 4),
        "SSIM": {w: round(np.nanmean(ssim_vals[w]), 4) for w in ssim_vals},
        "NCC": {w: round(sum(ncc_vals[w]) / num_masked, 4) for w in ncc_vals}
    }


def run_comparison(gt_path, predicted_path_1, predicted_path_2, mask_path):
    # Load volumes
    gt = tiff.imread(gt_path)
    pred1 = tiff.imread(predicted_path_1)
    pred2 = tiff.imread(predicted_path_2)
    mask = tiff.imread(mask_path)  # mask expected as uint8 or uint16 0/1 format

    # Ensure shapes match
    assert gt.shape == pred1.shape == pred2.shape == mask.shape, "Volumes and mask must match in shape!"

    metrics_1 = evaluate_volume_metrics(gt, pred1, mask)

    log(f"Volume Metrics for Predicted Volume 1:")
    log(f" - L1 Loss: {metrics_1['L1']:.4f}")
    log(f" - PSNR: {metrics_1['PSNR']}")
    log(f" - Mean Intensity Diff: {metrics_1['MeanIntensityError']}")
    for w in [7, 11, 17]:
        log(f" - SSIM (win={w}): {metrics_1['SSIM'][w]}")
    for w in [7, 11, 17]:
        log(f" - NCC (win={w}):  {metrics_1['NCC'][w]}")
    log("-" * 40)
    print('\n')


    metrics_2 = evaluate_volume_metrics(gt, pred2, mask)

    log(f"Volume Metrics for Predicted Volume 2:")
    log(f" - L1 Loss: {metrics_2['L1']:.4f}")
    log(f" - PSNR: {metrics_2['PSNR']}")
    log(f" - Mean Intensity Diff: {metrics_2['MeanIntensityError']}")
    for w in [7, 11, 17]:
        log(f" - SSIM (win={w}): {metrics_2['SSIM'][w]}")
    for w in [7, 11, 17]:
        log(f" - NCC (win={w}):  {metrics_2['NCC'][w]}")
    log("-" * 40)
    print('\n')


In [25]:
base_path = "/media/admin/Expansion/Mosaic_Data_for_Ipeks_Group/OCT_Inpainting_Testing/"

# run_comparison(
#     gt_path=f"{base_path}1.1_OCTA_Vol1_Processed_Cropped_gt.tif",
#     predicted_path_1=f"{base_path}1.1_OCTA_Vol1_Processed_Cropped_inpainted_2p5DUNet_fold4_v2.tif",
#     predicted_path_2=f"{base_path}1.1_OCTA_Vol1_Processed_Cropped_inpainted_2p5DUNet_fold4_v4.tif",
#     mask_path=f"{base_path}1.1_OCTA_Vol1_Processed_Cropped_mask.tif"
# )

run_comparison(
    gt_path=f"{base_path}1.2_OCTA_Vol2_Processed_Cropped_gt.tif",
    predicted_path_1=f"{base_path}1.2_OCTA_Vol2_Processed_Cropped_inpainted_2p5DUNet_fold1_v2.tif",
    predicted_path_2=f"{base_path}1.2_OCTA_Vol2_Processed_Cropped_inpainted_2p5DUNet_fold1_v2_brightcorr.tif",
    mask_path=f"{base_path}1.2_OCTA_Vol2_Processed_Cropped_mask.tif"
)

# run_comparison(
#     gt_path=f"{base_path}1.4_OCTA_Vol1_Processed_Cropped_gt.tif",
#     predicted_path_1=f"{base_path}1.4_OCTA_Vol1_Processed_Cropped_inpainted_2p5DUNet_fold3_v2.tif",
#     predicted_path_2=f"{base_path}1.4_OCTA_Vol1_Processed_Cropped_inpainted_2p5DUNet_fold3_v4.tif",
#     mask_path=f"{base_path}1.4_OCTA_Vol1_Processed_Cropped_mask.tif"
# )

# run_comparison(
#     gt_path=f"{base_path}3.4_OCT_uint16_Cropped_Reflected_VolumeSplit_2_RegSeq_seqSVD_gt.tif",
#     predicted_path_1=f"{base_path}3.4_OCT_uint16_Cropped_Reflected_VolumeSplit_2_RegSeq_seqSVD_inpainted_2p5DUNet_fold5_v2.tif",
#     predicted_path_2=f"{base_path}3.4_OCT_uint16_Cropped_Reflected_VolumeSplit_2_RegSeq_seqSVD_inpainted_2p5DUNet_fold5_v4.tif",
#     mask_path=f"{base_path}3.4_OCT_uint16_Cropped_Reflected_VolumeSplit_2_RegSeq_seqSVD_mask.tif"
# )

# run_comparison(
#     gt_path=f"{base_path}5.3_OCT_uint16_Cropped_Reflected_VolumeSplit_1_RegSeq_seqSVD_gt.tif",
#     predicted_path_1=f"{base_path}5.3_OCT_uint16_Cropped_Reflected_VolumeSplit_1_RegSeq_seqSVD_inpainted_2p5DUNet_fold2_v2.tif",
#     predicted_path_2=f"{base_path}5.3_OCT_uint16_Cropped_Reflected_VolumeSplit_1_RegSeq_seqSVD_inpainted_2p5DUNet_fold2_v4.tif",
#     mask_path=f"{base_path}5.3_OCT_uint16_Cropped_Reflected_VolumeSplit_1_RegSeq_seqSVD_mask.tif"
# )

Volume Metrics for Predicted Volume 1:
 - L1 Loss: 0.0135
 - PSNR: 33.5569
 - Mean Intensity Diff: 0.0003
 - SSIM (win=7): 0.7665
 - SSIM (win=11): 0.7629
 - SSIM (win=17): 0.7603
 - NCC (win=7):  0.1838
 - NCC (win=11):  0.2677
 - NCC (win=17):  0.3781
----------------------------------------


Volume Metrics for Predicted Volume 2:
 - L1 Loss: 0.0136
 - PSNR: 33.5949
 - Mean Intensity Diff: 0.0001
 - SSIM (win=7): 0.7674
 - SSIM (win=11): 0.7641
 - SSIM (win=17): 0.7619
 - NCC (win=7):  0.1838
 - NCC (win=11):  0.2677
 - NCC (win=17):  0.3781
----------------------------------------


