In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image,ImageChops
from torchvision.transforms import Resize, Compose, ToTensor, Normalize, Grayscale, Pad
import matplotlib.pyplot as plt
from skimage import data
import skimage
from skimage.metrics import structural_similarity as ssim

## Load ground truth

In [None]:
obj_act=np.load("data/crystal.npy")
obj_act=obj_act[-241:,-241:]
obj_act=obj_act[30:-30,30:-30]
probe_act=np.load("data/probe_for_sim.npy")[24:-24,24:-24]

In [None]:
amp_act=np.abs(obj_act)
phase_act=np.angle(obj_act)

In [None]:
current_path=os.getcwd()

In [None]:
current_path

In [None]:
def psnr_amp(original, compressed): 
    mse = np.mean((original - compressed) ** 2) 
    if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                  # Therefore PSNR have no importance. 
        return 100
    max_pixel = 1.0
    psnr = 20 * log10(max_pixel / sqrt(mse)) 
    return [psnr,ssim(original,compressed,data_range=1)] 

In [None]:
def psnr_phase(original, compressed): 
    mse = np.mean((original - compressed) ** 2) 
    if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                  # Therefore PSNR have no importance. 
        return 100
    max_pixel = np.pi
    psnr = 20 * log10(max_pixel / sqrt(mse)) 
    return [psnr,ssim(original,compressed,data_range=2*np.pi)] 

In [None]:
from math import log10, sqrt 

In [None]:
os.chdir("result")

In [None]:
plt.imshow(np.abs(obj_act),cmap="grey")

In [None]:
plt.imshow(np.angle(obj_act),cmap="magma_r")

## Find the phase shift of the object phase using the least square method

In [None]:
def find_global_phase_shift_loop(A, B, num_steps=100000):
    best_phi = 0
    min_error = float('inf')

    phase_range = np.linspace(-np.pi, np.pi, num_steps)

    for phi in phase_range:
        error =  np.mean((np.angle(A*np.exp(phi*1j))-B)**2) # Compute least squares error

        if error < min_error:
            min_error = error
            best_phi = phi  # Update best phase shift

    return best_phi

## Calculate the PSNR of object amplitude and phase with CPU

In [None]:
import os
import csv
import numpy as np


# Loop through the files and calculate metrics
for filename in os.listdir("."):
    if "npy" in filename and "crystal" in filename and "obj" in filename and "_unknown" in filename and "PtyINR" in filename:
        obj = np.load(filename)
        
        # Process object array
        obj = obj[:241, :241]
        amp = np.abs(obj)[30:-30, 30:-30]
        amp = amp * np.median(amp_act) / np.median(amp)
        amp = amp.clip(max=1)
        phase_shift = find_global_phase_shift_loop(obj[30:-30, 30:-30], phase_act)
        phase = np.angle(obj * np.exp(phase_shift * 1j))[30:-30, 30:-30]


        # plt.imsave("images/"+(filename[:-4]+"_obj_amp.tiff"),amp,vmin=amp_act.min(), vmax=amp_act.max(),cmap="grey")
        # plt.imsave(("images/"+filename[:-4]+"_obj_phase.tiff"),phase,vmin=phase_act.min(), vmax=phase_act.max(),cmap="magma_r")
        # plt.imsave("images/"+(filename[:-4]+"_obj_amp_croped.tiff"),amp[35:95,62:122],vmin=amp_act.min(), vmax=amp_act.max(),cmap="grey")
        # plt.imsave(("images/"+filename[:-4]+"_obj_phase_croped.tiff"),phase[35:95,62:122],vmin=phase_act.min(), vmax=phase_act.max(),cmap="magma_r")
        # # Calculate PSNR metrics
        psnr_amp_result = psnr_amp(amp, amp_act)
        psnr_phase_result = psnr_phase(phase, phase_act)

        # Extract PSNR values
        psnr_amp_1 = round(psnr_amp_result[0], 2)
        psnr_amp_2 = round(psnr_amp_result[1], 2)
        psnr_phase_1 = round(psnr_phase_result[0], 2)
        psnr_phase_2 = round(psnr_phase_result[1], 2)
        print(f"{filename} : {psnr_amp_1}/{psnr_amp_2}  {psnr_phase_1}/{psnr_phase_2}")


## Below is GPU version of metrics calculation (faster)

In [None]:
import cupy as cp

def find_global_phase_shift_loop_cuda(A, B, num_steps=10000):
    """
    Compute the optimal global phase shift using CUDA (CuPy) with the least squares method.
    
    Parameters:
        A (cupy.ndarray): The first complex array (object).
        B (cupy.ndarray): The second real-valued array (reference phase).
        num_steps (int): Number of steps for phase search.

    Returns:
        float: The best phase shift minimizing the least squares error.
    """

    # Move data to GPU
    A_gpu = cp.asarray(A)  # A is complex
    B_gpu = cp.asarray(B)  # B is real (reference phase)

    # Generate phase shift candidates on GPU (num_steps values from -π to π)
    phase_range = cp.linspace(-cp.pi, cp.pi, num_steps)

    # Expand dimensions to broadcast over the image
    phase_range = phase_range[:, None, None]  # Shape: (num_steps, 1, 1)

    # Apply phase shifts to A (broadcasted computation)
    A_shifted = A_gpu * cp.exp(1j * phase_range)  # Apply phase shifts

    # Compute the phase difference with the reference phase B
    phase_diff = cp.angle(A_shifted) - B_gpu  # Compute element-wise phase difference

    # Compute the least squares error for each phase shift
    errors = cp.mean(phase_diff ** 2, axis=(1, 2))  # Mean squared error over image

    # Find the phase shift with the smallest error
    best_index = cp.argmin(errors)  # Index of minimum error
    best_phi = phase_range[best_index].item()  # Convert from GPU to float

    return best_phi

In [None]:
# Loop through the files and calculate metrics
for filename in os.listdir("."):
    if "npy" in filename and "crystal" in filename and "obj" in filename and "_unknown" in filename and "PtyINR" in filename:
        obj = np.load(filename)
        
        # Process object array
        obj = obj[:241, :241]
        amp = np.abs(obj)[30:-30, 30:-30]
        amp = amp * np.median(amp_act) / np.median(amp)
        amp = amp.clip(max=1)
        phase_shift = find_global_phase_shift_loop_cuda(obj[30:-30, 30:-30], phase_act)
        phase = np.angle(obj * np.exp(phase_shift * 1j))[30:-30, 30:-30]
        
        # Calculate PSNR metrics
        psnr_amp_result = psnr_amp(amp, amp_act)
        psnr_phase_result = psnr_phase(phase, phase_act)

        # Extract PSNR values
        psnr_amp_1 = round(psnr_amp_result[0], 2)
        psnr_amp_2 = round(psnr_amp_result[1], 2)
        psnr_phase_1 = round(psnr_phase_result[0], 2)
        psnr_phase_2 = round(psnr_phase_result[1], 2)
        print(f"{filename} : {psnr_amp_1}/{psnr_amp_2}  {psnr_phase_1}/{psnr_phase_2}")
