## Theory Analysis of Pychographic Wavefront Camera

- Author: Ni Chen
- Date: 2024

In [1]:
import sys
sys.path.append(f'../')
from util import *

from FP import *

from fractions import Fraction
from IPython.display import Image

%load_ext autoreload
%autoreload 2

torch.set_default_dtype(torch.float64)

device = "cuda" if torch.cuda.is_available() else "cpu"

is_band_limit=False
if is_band_limit==True:
    out_dir = 'output_band'
else:
    out_dir = 'output'

plt.rcParams['figure.figsize'] = [6, 3]
plt.rcParams['figure.dpi'] = 120
plt.rcParams['savefig.dpi'] = 120
plt.rcParams.update({'font.size': 14})


N_hr = 256
method = 'gs_gd'
ratio_aperture_obj = [1/8, 2/8, 3/8, 4/8, 5/8, 6/8, 7/8]
ratio_spacing_aperture = [1/8, 2/8, 3/8, 4/8, 5/8, 6/8, 7/8]
photons = [1e2, 1e3, 1e4, 1e5, 1e6]

## MSE vs. overlapping ratio

In [3]:
%reload_ext autoreload
from FP import *

N_lrs = [64, 96, 128, 160, 192, 224]

for N_lr in N_lrs:
    MSE_cpx_list = []
    MSE_ang_list = []
    MSE_amp_list = []
    PSNR_ang_list = []
    PSNR_amp_list = []
    
    for ol in ratio_spacing_aperture:
        print(f'=========== Overlapping ratio is {ol} ===========')

        metrics = group_test_exclude(try_num=2, N_hr=N_hr, N_lr=N_lr, padding=1, aperture_num=None, 
                             is_band_limit=is_band_limit,
                             overlapping_ratio=1-ol, method=method, device=device)

        MSE_cpx_list.append(metrics[0])
        MSE_ang_list.append(metrics[1])
        MSE_amp_list.append(metrics[2])
        PSNR_ang_list.append(metrics[3])
        PSNR_amp_list.append(metrics[4])

    np.save(f'./{out_dir}/mse_vs_overlapping_Nlr{N_lr}.npy', 
            [MSE_cpx_list, MSE_ang_list, MSE_amp_list, PSNR_ang_list, PSNR_amp_list])


## MSE vs. aperture number

In [None]:
%reload_ext autoreload
from FP import *


N_lrs = [64, 96, 128, 160, 192, 224]
N_pupils = [np.arange(2, 30, 1), np.arange(2, 30, 1), np.arange(2, 10, 1), np.arange(2, 9, 1), np.arange(2, 8, 1), np.arange(2, 7, 1)]

for N_lr, N_pupils in zip(N_lrs, N_pupils):
    MSE_cpx_list = []
    MSE_ang_list = []
    MSE_amp_list = []
    PSNR_ang_list = []
    PSNR_amp_list = []

    for N_p in N_pupils:
        print(f'=========== Aperture number is {N_p**2} ===========')

        metrics = group_test_exclude(try_num=5, N_hr=N_hr, N_lr=N_lr, padding=1, 
                             aperture_num=N_p, is_band_limit=is_band_limit,
                             overlapping_ratio=0.75, method=method, device=device)

        MSE_cpx_list.append(metrics[0])
        MSE_ang_list.append(metrics[1])
        MSE_amp_list.append(metrics[2])
        PSNR_ang_list.append(metrics[3])
        PSNR_amp_list.append(metrics[4])


    np.save(f'./{out_dir}/mse_vs_aperture_number_Nlr{N_lr}.npy', 
            [MSE_cpx_list, MSE_ang_list, MSE_amp_list, PSNR_ang_list, PSNR_amp_list])

## MSE vs. aperture size

In [None]:
%reload_ext autoreload
from FP import *

spacings=[0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75]
for s in spacings:
    MSE_cpx_list = []
    MSE_ang_list = []
    MSE_amp_list = []
    PSNR_ang_list = []
    PSNR_amp_list = []

    for N_lr in ratio_aperture_obj:
        print(f'=========== Aperture size is {N_lr} ===========')
        N_lr_single = N_lr*N_hr
        
        metrics = group_test_exclude(try_num=2, N_hr=N_hr, N_lr=N_lr_single, padding=1, 
                            aperture_num=None, is_band_limit=is_band_limit,
                            overlapping_ratio=1-s, method=method, device=device)
        
        MSE_cpx_list.append(metrics[0])
        MSE_ang_list.append(metrics[1])
        MSE_amp_list.append(metrics[2])
        PSNR_ang_list.append(metrics[3])
        PSNR_amp_list.append(metrics[4])

    np.save(f'./{out_dir}/mse_vs_aperture_size_s{s}.npy', 
            [MSE_cpx_list, MSE_ang_list, MSE_amp_list, PSNR_ang_list, PSNR_amp_list])

