# Phase Retrieval Tutorial

In [1]:
import numpy as np
import tensorflow as tf

def H_fresnel_prop(Nx,Ny,dx,dy,z,wavelength): # x is the row coordinate, y is the column coordinate

    Lx = Nx*dx
    Ly = Ny*dy
    
    fx=np.linspace(-1/(2*dx),1/(2*dx)-1/Lx,Nx) #freq coords
    fy=np.linspace(-1/(2*dy),1/(2*dy)-1/Ly,Ny) #freq coords
    
    FX,FY=np.meshgrid(fx,fy, indexing = 'ij')

    H_fresnel=np.exp(-1j*np.pi*wavelength*z*(FX**2+FY**2))
    
    return H_fresnel


def create_H_fresnel_stack(Nx,Ny,dx,dy,z_vec,wavelength):
    H_fresnel_stack = np.zeros([Nx,Ny, len(z_vec)], dtype = np.complex64)
    for i,z in enumerate(z_vec):
        H_fresnel = H_fresnel_prop(Nx,Ny,dx,dy,z,wavelength)
        H_fresnel_stack[:,:,i] = H_fresnel 
    return H_fresnel_stack


def NA_filter(Nx,Ny,dx,dy,wavelength,NA):
    #wavelength is the free space wavelength
    
    Lx = Nx*dx
    Ly = Ny*dy

    k=1./wavelength #wave number 
    fx=np.linspace(-1/(2*dx),1/(2*dx)-1/Lx,Nx) #freq coords
    fy=np.linspace(-1/(2*dy),1/(2*dy)-1/Ly,Ny) #freq coords

    FX,FY=np.meshgrid(fx,fy, indexing = 'ij')
    
    H_NA=np.zeros([Nx,Ny], dtype=np.complex64)
    H_NA[np.nonzero(np.sqrt(FX**2+FY**2)<=NA*k)]=1.

    return H_NA  

def apply_filter_function(u0,H,Nx,Ny,incoherent=False, library=tf):
    #u1 is the source plane field

    if incoherent:
        H=F(Ft(H,Nx,Ny,library)*library.conj(Ft(H,Nx,Ny,library)),Nx,Ny,library)

        U0=F(u0,Nx,Ny,library)

        U1=H*U0
        u1=Ft(U1,Nx,Ny,library)

    else:
        U0=F(u0,Nx,Ny,library)

        U1=H*U0
        u1=Ft(U1,Nx,Ny,library)

    return u1

