In [None]:
import sys
import os
from tqdm import tqdm

parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)

import evaluation.evaluation_util as util
from Models.PtychoFormer import PtychoFormer
from Models.ePIE import ePIE
from Models.extendedPtychoFormer import ExtendedPtychoFormer
from Models.PtychoNet import PtychoNet
from Models.PtychoNN import PtychoNN

import torch
from torch.nn.functional import l1_loss
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.colorbar as mcolorbar
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [None]:
PtychoFormer_weights = os.path.join(os.pardir, 'Models', 'Weights', 'PtychoFormer_weights.pth')
PtychoNN_weights = os.path.join(os.pardir, 'Models', 'Weights', 'PtychoNN_weights.pth')
PtychoNet_weights = os.path.join(os.pardir, 'Models', 'Weights', 'PtychoNet_weights.pth')

PtychoFormer = PtychoFormer()
PtychoFormer.load_state_dict(torch.load(PtychoFormer_weights)['model'])
PtychoFormer = PtychoFormer.to(device)
PtychoFormer.eval()

ePtychoFormer = ExtendedPtychoFormer()
ePtychoFormer.load_ptychoformer_weights(torch.load(PtychoFormer_weights)['model'])
ePtychoFormer.to(device)

PtychoNN = PtychoNN()
PtychoNN.load_state_dict(torch.load(PtychoNN_weights)['model'])
PtychoNN = PtychoNN.to(device)
PtychoNN.eval()

PtychoNet = PtychoNet()
PtychoNet.load_state_dict(torch.load(PtychoNet_weights)['model'])
PtychoNet = PtychoNet.to(device)
PtychoNet.eval()
print()

In [None]:
raw_img_size = 500
num_flickr_img = 31782
test_image_len = (5*int(num_flickr_img * 0.1)) // 2
test_image_path = os.path.join(parent_dir, 'data_processing', 'preprocess', 'test_transmission_data_memmap.npy')
raw_test_data = np.memmap(test_image_path, dtype='float32', mode='r', shape=(test_image_len, 2, raw_img_size, raw_img_size))
raw_test_data.shape

Step Size (20-60) accuracies (PtychoFormer, ePF, ePIE, PtychoNN, PtychoNet) (NRMSE + MAE)

In [None]:
num_test_data = 2
grid_dim = 3
offset = 36
diff_size = 128
methods = ['PtychoFormer', 'ePtychoFormer', 'ePIE', 'PtychoNN', 'PtychoNet']
step_arr = [20, 30, 40, 50, 60]
metrics = ['NRMSE', 'MAE']
num_metrics = len(metrics)
num_method = len(methods)
num_step = len(step_arr)
ePIE_iter = 100

results = np.zeros((num_metrics, num_step, num_method, 2, num_test_data))
test_img_indices = np.arange(test_image_len)
np.random.shuffle(test_img_indices)
test_transmissions = raw_test_data[test_img_indices[:num_test_data]]
test_transmissions.shape

In [None]:
with tqdm(total=num_test_data*num_step) as progressbar:
    with torch.no_grad():
        for data_idx in range(num_test_data):
            for step_idx, step in enumerate(step_arr):
                if (step == 60):
                    n_grid = 4
                    diff_step = 1
                else:
                    n_grid = 2
                    diff_step = 3
                grid_data, single_data, patch_label, entire_label = util.generate_grids(test_transmissions[data_idx], step, n_grid, diff_step)
                entire_label = entire_label[:,offset:-offset,offset:-offset]
                CNN_data = single_data.unsqueeze(1).to(device)
                single_data = single_data.to(device)
                grid_data = grid_data.to(device)
                n_row = int(np.sqrt(single_data.shape[0]))
                patch_size = diff_size + (grid_dim-1) * step

                PtychoFormer_patch_pred = PtychoFormer(grid_data)
                PtychoFormer_pred = util.feathered_stitching(PtychoFormer_patch_pred.reshape(n_grid,n_grid,2,patch_size,patch_size), n_grid, step, diff_step)
                PtychoFormer_pred = PtychoFormer_pred[:,offset:-offset,offset:-offset].detach().cpu()

                ePtychoFormer_pred, _, _= ePtychoFormer(grid_data, single_data, step, diff_step, n_row, iteration=ePIE_iter)
                ePtychoFormer_pred = ePtychoFormer_pred[:,offset:-offset,offset:-offset].detach().cpu()

                ePIE_pred, _, _ = ePIE(single_data, step, n_row, iteration=ePIE_iter)
                ePIE_pred = ePIE_pred[:,offset:-offset,offset:-offset].detach().cpu()

                PtychoNN_patch_pred = PtychoNN(CNN_data)
                PtychoNN_patch_pred = PtychoNN_patch_pred[:,:,offset:-offset,offset:-offset]
                PtychoNN_pred = util.CNN_stitching(PtychoNN_patch_pred.reshape(n_row, n_row, 2, PtychoNN_patch_pred.shape[-1],PtychoNN_patch_pred.shape[-1]), n_row, step).detach().cpu()

                PtychoNet_patch_pred = PtychoNet(CNN_data)
                PtychoNet_patch_pred = PtychoNet_patch_pred[:,:,offset:-offset,offset:-offset]
                PtychoNet_pred = util.CNN_stitching(PtychoNet_patch_pred.reshape(n_row, n_row, 2, PtychoNet_patch_pred.shape[-1],PtychoNet_patch_pred.shape[-1]), n_row, step).detach().cpu()
                
                predictions = [PtychoFormer_pred, ePtychoFormer_pred, ePIE_pred, PtychoNN_pred, PtychoNet_pred]
                for method_idx, pred in enumerate(predictions):
                    results[0, step_idx, method_idx, 0, data_idx] = util.compute_nrmse(entire_label[0], pred[0])
                    results[0, step_idx, method_idx, 1, data_idx] = util.compute_nrmse(entire_label[1], pred[1], diff_bool=False)
                    results[1, step_idx, method_idx, 0, data_idx] = l1_loss(entire_label[0], pred[0]).item()
                    results[1, step_idx, method_idx, 1, data_idx] = l1_loss(entire_label[1], pred[1]).item()
                progressbar.update(1)