## MSE vs. aperture number and object complexity

In [None]:
%reload_ext autoreload
from FP import *


N_lr = 192
N_pupils = np.arange(2, 8, 1)


for N_phi in [0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    MSE_cpx_list = []
    MSE_ang_list = []
    MSE_amp_list = []
    PSNR_ang_list = []
    PSNR_amp_list = []

    for N_p in N_pupils:
        x_gt = torch.exp(1j*(N_phi*torch.pi*(2*(torch.rand([N_hr, N_hr])-0.5))))
        print(f'=========== Aperture number is {N_p**2} ===========')

        metrics = group_test_exclude(try_num=2, N_hr=N_hr, N_lr=N_lr, padding=1,
                            aperture_num=N_p, is_band_limit=True, gt=x_gt, max_it_gs=1000, max_it_gd=1000,
                            overlapping_ratio=0.75, method=method, device=device)

        MSE_cpx_list.append(metrics[0])
        MSE_ang_list.append(metrics[1])
        MSE_amp_list.append(metrics[2])
        PSNR_ang_list.append(metrics[3])
        PSNR_amp_list.append(metrics[4])

    np.save(f'./{out_dir}/mse_vs_aperture_number_phi{N_phi}.npy', [MSE_cpx_list, MSE_ang_list, MSE_amp_list, PSNR_ang_list, PSNR_amp_list])

## MSE vs. aperture size and object complexity

In [None]:
%reload_ext autoreload
from FP import *

for N_phi in [0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    MSE_cpx_list = []
    MSE_ang_list = []
    MSE_amp_list = []
    PSNR_ang_list = []
    PSNR_amp_list = []

    for N_lr in ratio_aperture_obj:
        print(f'=========== Aperture size is {N_lr} ===========')
        N_lr_single = N_lr*N_hr
        x_gt = torch.exp(1j*(N_phi*torch.pi*(2*(torch.rand([N_hr, N_hr])-0.5))))

        metrics = group_test_exclude(try_num=2, N_hr=N_hr, N_lr=N_lr_single, padding=1,
                            aperture_num=None, is_band_limit=True, gt=x_gt,
                            overlapping_ratio=0.65, method=method, device=device)

        MSE_cpx_list.append(metrics[0])
        MSE_ang_list.append(metrics[1])
        MSE_amp_list.append(metrics[2])
        PSNR_ang_list.append(metrics[3])
        PSNR_amp_list.append(metrics[4])

    np.save(f'./{out_dir}/mse_vs_aperture_size_phi{N_phi}.npy', [MSE_cpx_list, MSE_ang_list, MSE_amp_list, PSNR_ang_list, PSNR_amp_list])

## MSE vs. overlapping ratio with 3x3 layout

In [None]:
%reload_ext autoreload
from FP import *


ratio_spacing_aperture = [1/8, 2/8, 3/8, 4/8, 5/8, 6/8, 7/8]
N_lrs = [64, 96, 128, 160, 192,224]
for N_lr in N_lrs:
    MSE_cpx_list = []
    MSE_ang_list = []
    MSE_amp_list = []
    PSNR_ang_list = []
    PSNR_amp_list = []

    for ol in ratio_spacing_aperture:
        print(f'=========== Overlapping ratio is {ol} ===========')

        metrics = group_test_exclude(try_num=2, N_hr=N_hr, N_lr=N_lr, padding=1, aperture_num=3, 
                                     max_it_gs=1000, max_it_gd=1000,
                                     overlapping_ratio=1-ol, method=method, device=device)

        MSE_cpx_list.append(metrics[0])
        MSE_ang_list.append(metrics[1])
        MSE_amp_list.append(metrics[2])
        PSNR_ang_list.append(metrics[3])
        PSNR_amp_list.append(metrics[4])

    np.save(f'./{out_dir}/mse_vs_overlapping_Nlr{N_lr}_A3x3.npy',
            [MSE_cpx_list, MSE_ang_list, MSE_amp_list, PSNR_ang_list, PSNR_amp_list])