def fftshift(mat2D, dim0, dim1): #fftshift == ifftshift when dimensions are all even
                                 #fftshift only works with even dimensions

    if (dim0==1) and (dim1==1):
        return mat2D    
    
    if (dim0%2) or (dim1%2):
        raise ValueError('Dimensions must be even to use fftshift.')

    dim0=tf.cast(dim0,tf.int32)
    dim1=tf.cast(dim1,tf.int32)

    piece1=tf.slice(mat2D,[0,0],[dim0//2,dim1//2])
    piece2=tf.slice(mat2D,[0,dim1//2],[dim0//2,dim1//2])
    piece3=tf.slice(mat2D,[dim0//2,0],[dim0//2,dim1//2])
    piece4=tf.slice(mat2D,[dim0//2,dim1//2],[dim0//2,dim1//2])

    top=tf.concat([piece4,piece3],axis=1)
    bottom=tf.concat([piece2,piece1],axis=1)

    final=tf.concat([top,bottom],axis=0)
    return final

#### Define Fourier and Inverse Fourier transform
    
def F(mat2D,dim0,dim1,library=tf):
    if library==tf:
        return fftshift(tf.fft2d(fftshift(mat2D, dim0, dim1)), dim0, dim1)
    elif library==np:
        return np.fft.fftshift(np.fft.fft2(np.fft.fftshift(mat2D)))

def Ft(mat2D,dim0,dim1,library=tf):
    if library==tf:
        return fftshift(tf.ifft2d(fftshift(mat2D, dim0, dim1)), dim0, dim1)
    elif library==np:
        return np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(mat2D)))

In [None]:
#In-class Tutorial on Phase Retrival from 11-08-2018
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from E30_PhaseRetrieval_Functions import NA_filter, create_H_fresnel_stack, apply_filter_function

#Constants
Nx = 2**9
Ny = 2**9
wavelength = 600e-9 #meters
NA = 0.75
dx = 200e-9 #meters
dy = dx
z_vec = np.arange(-10,10,2)*(1e-9) #meters
Nz = len(z_vec)
num_iter = 100
learning_rate = 1e-3
radius = 2**6

# Pure Amplitude Object
obj_0 = np.zeros([Nx,Ny], dtype = np.complex64)
xm, ym = np.meshgrid(range(-Nx//2,Nx//2),range(-Ny//2,Ny//2), indexing = 'ij')

#print(xm)
#print(ym)

obj_0[np.nonzero( (xm**2 + ym**2) < radius**2 )] = 1

plt.figure()
plt.imshow(np.abs(obj_0))
plt.show()

# Pure Phase OBject
obj_1 = np.zeros([Nx,Ny], dtype=np.complex64)
obj_1[np.nonzero( np.abs(xm) < radius/2 )] = 1
obj_1 = obj_1 * np.transpose(obj_1)
obj_1 = np.exp(1j*obj_1*np.pi/2)

#Phase and Amplitude Object
obj_2 = obj_0*obj_1

# Create Pupil Function
H_NA = NA_filter(Nx, Ny, dx, dy, wavelength, NA)

#Create Fresnel transfer function
H_fresnel_stack = create_H_fresnel_stack(Nx, Ny, dx, dy, z_vec, wavelength)

#obj = obj_0
#obj = obj_1
obj = obj_2

for z_ind in range(Nz):
    field = apply_filter_function(obj, H_fresnel_stack[:,:,z_ind]*H_NA, Nx, Ny, incoherent = False, library = np)
    intensity = np.abs(field)**2
    #plt.figure
    #plt.imshow(np.abs(field)**2)
    #plt.show(intensity)
    #plt.colorbar()
    #plt.title('z_ind = ' + str(z_ind))
    #THEORETICAL NOTE:Observe that we still get something at the focal plane, this is because the pupil function still gives
    #us some intensity, even though we are a pure phase object
    
    intensity = np.expand_dims(intensity, axis=-1)
    if z_ind == 0: 
        intensity_stack_actual = intensity
    else:
        intensity_stack_actual = np.concatenate([intensity_stack_actual, intensity], axis=-1)
       
    
# Add Noise (This is Poisson Noise, similar to lab)
pnm = 100
intensity_stack_actual = pnm*intensity_stack_actual
intensity_stack_actual = np.random.poisson(intensity_stack_actual)
intensity_stack_actual = intensity_stack_actual.astype(float)
intensity_stack_actual = intensity_stack_actual/pnm
    
# Iterative phase retrieval

matrix_init = np.ones([Nx,Ny], dtype=np.float32)

with tf.Graph().as_default():
    H_fresnel_stack = tf.constant(H_fresnel_stack, dtype= tf.complex64)
    H_NA = tf.constant(H_NA, dtype = tf.complex64)
    intensity_stack_actual = tf.constant(intensity_stack_actual, dtype=tf.float32)
    obj_guess_real = tf.get_variable('obj_guess_real', dtype = tf.float32, initializer = matrix_init, trainable = True)
    obj_guess_imag = tf.get_variable('obj_guess_imag', dtype = tf.float32, initializer = matrix_init, trainable = True)
    #We use tf.get_variable instead of tf.variable because that allows us to reuse the variable across different CPU
    obj_guess = tf.cast(obj_guess_real, tf.complex64) + 1j*tf.cast(obj_guess_imag, tf.complex64)
                       
    for z_ind in range(Nz):
        #tf.while loop is more efficient than for loop
        field = apply_filter_function(obj_guess, H_fresnel_stack[:,:,z_ind]*H_NA, Nx, Ny, incoherent=False, library=tf)
        intensity = tf.expand_dims(tf.abs(field)**2, axis=-1)
        
        if z_ind == 0:
            intensity_stack_guess = intensity
        else:
            intensity_stack_guess = tf.concat([intensity_stack_guess, intensity], axis=-1)
                       
    MSE = tf.reduce_sum(intensity_stack_guess - intensity_stack_actual)**2/float(Nx*Ny)
    #This is per pixel error unles you add the tf.reduce_sum, and we divide by this term to normalize
    
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train = optimizer.minimize(MSE)
    init_op = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init_op)
        
        for i in range(num_iter):
            [_,MSE_0, obj_guess_real_0, obj_guess_imag_0, intensity_stack_guess_0] = sess.run([train, MSE, obj_guess_real, \
                                                            obj_guess_imag, intensity_stack_guess])
            print('iter' + str(i))
            print(MSE_0)
    
H_NA = NA_Filter(Nx, Ny, dx, dy, wavelength, NA)
obj_guess_0 = obj_guess_real_0 + 1j*obj_guess_imag_0
obj_filtered = apply_filter_function(obj, H_NA, Nx, Ny, incoherent=False, library=np)
                       
plt.figure()
plt.title('Amplitude of Guess')
plt.imshow(np.abs(obj_guess_0))
plt.colorbar()      
    
plt.figure()
plt.title('Actual Amplitude')
plt.imshow(np.abs(obj))      
plt.colorbar() 

plt.figure()
plt.title('Phase of Guess')
plt.imshow(np.angle(obj_guess_0))      
plt.colorbar() 

plt.figure()
plt.title('Actual Phase')
plt.imshow(np.angle(obj))  
plt.colorbar()

### Try this code with noise
### Try the other objects (pure phase and pure amplitude)