In [None]:
# Iterate through each step and print the results
for step_idx, step in enumerate(step_arr):
    print(f"Step Size: {step}")
    for metric_idx, metric_name in enumerate(metrics):
        for method_idx, method_name in enumerate(methods): 
            # Fetch results for phase and amplitude
            phase_results = results[metric_idx, step_idx, method_idx, 0, :]
            amp_results = results[metric_idx, step_idx, method_idx, 1, :]
            
            # Calculate mean and standard deviation
            phase_mean = round(np.mean(phase_results), 3)
            phase_std = round(np.std(phase_results), 2)
            amp_mean = round(np.mean(amp_results), 3)
            amp_std = round(np.std(amp_results), 2)
            
            # Print the formatted result for the current method and metric
            print(f"{method_name:20} {metric_name:10} Ph: {phase_mean:6} ± {phase_std:<5} Amp: {amp_mean:6} ± {amp_std:<5}")            
    print()

### Visually Compare Predictions

In [None]:
step = 20
grid_dim = 3
n_grid = 6
diff_step = 3
offset = 36
diff_size = 128
ePIE_iter = 100
with torch.no_grad():
    grid_data, single_data, patch_label, entire_label = util.generate_grids(raw_test_data[test_img_indices[0]], step, n_grid, diff_step)

    entire_label = entire_label[:,offset:-offset,offset:-offset]
    CNN_data = single_data.unsqueeze(1).to(device)
    single_data = single_data.to(device)
    grid_data = grid_data.to(device)
    n_row = int(np.sqrt(single_data.shape[0]))
    patch_size = diff_size + (grid_dim-1) * step

    PtychoFormer_patch_pred = PtychoFormer(grid_data)

    PtychoFormer_unfeathered_pred = util.unfeathered_stitching(PtychoFormer_patch_pred.reshape(n_grid,n_grid,2,patch_size,patch_size), n_grid, step, diff_step)
    PtychoFormer_unfeathered_pred = PtychoFormer_unfeathered_pred[:,offset:-offset,offset:-offset].detach().cpu()

    PtychoFormer_pred = util.feathered_stitching(PtychoFormer_patch_pred.reshape(n_grid,n_grid,2,patch_size,patch_size), n_grid, step, diff_step)
    PtychoFormer_pred = PtychoFormer_pred[:,offset:-offset,offset:-offset].detach().cpu()

    ePtychoFormer_pred, _, _= ePtychoFormer(grid_data, single_data, step, diff_step, n_row, iteration=ePIE_iter)
    ePtychoFormer_pred = ePtychoFormer_pred[:,offset:-offset,offset:-offset].detach().cpu()

    ePIE_pred, _, _ = ePIE(single_data, step, n_row, iteration=ePIE_iter)
    ePIE_pred = ePIE_pred[:,offset:-offset,offset:-offset].detach().cpu()

    PtychoNN_patch_pred = PtychoNN(CNN_data)
    PtychoNN_patch_pred = PtychoNN_patch_pred[:,:,offset:-offset,offset:-offset]
    PtychoNN_pred = util.CNN_stitching(PtychoNN_patch_pred.reshape(n_row, n_row, 2, PtychoNN_patch_pred.shape[-1],PtychoNN_patch_pred.shape[-1]), n_row, step).detach().cpu()

    PtychoNet_patch_pred = PtychoNet(CNN_data)
    PtychoNet_patch_pred = PtychoNet_patch_pred[:,:,offset:-offset,offset:-offset]
    PtychoNet_pred = util.CNN_stitching(PtychoNet_patch_pred.reshape(n_row, n_row, 2, PtychoNet_patch_pred.shape[-1],PtychoNet_patch_pred.shape[-1]), n_row, step).detach().cpu()

