In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
%reload_ext autoreload
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from multiprocessing import Pool

from phase_retrieval_jax import run_diffmap_algo, run_altproj_algo
from image_generator import gen_image
from register_to_reference import register_to_reference

In [None]:
def process_images(image, ref_image, support, dft_mag):

    projB = proj_supp(image, support)
    final_image = proj_A(2 * projB - image, dft_mag_img)
    final_image = register_to_reference(final_image, ref_image)
    
    return final_image

def process_results(results, initial_images, ref_image, support, dft_mag):

    images = np.array([results[i][0] for i in range(n_trials)])
    recons = np.array([results[i][1] for i in range(n_trials)])
    residues = np.array([results[i][2] for i in range(n_trials)])
    
    with Pool(32) as pool:
        aligned_images = pool.map(
            partial(process_images, ref_image=ref_image, support=support, dft_mag=dft_mag),
            images
        )
        aligned_initial = pool.map(partial(register_to_reference, image_ref=ref_image), initial_images)
        aligned_recon = pool.map(partial(register_to_reference, image_ref=ref_image), recons)
        
    return np.array(aligned_initial), np.array(aligned_images), np.array(aligned_recon), residues

def save_results(filename, images, recons, real_image, initial_images, residues):
    
    np.savez(
        filename,
        images=images,
        recons=recons,
        initial_images=initial_images,
        residues=residues,
        real_image=real_image
    )
    
    return

In [29]:
n_pixels = 32
n_trials = 2000
n_disks = 50

n_pad = 3
k = 2
supp_neigh = 4

real_image, support = gen_image(
    n_pixels=n_pixels, n_pad=n_pad, k=k, n_disks=n_disks, seed=0, supp_neigh=supp_neigh
)

real_image /= np.linalg.norm(real_image)

dft_mag_img = np.abs(np.fft.fftn(real_image))

In [None]:
init_images = np.fft.ifft2(
    dft_mag_img[None, :, :]
    * np.exp(2j * np.pi * np.random.rand(n_trials, n_pixels * n_pad, n_pixels * n_pad))
).real

In [27]:
with Pool() as pool:
    results = pool.map(
        partial(
            run_diffmap_algo, n_iter=2000, exp_data=dft_mag_img,
            ref_image=real_image, aux="nonneg", support_mask=support),
        init_images
    )

In [None]:
initial_images, final_images, final_recons, residues = process_results(
    results,
    init_images,
    real_image,
    support,
    dft_mag_img
)

In [None]:
save_results(
    f"results_k{k}_nonneg_p3.npz",
    images=final_images,
    recons=final_recons,
    real_image=real_image,
    initial_images=initial_images,
    residues=residues
)