In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image,ImageChops
from torchvision.transforms import Resize, Compose, ToTensor, Normalize, Grayscale, Pad
import matplotlib.pyplot as plt
from skimage import data
import skimage
from skimage.metrics import structural_similarity as ssim
from scipy.signal import correlate2d

## Load ground truth

In [None]:
obj_act=np.load("data/crystal.npy")
obj_act=obj_act[-241:,-241:]
obj_act=obj_act[30:-30,30:-30]
probe_act=np.load("data/probe_for_sim.npy")[24:-24,24:-24]

In [None]:
np.abs(probe_act).max()

In [None]:
amp_act=np.abs(obj_act)
phase_act=np.angle(obj_act)
phase_act = phase_act - np.median(phase_act)

In [None]:
probe_amp_act=np.abs(probe_act)/np.abs(probe_act).max()
probe_phase_act=np.angle(probe_act)
#probe_phase_act = probe_phase_act - probe_phase_act.min()

In [None]:
current_path=os.getcwd()

In [None]:
current_path

In [None]:
def psnr_amp(original, compressed): 
    mse = np.mean((original - compressed) ** 2) 
    if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                  # Therefore PSNR have no importance. 
        return 100
    max_pixel = 1.0
    psnr = 20 * log10(max_pixel / sqrt(mse)) 
    return [psnr,ssim(original,compressed,data_range=1)] 

In [None]:
def psnr_phase(original, compressed): 
    mse = np.mean((original - compressed) ** 2) 
    if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                  # Therefore PSNR have no importance. 
        return 100
    max_pixel = np.pi
    psnr = 20 * log10(max_pixel / sqrt(mse)) 
    return [psnr,ssim(original,compressed,data_range=2*np.pi)] 

In [None]:
from math import log10, sqrt 

In [None]:
os.chdir("result")

In [None]:
plt.imshow(np.sqrt(np.abs(probe_act)),cmap="grey")

In [None]:
plt.imshow((np.angle(probe_act)),cmap="magma")

## Find the phase shift of the probe phase using the least square method

In [None]:
def find_global_phase_shift_loop(A, B, num_steps=100000):
    best_phi = 0
    min_error = float('inf')

    phase_range = np.linspace(-np.pi, np.pi, num_steps)

    for phi in phase_range:
        phase_factor = np.exp(1j * phi)
        error =  np.mean((np.angle(A*np.exp(phi*1j))-B)**2) # Compute least squares error

        if error < min_error:
            min_error = error
            best_phi = phi  # Update best phase shift

    return best_phi

## Calculate the PSNR of object amplitude and phase with CPU

In [None]:
# Loop through the files and calculate metrics
for j in os.listdir("."):
    if "npy" in j and "crystal" in j and "probe." in j and "_unknown" in j and "PtyINR" in j:
        obj = np.load(j)
        
        amp = np.abs(obj)/np.abs(obj).max()
        amp = amp * np.median(probe_amp_act) / np.median(amp)
        phase = np.angle(obj)
        phase_shift = find_global_phase_shift_loop(obj,probe_phase_act)
        phase=np.angle(obj*np.exp(phase_shift*1j))
        
        
        # plt.imsave("images/"+(j[:-4]+"_probe_amp.tiff"),np.sqrt(amp),vmin=np.sqrt(probe_amp_act).min(), vmax=np.sqrt(probe_amp_act).max(),cmap="grey")
        # plt.imsave(("images/"+j[:-4]+"_probe_phase.tiff"),phase,vmin=probe_phase_act.min(), vmax=probe_phase_act.max(),cmap="magma")
        # plt.imsave("images/"+(j[:-4]+"_amp_croped.tiff"),np.sqrt(amp[18:-18,18:-18]),vmin=np.sqrt(probe_amp_act).min()
        #            , vmax=np.sqrt(probe_amp_act).max(),cmap="grey")
        # plt.imsave(("images/"+j[:-4]+"_phase_croped.tiff"),phase[18:-18,18:-18],vmin=probe_phase_act.min()
        #            , vmax=probe_phase_act.max(),cmap="magma")


        
        # Calculate PSNR metrics
        psnr_amp_result = psnr_amp(amp, probe_amp_act)
        psnr_phase_result = psnr_phase(phase, probe_phase_act)

        # Extract results
        psnr_amp_1 = round(psnr_amp_result[0], 2)
        psnr_amp_2 = round(psnr_amp_result[1], 2)
        psnr_phase_1 = round(psnr_phase_result[0], 2)
        psnr_phase_2 = round(psnr_phase_result[1], 2)

        # Print results for debugging
        print(f"{j} : {psnr_amp_1}/{psnr_amp_2}  {psnr_phase_1}/{psnr_phase_2}")