Feathering VS no Feathering

In [None]:
f, ax = plt.subplots(2, 3, figsize=(6, 4))
predictions = [entire_label, PtychoFormer_pred, PtychoFormer_unfeathered_pred]
for method_idx, pred in enumerate(predictions):
    for i in range(2):
        pred_idx = (i+1)%2
        ax[i, method_idx].imshow(pred[pred_idx], vmin=0 if i == 0 else -np.pi, vmax=1 if i == 0 else np.pi, cmap='gray' if i == 0 else 'viridis')
        ax[i, method_idx].set_xticklabels([])
        ax[i, method_idx].set_yticklabels([])
        ax[i, method_idx].set_xticks([])
        ax[i, method_idx].set_yticks([])
        for spine in ax[i, method_idx].spines.values():
            spine.set_visible(False)
        
cbar_ax_top = f.add_axes([0.907, 0.515, 0.017, 0.35])
cbar_ax_bottom = f.add_axes([0.907, 0.13, 0.017, 0.35])

norm_top = mcolors.Normalize(vmin=0, vmax=1)
norm_bottom = mcolors.Normalize(vmin=-np.pi, vmax=np.pi)

cb1 = mcolorbar.ColorbarBase(cbar_ax_top, cmap='gray', norm=norm_top, orientation='vertical')
cb2 = mcolorbar.ColorbarBase(cbar_ax_bottom, cmap='viridis', norm=norm_bottom, orientation='vertical')
cb1.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
cb1.set_ticklabels(['0', '0.25', '0.5', '0.75','1'])
cb2.set_ticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
cb2.set_ticklabels([r'$-\pi$', r'$-\pi/2$', '0', r'$\pi/2$', r'$\pi$'])
cb2.ax.tick_params(labelsize=10)
cb1.ax.tick_params(labelsize=10)

for spine1, spine2 in zip(cb1.ax.spines.values(), cb2.ax.spines.values()):
    spine1.set_linewidth(0.3)
    spine2.set_linewidth(0.3)

ax[0, 0].set_ylabel('Amplitude', fontsize=16)
ax[1, 0].set_ylabel('Phase', fontsize=16)
ax[1, 0].set_xlabel('Ground Truth', fontsize=13)
ax[1, 1].set_xlabel('PtychoFormer\nFeathered', fontsize=13)
ax[1, 2].set_xlabel('PtychoFormer\nUnfeathered', fontsize=13)
plt.subplots_adjust(wspace=0.03, hspace=0.0)
plt.show()

Deep Learning Method Comparison

In [None]:
f, ax = plt.subplots(2, 4, figsize=(8, 4))
predictions = [entire_label, PtychoFormer_pred, PtychoNN_pred, PtychoNet_pred]
for method_idx, pred in enumerate(predictions):
    for i in range(2):
        pred_idx = (i+1)%2
        ax[i, method_idx].imshow(pred[pred_idx], vmin=0 if i == 0 else -np.pi, vmax=1 if i == 0 else np.pi, cmap='gray' if i == 0 else 'viridis')
        ax[i, method_idx].set_xticklabels([])
        ax[i, method_idx].set_yticklabels([])
        ax[i, method_idx].set_xticks([])
        ax[i, method_idx].set_yticks([])
        for spine in ax[i, method_idx].spines.values():
            spine.set_visible(False)
        if method_idx:
            nrmse = util.compute_nrmse(entire_label[pred_idx], pred[pred_idx], diff_bool=False if i == 0 else True)
            ax[i, method_idx].text(0.96, 0.04, f'{nrmse:.4f}', color='white', fontsize=17,
                        ha='right', va='bottom', transform=ax[i, method_idx].transAxes, bbox=dict(facecolor='black', alpha=0.5, edgecolor='none'))

cbar_ax_top = f.add_axes([0.907, 0.515, 0.017, 0.35])
cbar_ax_bottom = f.add_axes([0.907, 0.13, 0.017, 0.35])

norm_top = mcolors.Normalize(vmin=0, vmax=1)
norm_bottom = mcolors.Normalize(vmin=-np.pi, vmax=np.pi)

cb1 = mcolorbar.ColorbarBase(cbar_ax_top, cmap='gray', norm=norm_top, orientation='vertical')
cb2 = mcolorbar.ColorbarBase(cbar_ax_bottom, cmap='viridis', norm=norm_bottom, orientation='vertical')
cb1.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
cb1.set_ticklabels(['0', '0.25', '0.5', '0.75','1'])
cb2.set_ticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
cb2.set_ticklabels([r'$-\pi$', r'$-\pi/2$', '0', r'$\pi/2$', r'$\pi$'])
cb2.ax.tick_params(labelsize=10)
cb1.ax.tick_params(labelsize=10)

for spine1, spine2 in zip(cb1.ax.spines.values(), cb2.ax.spines.values()):
    spine1.set_linewidth(0.3)
    spine2.set_linewidth(0.3)

ax[0, 0].set_ylabel('Amplitude', fontsize=16)
ax[1, 0].set_ylabel('Phase', fontsize=16)
ax[1, 0].set_xlabel('Ground Truth', fontsize=13)
ax[1, 1].set_xlabel('PtychoFormer', fontsize=13)
ax[1, 2].set_xlabel('PtychoNN', fontsize=13)
ax[1, 3].set_xlabel('PtychoNet', fontsize=13)
plt.subplots_adjust(wspace=0.03, hspace=0.0)
plt.show()

Comparison of ePIE and our methods

In [None]:
f, ax = plt.subplots(2, 4, figsize=(8, 4))
predictions = [entire_label, PtychoFormer_pred, ePtychoFormer_pred, ePIE_pred]
for method_idx, pred in enumerate(predictions):
    for i in range(2):
        pred_idx = (i+1)%2
        ax[i, method_idx].imshow(pred[pred_idx], vmin=0 if i == 0 else -np.pi, vmax=1 if i == 0 else np.pi, cmap='gray' if i == 0 else 'viridis')
        ax[i, method_idx].set_xticklabels([])
        ax[i, method_idx].set_yticklabels([])
        ax[i, method_idx].set_xticks([])
        ax[i, method_idx].set_yticks([])
        for spine in ax[i, method_idx].spines.values():
            spine.set_visible(False)
        if method_idx:
            nrmse = util.compute_nrmse(entire_label[pred_idx], pred[pred_idx], diff_bool=False if i == 0 else True)
            ax[i, method_idx].text(0.96, 0.04, f'{nrmse:.4f}', color='white', fontsize=17,
                        ha='right', va='bottom', transform=ax[i, method_idx].transAxes, bbox=dict(facecolor='black', alpha=0.5, edgecolor='none'))

cbar_ax_top = f.add_axes([0.907, 0.515, 0.017, 0.35])
cbar_ax_bottom = f.add_axes([0.907, 0.13, 0.017, 0.35])

norm_top = mcolors.Normalize(vmin=0, vmax=1)
norm_bottom = mcolors.Normalize(vmin=-np.pi, vmax=np.pi)

cb1 = mcolorbar.ColorbarBase(cbar_ax_top, cmap='gray', norm=norm_top, orientation='vertical')
cb2 = mcolorbar.ColorbarBase(cbar_ax_bottom, cmap='viridis', norm=norm_bottom, orientation='vertical')
cb1.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
cb1.set_ticklabels(['0', '0.25', '0.5', '0.75','1'])
cb2.set_ticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
cb2.set_ticklabels([r'$-\pi$', r'$-\pi/2$', '0', r'$\pi/2$', r'$\pi$'])
cb2.ax.tick_params(labelsize=10)
cb1.ax.tick_params(labelsize=10)

for spine1, spine2 in zip(cb1.ax.spines.values(), cb2.ax.spines.values()):
    spine1.set_linewidth(0.3)
    spine2.set_linewidth(0.3)

ax[0, 0].set_ylabel('Amplitude', fontsize=16)
ax[1, 0].set_ylabel('Phase', fontsize=16)
ax[1, 0].set_xlabel('Ground Truth', fontsize=13)
ax[1, 1].set_xlabel('PtychoFormer', fontsize=13)
ax[1, 2].set_xlabel('ePtychoFormer', fontsize=13)
ax[1, 3].set_xlabel('ePIE', fontsize=13)
plt.subplots_adjust(wspace=0.03, hspace=0.0)
plt.show()

Line Profile Comparison

In [None]:
util.phase_line_profile(entire_label, PtychoFormer_pred, ePtychoFormer_pred, ePIE_pred)

Fourier Ring Correlation (FRC)

In [None]:
num_test_data = 2
grid_dim = 3
offset = 36
diff_size = 128
methods = ['PtychoFormer', 'ePtychoFormer', 'ePIE']
num_method = len(methods)
step = 20
n_grid = 6
ePIE_iter = 100

results = np.zeros((num_method, 2, num_test_data))
test_img_indices = np.arange(test_image_len)
np.random.shuffle(test_img_indices)
test_transmissions = raw_test_data[test_img_indices[:num_test_data]]
test_transmissions.shape

In [None]:
frequencies = None
PtychoFormer_frc_values_amp = []
PtychoFormer_frc_values_ph = []
ePtychoFormer_frc_values_amp = []
ePtychoFormer_frc_values_ph = []
ePIE_frc_values_amp = []
ePIE_frc_values_ph = []
loss_ePF = []
loss_ePIE = []
with tqdm(total=num_test_data) as progressbar:
    with torch.no_grad():
        for data_idx in range(num_test_data):
                grid_data, single_data, patch_label, entire_label = util.generate_grids(test_transmissions[data_idx], step, n_grid, diff_step)
                entire_label = entire_label[:,offset:-offset,offset:-offset]
                single_data = single_data.to(device)
                grid_data = grid_data.to(device)
                n_row = int(np.sqrt(single_data.shape[0]))
                patch_size = diff_size + (grid_dim-1) * step

                PtychoFormer_patch_pred = PtychoFormer(grid_data)
                PtychoFormer_pred = util.feathered_stitching(PtychoFormer_patch_pred.reshape(n_grid,n_grid,2,patch_size,patch_size), n_grid, step, diff_step)
                PtychoFormer_pred = PtychoFormer_pred[:,offset:-offset,offset:-offset].detach().cpu()

                ePtychoFormer_pred, SSE_loss_ePF, _= ePtychoFormer(grid_data, single_data, step, diff_step, n_row, iteration=ePIE_iter)
                ePtychoFormer_pred = ePtychoFormer_pred[:,offset:-offset,offset:-offset].detach().cpu()
                loss_ePF.append(SSE_loss_ePF)

                ePIE_pred, SSE_loss_ePIE, _ = ePIE(single_data, step, n_row, iteration=ePIE_iter)
                ePIE_pred = ePIE_pred[:,offset:-offset,offset:-offset].detach().cpu()
                loss_ePIE.append(SSE_loss_ePIE)

                frequencies, frc_values_ph = util.compute_frc(PtychoFormer_pred[0], entire_label[0])
                frequencies, frc_values_amp = util.compute_frc(PtychoFormer_pred[1], entire_label[1])
                PtychoFormer_frc_values_amp.append(frc_values_amp)
                PtychoFormer_frc_values_ph.append(frc_values_ph)

                frequencies, frc_values_ph = util.compute_frc(ePtychoFormer_pred[0], entire_label[0])
                frequencies, frc_values_amp = util.compute_frc(ePtychoFormer_pred[1], entire_label[1])
                ePtychoFormer_frc_values_amp.append(frc_values_amp)
                ePtychoFormer_frc_values_ph.append(frc_values_ph)

                frequencies, frc_values_ph = util.compute_frc(ePIE_pred[0], entire_label[0])
                frequencies, frc_values_amp = util.compute_frc(ePIE_pred[1], entire_label[1])
                ePIE_frc_values_amp.append(frc_values_amp)
                ePIE_frc_values_ph.append(frc_values_ph)
                
                progressbar.update(1)

In [None]:
def calculate_stats(amp_values, ph_values):
    return {
        'mean_amp': np.mean(amp_values, axis=0),
        'std_amp': np.std(amp_values, axis=0),
        'mean_ph': np.mean(ph_values, axis=0),
        'std_ph': np.std(ph_values, axis=0),
    }

frc_data = [
    calculate_stats(PtychoFormer_frc_values_amp, PtychoFormer_frc_values_ph),
    calculate_stats(ePtychoFormer_frc_values_amp, ePtychoFormer_frc_values_ph),
    calculate_stats(ePIE_frc_values_amp, ePIE_frc_values_ph)
]

util.plot_frc_multiple(frequencies, frc_data, stdev_step=2)

### Test Various Setup (Scan Patterns, Probe Functions, Caltech101/Flower102 Datasets)

Scan Patterns

In [None]:
PtychoFormer_weights = os.path.join(os.pardir, 'Models', 'Weights', 'PtychoFormer_scan.pth')
PtychoFormer = PtychoFormer()
PtychoFormer.load_state_dict(torch.load(PtychoFormer_weights)['model'])
PtychoFormer = PtychoFormer.to(device)
PtychoFormer.eval()
print()

In [None]:
data_per_cls = 2223
trn_data_per_cls = int(data_per_cls*0.9)
tst_data_per_cls = data_per_cls - trn_data_per_cls
data_dir = os.path.join(os.pardir, 'dataset', 'scans', 'data')
num_test_data = 2
num_cls = 8
offset = 55

measure_NRMSE = np.zeros((num_cls, 2, num_test_data))

with tqdm(total=7*num_test_data) as progressbar:
    for idx, i in enumerate(range(0, data_per_cls*num_cls, data_per_cls)):
        group_indices = range(i, i+data_per_cls)
        test_indices = []
        test_indices.extend(group_indices[-tst_data_per_cls:])
        np.random.shuffle(test_indices)
        for j in range(test_indices[:num_test_data]):
            data = torch.load(os.path.join(data_dir, f'data_{test_indices[j]}'))
            input = data['input'].to(device).unsqueeze(0)
            output = PtychoFormer(input)
            output = output[:, offset:-offset, offset:-offset].cpu().squeeze().detach()
            label = data['label'][:, offset:-offset, offset:-offset]
            measure_NRMSE[idx, j, 0] = util.compute_nrmse(label[0], output[0], diff_bool=True)
            measure_NRMSE[idx, j, 1] = util.compute_nrmse(label[1], output[1], diff_bool=False)
            progressbar.update(1)

scan_config = ['Five', 'Six', 'Seven', 'Eight', 'Diamond', 'Parallel', 'Random']
print()
for idx, scan in enumerate(scan_config):
    # print("{}\t\t{} ± {}".format(scan, round(np.mean(measure_NRMSE[idx]), 3), round(np.std(measure_NRMSE[idx]), 2)))
    print("{}\t\t{} ± {}\t{} ± {}\t{} ± {}".format(scan, round(np.mean(measure_NRMSE[idx]), 3), round(np.std(measure_NRMSE[idx]), 2), round(np.mean(measure_NRMSE[idx,:,0]), 3), round(np.std(measure_NRMSE[idx,:,0]), 2), round(np.mean(measure_NRMSE[idx,:,1]), 3), round(np.std(measure_NRMSE[idx,:,1]), 2)))

Probe Functions

In [None]:
PtychoFormer_weights = os.path.join(os.pardir, 'Models', 'Weights', 'PtychoFormer_probe.pth')
PtychoFormer = PtychoFormer()
PtychoFormer.load_state_dict(torch.load(PtychoFormer_weights)['model'])
PtychoFormer = PtychoFormer.to(device)
PtychoFormer.eval()
print()

In [None]:
probe_arg = [1, 0.6, 1.2, 0.85, 0.8, 0.75, 0.7, 0.65, 0.64, 0.62, 0.58]

fig1, ax1 = plt.subplots(2, 6, figsize=(13, 4))

max_amp = 1.5
min_amp = 0
max_ph = np.pi
min_ph = -np.pi

# Plot the first two rows (first 6 probes)
for i, arg in enumerate(probe_arg[:6]):
    Wave, Hole, Hz2, HzMinus = util.create_probe(arg)
    
    # Display amplitude (absolute value) in grayscale
    im1 = ax1[0, i].imshow(abs(Wave), vmin=min_amp, vmax=max_amp, cmap='gray')
    if(i == 0):
        ax1[0, i].set_title(f'Probe {i+1} (A)')
    elif(i == 1):
        ax1[0, i].set_title(f'Probe {i+1} (B)')
    else:
        ax1[0, i].set_title(f'Probe {i+1}')
    ax1[0, i].set_xticks([])
    ax1[0, i].set_yticks([])
    for spine in ax1[0, i].spines.values():
        spine.set_visible(False)

    # Display phase (angle) in the second row
    im2 = ax1[1, i].imshow(torch.angle(Wave), vmin=min_ph, vmax=max_ph, cmap='viridis')
    ax1[1, i].set_xticks([])
    ax1[1, i].set_yticks([])
    for spine in ax1[1, i].spines.values():
        spine.set_visible(False)

# Add Amplitude and Phase labels on the left side for the first subfigure
ax1[0, 0].set_ylabel('Amplitude', fontsize=12)
ax1[1, 0].set_ylabel('Phase', fontsize=12)

# Create colorbars for the first subfigure
cax_top1 = fig1.add_axes([0.89, 0.53, 0.01, 0.4])  # Amplitude colorbar   [left, bottom, width, height]
cax_bottom1 = fig1.add_axes([0.89, 0.07, 0.01, 0.4])  # Phase colorbar
# Adding colorbars
cbar_top1 = plt.colorbar(im1, cax=cax_top1, cmap='gray')
cbar_top1.set_ticks([0, 0.5, 1.0, 1.5])
cbar_top1.ax.tick_params(labelsize=8)

cbar_bottom1 = plt.colorbar(im2, cax=cax_bottom1, cmap='viridis')
cbar_bottom1.set_ticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
cbar_bottom1.ax.tick_params(labelsize=8)
cbar_bottom1.set_ticklabels([r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])

for spine1, spine2 in zip(cbar_bottom1.ax.spines.values(), cbar_top1.ax.spines.values()):
    spine1.set_linewidth(0.3)
    spine2.set_linewidth(0.3)

# Adjust layout for the first subfigure
fig1.subplots_adjust(left=0.05, right=0.88, bottom=0.05, top=0.95, wspace=0.03, hspace=0.03)
plt.show()

# Second subfigure (last two rows)
fig2, ax2 = plt.subplots(2, 6, figsize=(13, 4))

# Plot the last two rows (remaining probes including Probe C)
for i, arg in enumerate(probe_arg[6:]):
    Wave, Hole, Hz2, HzMinus = util.create_probe(arg)
    
    # Display amplitude (absolute value) in grayscale
    im1 = ax2[0, i].imshow(abs(Wave), vmin=min_amp, vmax=max_amp, cmap='gray')
    ax2[0, i].set_title(f'Probe {i+7}')
    ax2[0, i].set_xticks([])
    ax2[0, i].set_yticks([])
    for spine in ax2[0, i].spines.values():
        spine.set_visible(False)

    # Display phase (angle) in the second row
    im2 = ax2[1, i].imshow(torch.angle(Wave), vmin=min_ph, vmax=max_ph, cmap='viridis')
    ax2[1, i].set_xticks([])
    ax2[1, i].set_yticks([])
    for spine in ax2[1, i].spines.values():
        spine.set_visible(False)

Wave2, Hole, Hz2, HzMinus = util.create_probe(0.77)  # Generate Probe C

# Display amplitude (absolute value) of Probe C in grayscale
im1_c = ax2[0, -1].imshow(abs(Wave2), vmin=min_amp, vmax=max_amp, cmap='gray')
ax2[0, -1].set_title('Test Probe (C)')
ax2[0, -1].set_xticks([])
ax2[0, -1].set_yticks([])
for spine in ax2[0, -1].spines.values():
    spine.set_visible(False)

# Display phase (angle) of Probe C
im2_c = ax2[1, -1].imshow(torch.angle(Wave2), vmin=min_ph, vmax=max_ph, cmap='viridis')
ax2[1, -1].set_xticks([])
ax2[1, -1].set_yticks([])
for spine in ax2[1, -1].spines.values():
    spine.set_visible(False)

# Add Amplitude and Phase labels on the left side for the second subfigure
ax2[0, 0].set_ylabel('Amplitude', fontsize=12)
ax2[1, 0].set_ylabel('Phase', fontsize=12)

# Create colorbars for the second subfigure
cax_top2 = fig2.add_axes([0.89, 0.53, 0.01, 0.4])  # Amplitude colorbar   [left, bottom, width, height]
cax_bottom2 = fig2.add_axes([0.89, 0.07, 0.01, 0.4])  # Phase colorbar

# Adding colorbars
cbar_top2 = plt.colorbar(im1_c, cax=cax_top2, cmap='gray')
cbar_top2.set_ticks([0, 0.5, 1.0, 1.5])
cbar_top2.ax.tick_params(labelsize=8)

cbar_bottom2 = plt.colorbar(im2_c, cax=cax_bottom2, cmap='viridis')
cbar_bottom2.set_ticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
cbar_bottom2.ax.tick_params(labelsize=8)
cbar_bottom2.set_ticklabels([r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])

for spine1, spine2 in zip(cbar_bottom2.ax.spines.values(), cbar_top2.ax.spines.values()):
    spine1.set_linewidth(0.3)
    spine2.set_linewidth(0.3)

fig2.subplots_adjust(left=0.05, right=0.88, bottom=0.05, top=0.95, wspace=0.03, hspace=0.03)
plt.show()

In [None]:
data_dir = os.path.join(os.pardir, 'dataset', 'probes', 'data')

num_sample = 2
offset = 36
probe_arg = [1.2, 1, 0.85, 0.8, 0.75, 0.7, 0.65, 0.64, 0.62, 0.6, 0.58]
num_cls = len(probe_arg)
data_per_cls = 2223
trn_data_per_cls = int(data_per_cls*0.9)
tst_data_per_cls = data_per_cls - trn_data_per_cls

assert(num_sample <= tst_data_per_cls)
test_img_indices = np.arange(test_image_len)
np.random.shuffle(test_img_indices)
test_transmissions = raw_test_data[test_img_indices[:num_sample]]

# Initializing index lists for 11 probe functions and other (probe C)
probe_idx = [[] for _ in range(11)]
probeC_idx = []

# Grouping indices for each probe function and assigning last 200 samples to test set
for idx, i in enumerate(range(0, data_per_cls*num_cls, data_per_cls)):
    group_indices = range(i, i+data_per_cls)
    probe_idx[idx].extend(group_indices[-tst_data_per_cls:][:num_sample])

# NRMSE arrays for the 11 probes and probe C
probe_NRMSE = [np.zeros((num_sample, 2)) for _ in range(11)]
probe_c_NRMSE = np.zeros((num_sample, 2))


# Progress bar for tracking
with torch.no_grad():
    with tqdm(total=num_sample) as progressbar:
        for idx in range(num_sample):
            for probe_id in range(11):
                probe_data = torch.load(os.path.join(data_dir, f'data_{probe_idx[probe_id][idx]}'))

                output = PtychoFormer(probe_data['input'].unsqueeze(0).to(device))
                output = output[:, offset:-offset, offset:-offset].squeeze().cpu().detach()
                label = probe_data['label'][:, offset:-offset, offset:-offset]
                
                probe_NRMSE[probe_id][idx, 0] = util.compute_nrmse(output[0], label[0])
                probe_NRMSE[probe_id][idx, 1] = util.compute_nrmse(output[1], label[1], False)

            # Probe C data generation and prediction
            probe_c_diff, probe_c_label = util.generate_probe_C_grids_data(test_transmissions[idx], step=20)

            output = PtychoFormer(probe_c_diff.unsqueeze(0).to(device))
            output = output[:, offset:-offset, offset:-offset].squeeze().cpu().detach()
            label = probe_c_label[:, offset:-offset, offset:-offset]
            
            probe_c_NRMSE[idx, 0] = util.compute_nrmse(output[0], label[0])
            probe_c_NRMSE[idx, 1] = util.compute_nrmse(output[1], label[1], False)
            progressbar.update(1)

In [None]:
method_width = 15
metric_width = 10
value_width = 6

# Print results for each probe
for probe_id in range(11):
    probe_mean_ph = round(np.mean(probe_NRMSE[probe_id][:, 0]), 3)
    probe_std_ph = round(np.std(probe_NRMSE[probe_id][:, 0]), 3)
    probe_mean_amp = round(np.mean(probe_NRMSE[probe_id][:, 1]), 3)
    probe_std_amp = round(np.std(probe_NRMSE[probe_id][:, 1]), 3)
    
    print(f"Probe {probe_id + 1}")
    print(f"{'PtychoFormer':{method_width}} {'NRMSE':{metric_width}} Ph: {probe_mean_ph:{value_width}} ± {probe_std_ph:<{value_width}} Amp: {probe_mean_amp:{value_width}} ± {probe_std_amp:<{value_width}}")

probe_c_mean_ph = round(np.mean(probe_c_NRMSE[:, 0]), 3)
probe_c_std_ph = round(np.std(probe_c_NRMSE[:, 0]), 3)
probe_c_mean_amp = round(np.mean(probe_c_NRMSE[:, 1]), 3)
probe_c_std_amp = round(np.std(probe_c_NRMSE[:, 1]), 3)

print("\nProbe C")
print(f"{'PtychoFormer':{method_width}} {'NRMSE':{metric_width}} Ph: {probe_c_mean_ph:{value_width}} ± {probe_c_std_ph:<{value_width}} Amp: {probe_c_mean_amp:{value_width}} ± {probe_c_std_amp:<{value_width}}")


Caltech 101

In [None]:
num_test_data = 2
num_caltech_data = 6000 #Change this
assert(num_test_data <= num_caltech_data)
caltech_indices = np.arange(num_caltech_data)
np.random.shuffle(test_img_indices)

data_dir = os.path.join(os.pardir, 'dataset', 'caltech', 'data')

measure_NRMSE = np.zeros((2, num_test_data))

with torch.no_grad():
    with tqdm(total=num_test_data) as progressbar:
        for idx in range(num_test_data):
            data = torch.load(os.path.join(data_dir, f'data_{caltech_indices[idx]}'))
            input = data['input'].to(device).unsqueeze(0)
            output = PtychoFormer(input)
            output = output[:, offset:-offset, offset:-offset].cpu().squeeze().detach()
            label = data['label'][:, offset:-offset, offset:-offset]
            measure_NRMSE[0, idx] = util.compute_nrmse(label[0], output[0], diff_bool=True)
            measure_NRMSE[1, idx] = util.compute_nrmse(label[1], output[1], diff_bool=False)
            progressbar.update(1)


caltech_mean_ph = round(np.mean(measure_NRMSE[0, :]), 3)
caltech_std_ph = round(np.std(measure_NRMSE[0, :]), 3)
caltech_mean_amp = round(np.mean(measure_NRMSE[1, :]), 3)
caltech_std_amp = round(np.std(measure_NRMSE[1, :]), 3)

print("\nCaltech 102")
print(f"{'PtychoFormer':{method_width}} {'NRMSE':{metric_width}} Ph: {caltech_mean_ph:{value_width}} ± {caltech_std_ph:<{value_width}} Amp: {caltech_mean_amp:{value_width}} ± {caltech_std_amp:<{value_width}}")


Flower 102

In [None]:
num_test_data = 2
num_flower_data = 6000 #Change this
assert(num_test_data <= num_flower_data)
flower_indices = np.arange(num_flower_data)
np.random.shuffle(test_img_indices)

data_dir = os.path.join(os.pardir, 'dataset', 'flower', 'data')
measure_NRMSE = np.zeros((2, num_test_data))

with torch.no_grad():
    with tqdm(total=num_test_data) as progressbar:
        for idx in range(num_test_data):
            data = torch.load(os.path.join(data_dir, f'data_{flower_indices[idx]}'))
            input = data['input'].to(device).unsqueeze(0)
            output = PtychoFormer(input)
            output = output[:, offset:-offset, offset:-offset].cpu().squeeze().detach()
            label = data['label'][:, offset:-offset, offset:-offset]
            measure_NRMSE[0, idx] = util.compute_nrmse(label[0], output[0], diff_bool=True)
            measure_NRMSE[1, idx] = util.compute_nrmse(label[1], output[1], diff_bool=False)
            progressbar.update(1)


flower_mean_ph = round(np.mean(measure_NRMSE[0, :]), 3)
flower_std_ph = round(np.std(measure_NRMSE[0, :]), 3)
flower_mean_amp = round(np.mean(measure_NRMSE[1, :]), 3)
flower_std_amp = round(np.std(measure_NRMSE[1, :]), 3)

print("\nFlower 102")
print(f"{'PtychoFormer':{method_width}} {'NRMSE':{metric_width}} Ph: {flower_mean_ph:{value_width}} ± {flower_std_ph:<{value_width}} Amp: {flower_mean_amp:{value_width}} ± {flower_std_amp:<{value_width}}")