In [1]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from miniscope_utils_tf import *
#import utils as krist
import scipy.misc as sc
from skimage.transform import resize as imresize
%matplotlib inline
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))
from IPython import display
import scipy.ndimage as ndim
import scipy.misc as misc
from scipy import signal
from os import listdir
from os.path import isfile, join
import matplotlib.animation as animation
from os import listdir
from os.path import isfile, join
#import copy
#from bridson import poisson_disc_samples


In [2]:
#model and loss
class Model(tf.keras.Model):
    def __init__(self, ):
        super(Model, self).__init__()
        
        self.samples = (768,768)   #Grid for PSF simulation

        # min and max lenslet focal lengths in mm
        self.fmin = 6
        self.fmax = 20
        self.ior = 1.56
        self.lam=510e-6
        # Min and max lenslet radii
        self.Rmin = self.fmin*(self.ior-1)
        self.Rmax = self.fmax*(self.ior-1)

        # Convert to curvatures
        self.cmin = 1/self.Rmax
        self.cmax = 1/self.Rmin
        self.xgrng = np.array((-1.8, 1.8)).astype('float32')    #Range, in mm, of grid of the whole plane (not just grin)
        self.ygrng = np.array((-1.8, 1.8)).astype('float32')

        self.t = 10    #Distance to sensor from mask in mm

        #Compute depth range of virtual image that mask sees (this is assuming an objective is doing some magnification)

        self.zmin_virtual = 1/(1/self.t - 1/self.fmin)
        self.zmax_virtual = 1/(1/self.t - 1/self.fmax)
        self.CA = .9; #semi clear aperature of GRIN
        self.mean_lenslet_CA = .2 #average lenslest semi clear aperture in mm. 
            
        #Getting number of lenslets and z planes needed as well as defocus list
        self.ps = (self.xgrng[1] - self.xgrng[0])/self.samples[0]
        self.Nlenslets=np.int(np.floor((self.CA**2)/(self.mean_lenslet_CA**2)))
        self.Nz = np.ceil(np.sqrt(self.Nlenslets*2)).astype('int') #number of Zplanes 
        self.defocus_list = 1/(np.linspace(1/self.zmin_virtual, 1/self.zmax_virtual, self.Nz)) #mm or dioptres
        
        #initializing the x and y positions
        [xpos,ypos, rlist]=poissonsampling_circular(self)
        
        self.rlist = tf.constant(rlist, dtype = tf.float32)
        self.xpos = tfe.Variable(xpos, name='xpos', dtype = tf.float32)
        self.ypos = tfe.Variable(ypos, name='ypos', dtype = tf.float32)
        
        #parameters for making the lenslet surface
        self.yg = tf.constant(np.linspace(self.ygrng[0], self.ygrng[1], self.samples[0]),dtype=tf.float32)
        self.xg=tf.constant(np.linspace(self.xgrng[0], self.xgrng[1], self.samples[1]),dtype=tf.float32)
        self.px=tf.constant(self.xg[1] - self.xg[0],tf.float32)
        self.py=tf.constant(self.yg[1] - self.yg[0],tf.float32)
        self.xgm, self.ygm = tf.meshgrid(self.xg,self.yg)

        #PSF generation parameters
        self.lam=tf.constant(510.*10.**(-6.),dtype=tf.float32)
        self.k = np.pi*2/self.lam
        
        fx = tf.constant(np.linspace(-1/(2.*self.ps),1/(2.*self.ps),self.samples[1]),dtype=tf.float32)
        fy = tf.constant(np.linspace(-1/(2.*self.ps),1/(2.*self.ps),self.samples[0]),dtype=tf.float32)
        self.Fx,self.Fy = tf.meshgrid(fx,fy)
        self.field_list = tf.constant(np.array((0, 0)).astype('float32'))
        
    def call(self, inputs):
        T,aper=make_lenslet_tf(self) #ADD offset
        return T



In [3]:
model = Model()

In [None]:
freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3)
freqout=crop2d(freqs,np.array((2,2)))
freqout

In [None]:

plt.figure()
plt.imshow(amp)


In [None]:
test=tf.to_float(amp)
print(test)
plt.imshow(np.real(test))

In [None]:
plt.figure()
plt.imshow(np.real(U_out2-U_out))
plt.colorbar()


In [None]:
def propagate_field_freq(U,model,padfrac=0):
    if padfrac != 0:
        shape_orig = np.shape(U)
        U = pad_func(U, padfrac)
        Fx, Fy = np.meshgrid(np.linspace(np.min(Fx), np.max(Fx), U.shape[0]), np.linspace(np.min(Fy), np.max(Fy), U.shape[1]))
        
    Uf = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(U)))
    Hf = np.exp(1j*2*np.pi*z/lam * np.sqrt(1-(lam*Fx)**2 - (lam*Fy)**2))
    Up = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(Uf*Hf)))
    if padfrac != 0:
        Up = crop_func(Up, shape_orig)
    return Up


In [None]:
# Get psfs
def propagate_field_freq(lam, z, U, Fx, Fy, padfrac=0):
    k = 2*np.pi/lam

    #siz = np.shape(U)
    #fx = np.linspace(-1/2/ps,1/2/ps,siz[1])
    #fy = np.linspace(-1/2/ps,1/2/ps,siz[0])
    #x = np.linspace(-siz[1]/2*ps,siz[1]/2*ps,siz[1])
    #y = np.linspace(-siz[0]/2*ps,siz[0]/2*ps,siz[0])
    #X,Y = np.meshgrid(x,y)
    #Fx,Fy = np.meshgrid(fx,fy)
    if padfrac != 0:
        shape_orig = np.shape(U)
        U = pad_func(U, padfrac)
        Fx, Fy = np.meshgrid(np.linspace(np.min(Fx), np.max(Fx), U.shape[0]), np.linspace(np.min(Fy), np.max(Fy), U.shape[1]))
        
    Uf = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(U)))
    Hf = np.exp(1j*2*np.pi*z/lam * np.sqrt(1-(lam*Fx)**2 - (lam*Fy)**2))
    Up = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(Uf*Hf)))
    if padfrac != 0:
        Up = crop_func(Up, shape_orig)
    return Up

def gen_psf_ag_tf(T,model, obj_def, field, pupil_phase=0, prop_pad = 0):
    # Inputs:
    # surface: single surface thickness function, units: mm
    # ior : index of refraction of bulk material
    # t : thickness of surface (i.e. distance to output plane)
    # z_obj : distance from object plane. +Inf means object at infinity
    # object_def : 'angle' for angular field definition, 'obj_height' for finite
    # field : tuple (x,y) wavefront field definition. (0,0) is on-axis. Interpreted in context of object_def
    # CA: radius of clear aperture in mm
    # pupil_aberration: additional pupil phase, in radians!
    # lmbda: wavelength in mm
    # xg and yg are the spatial grid (pixel spacing in mm)
    # Fx : frequency grid in 1/mm
    # Fy : same as Fx  




    
    
    if obj_def is 'angle':
        ramp_coeff_x = -tf.tan(model.field_list[0]*np.pi/180)
        ramp_coeff_y = -tf.tan(model.field_list[1]*np.pi/180)
        ramp = model.xgm*ramp_coeff_x + model.ygm*ramp_coeff_y
        if model.defocus_list[3] is 'inf':
            U_in = tf.exp(-model.k*(ramp))
        else:
            U_in = tf_exp(model.k*(model.defocus_list[3] - model.defocus_list[3]*tf.sqrt(1+tf.square(model.ygm/model.defocus_list[3])+tf.square(model.xgm/model.defocus_list[3])) + ramp)) #negative already included
    
    elif obj_def is 'obj_height':
        if model.defocus_list[3] is 'inf':
            raise Exception('cannot use obj_height and object at infinity')
        else:
            U_in = np.exp(1j*-z_obj*k*np.sqrt(1-((xg-field[0])/z_obj)**2 - ((yg-field[1])/z_obj)**2))
    
    U_out = U_in * tf_exp(-(model.k*(model.ior-1)*T + pupil_phase))
    amp = tf.to_float(tf.sqrt(tf.square(model.xgm) + tf.square(model.ygm)) <= model.CA)
    U_prop = propagate_field_freq(lmbda, t, amp*U_out, Fx, Fy, padfrac = 0)
    
    psf = tf.abs(U_prop)**2
    return(psf/np.sum(psf)) #DO WE NEED TO DO THIS????
    



In [None]:

zstack = []
f=plt.figure()
for defocus in defocus_list:
    for field in field_list:
        fld = tuple([2*field,0])
        zstack.append(gen_psf_ag(T1, ior, t, defocus, 'angle',fld,.9,510e-6, xg, yg, Fx, Fy, 0, prop_pad = .5))

        plt.imshow(zstack[-1])
        display.display(f)
        display.clear_output(wait=True)


In [None]:
psf_spect = np.fft.fft2(zstack,norm='ortho')
Rmat = np.zeros((Nz,Nz))
for z1 in range(Nz):
    for z2 in np.r_[z1:Nz]:
      
        Fcorr = np.conj(psf_spect[z1])*psf_spect[z2]
        Rmat[z1,z2] = np.sum(np.abs(Fcorr)**2)
        
Rmat = np.transpose(Rmat)*(Rmat==0) + Rmat
plt.imshow(Rmat,vmin=0)

In [None]:
#tensor flow functions
tf_pi=np.pi
def tf_fftshift(spectrum):
    out=fftshift
def fftshift(spectrum, axis=-1):
  try: 
    shape = spectrum.shape[axis].value
  except:
    shape = None
  if shape is None:
    shape = tf.shape(spectrum)[axis]
  # Match NumPy's behavior for odd-length input. The number of items to roll is
  # truncated downwards.
  b_size = shape // 2
  a_size = shape - b_size
  a, b = tf.split(spectrum, [a_size, b_size], axis=axis)
  return tf.concat([b, a], axis=axis)
def propagate_field_freq_tf(lam, z, U, Fx, Fy, padfrac=0):
    k = 2*tf_pi/lam

    #siz = np.shape(U)
    #fx = np.linspace(-1/2/ps,1/2/ps,siz[1])
    #fy = np.linspace(-1/2/ps,1/2/ps,siz[0])
    #x = np.linspace(-siz[1]/2*ps,siz[1]/2*ps,siz[1])
    #y = np.linspace(-siz[0]/2*ps,siz[0]/2*ps,siz[0])
    #X,Y = np.meshgrid(x,y)
    #Fx,Fy = np.meshgrid(fx,fy)
    if padfrac != 0:
        shape_orig = np.shape(U)
        U = pad_func(U, padfrac)
        Fx, Fy = tf.meshgrid(tf.linspace(tf.minimum(Fx), tf.maximum(Fx), U.shape[0]), tf.linspace(tf.minimum(Fy), tf.maximum(Fy), U.shape[1]))
        
    U=tf.complex(U,0.)
    Uf = tf_ifftshift(tf.fft2d(tf_ifftshift(U)))
    Hf = tf.exp(1j*2*tf_pi*z/lam * tf.sqrt(1-tf.square(lam*Fx) - tf.square(lam*Fy)))
    print(Hf)
    Up = tf_ifftshift(tf.ifft2d(tf_ifftshift(Uf*Hf)))
    if padfrac != 0:
        Up = crop_func(Up, shape_orig)
    return Up

def gen_psf_ag_tf(surface, ior, t, z_obj, obj_def, field, CA, lmbda, xg, yg, Fx, Fy,pupil_phase=0, prop_pad = 0):
    # Inputs:
    # surface: single surface thickness function, units: mm
    # ior : index of refraction of bulk material
    # t : thickness of surface (i.e. distance to output plane)
    # z_obj : distance from object plane. +Inf means object at infinity
    # object_def : 'angle' for angular field definition, 'obj_height' for finite
    # field : tuple (x,y) wavefront field definition. (0,0) is on-axis. Interpreted in context of object_def
    # CA: radius of clear aperture in mm
    # pupil_aberration: additional pupil phase, in radians!
    # lmbda: wavelength in mm
    # xg and yg are the spatial grid (pixel spacing in mm)
    # Fx : frequency grid in 1/mm
    # Fy : same as Fx
    k = tf_pi*2/lmbda
    
    if obj_def is 'angle':
        ramp_coeff_x = -tf.tan(field[0]*tf_pi/180)
        ramp_coeff_y = -tf.tan(field[1]*tf_pi/180)
        ramp = xg*ramp_coeff_x + yg*ramp_coeff_y
        if z_obj is 'inf':
            U_in = tf.exp(1j*k*(ramp))
        else:
            #U_in = np.exp(1j*k*(z_obj*np.sqrt(1-(xg/z_obj)**2 - (yg/z_obj)**2) + ramp))
            U_in = tf.exp(1j*k*(z_obj - z_obj*tf.sqrt(1+tf.square(yg/z_obj)+tf.square(xg/z_obj)) + ramp))
    elif obj_def is 'obj_height':
        if z_obj is 'inf':
            raise Exception('cannot use obj_height and object at infinity')
        else:
            U_in = tf.exp(1j*-z_obj*k*tf.sqrt(1-tf.square((xg-field[0])/z_obj) - tf.square((yg-field[1])/z_obj)))
    
    U_out = U_in * tf.exp(1j*(k*(ior-1)*surface + pupil_phase))
    amp = tf.sqrt(tf.square(xg) + tf.square(yg)) <= CA
    amp=tf.to_float(amp)
    U_prop = propagate_field_freq_tf(lmbda, t, amp*U_out, Fx, Fy, padfrac = 0)
    
    psf = tf.square(tf.abs(U_prop))
    return(psf/tf.reduce_sum(psf))
    
def pad_func_tf(x, padfrac):
    if np.shape(padfrac) == ():
        if x.ndim is 2:
            padfrac = ((padfrac, padfrac), (padfrac, padfrac))
        elif x.ndim is 3:
            #If x is 3D and pad a single pad value was passed in, assume padding on last 2 dims only
            padfrac = ((0,0), (padfrac, padfrac), (padfrac, padfrac))
        
        
    padr = [];
    for n in range(x.ndim):
        pwpre = tf.ceil(padfrac[n][0]*x.shape[n]).astype('int')
        pwpost = tf.ceil(padfrac[n][1]*x.shape[n]).astype('int')
        padr.append((pwpre,pwpost))
        #x = zero_pad_ag(x, padr[n], n-2)
    #print("Padr:")
    #print(padr)
    return tf.pad(x,padr,'constant')

def crop_func_tf(x,crop_size):
    # Crops the center matching the size in the tuple crop_size. Implicitly deals with higher dimensions?
    cstart = []
    cent = []
    for n in range(x.ndim):
        cstart.append((x.shape[n]-crop_size[n])//2)
    slicer = tuple(slice(cstart[n],cstart[n]+crop_size[n],1) for n in range(len(crop_size)))
    return(x[slicer])


def make_lenslet_tf(Xlist, Ylist, Rlist, xrng, yrng, samples,aperR,r_lenslet):
        T = tf.zeros([samples[0],samples[1]])
        Nlenslets = np.shape(Xlist)[0]
        xgo = np.linspace(xrng[0], xrng[1], samples[1])
        ygo = np.linspace(yrng[0], yrng[1], samples[0])
        yg = tf.constant(ygo,dtype=tf.float32)
        xg=tf.constant(xgo,dtype=tf.float32)
        px = xg[1] - xg[0]
        py = yg[1] - yg[0]
        px_tf=tf.constant(px,tf.float32)
        py_tf=tf.constant(py,tf.float32)
        xg, yg = tf.meshgrid(xg,yg)
        aperR_tf=tf.constant(aperR,tf.float32)
        for n in range(Nlenslets):
            sph1 = tf.real(tf.sqrt(tf.square(Rlist[n]) - tf.square((xg-Xlist[n])) - tf.square((yg-Ylist[n]))))-tf.real(tf.sqrt(tf.square(Rlist[n])-tf.square(r_lenslet)))
            #sph = np.real(np.sqrt(0j+Rlist[n]**2 - (xg-Xlist[n])**2 - (yg-Ylist[n])**2))-np.real(np.sqrt(0j+Rlist[n]**2 - r_lenslet**2))
            T = tf.maximum(T,sph1)
        aper = tf.sqrt(xg**2+yg**2) <= aperR_tf
        return T, aper, px, py

In [None]:
#Using Tensor Flow
#TF constants

# Setup constants
samples_tf =tf.constant(samples,tf.float32) #Grid for PSF simulation

# min and max lenslet focal lengths in mm
fmin_tf = tf.constant(fmin,tf.float32)
fmax_tf = tf.constant(fmax,tf.float32)
ior_tf = tf.constant(ior,tf.float32)
lam_tf=tf.constant(lam,tf.float32)
# Min and max lenslet radii
Rmin_tf = fmin_tf*(ior_tf-1)
Rmax_tf = fmax_tf*(ior_tf-1)

# Convert to curvatures
cmin_tf = 1/Rmax_tf
cmax_tf = 1/Rmin_tf

# Number of lenslets to start with. These will be either clipped by or snapped to an aperture function!
#Nlenslets = 35

xgrng = (-1.8,1.8)    #Range, in mm, of grid of the whole plane (not just grin)
ygrng = (-1.8,1.8)

t_tf = tf.constant(t,tf.float32)    #Distance to sensor from mask in mm

#Compute depth range of virtual image that mask sees (this is assuming an objective is doing some magnification)
Nz_tf = tf.constant(Nz,tf.int32) #number of Zplanes 
zmin_virtual_tf = 1/(1/t_tf - 1/fmin_tf)
zmax_virtual_tf = 1/(1/t_tf - 1/fmax_tf)
CA_tf = tf.constant(CA,tf.float32) #semi clear aperature of GRIN
mean_lenslet_CA_tf = tf.constant(mean_lenslet_CA,tf.float32) #average lenslest semi clear aperture in mm. 
defocus_list_tf =tf.constant(defocus_list,tf.float32)
ps = (xgrng[1] - xgrng[0])/samples[0]

theta=np.concatenate((xpos,ypos), axis=0)
xpos_tf=tf.constant(xpos,tf.float32)
ypos_tf=tf.constant(ypos,tf.float32)
rlist_tf=tf.constant(rlist,tf.float32)
xgrng_tf=tf.constant(xgrng,tf.float32)
ygrng_tf=tf.constant(ygrng,tf.float32)

xg_tf = tf.constant(xg,tf.float32)
yg_tf = tf.constant(yg,tf.float32)


fx_tf = tf.constant(fx,tf.float32)
fy_tf = tf.constant(fy,tf.float32)
Fx_tf=tf.constant(Fx,tf.float32)
Fy_tf=tf.constant(Fy,tf.float32)
field_list_tf = tf.constant(field_list,tf.float32)
zstack_tf =tf.constant(zstack,tf.float32)


In [None]:
#model and loss
class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.theta_tf = tfe.Variable(theta,dtype=tf.float32)  #the point at the end of 20 really matters
    def call(self, inputs):
        T1_tf, r, px, py = make_lenslet_tf(self.theta_tf[0:20], self.theta_tf[20:40],rlist_tf, xgrng_tf,ygrng_tf,samples, CA_tf,mean_lenslet_CA_tf)
        for defocus in defocus_list_tf:
            for field in field_list_tf:
                fld = tuple([2*field,0])
                zstack.append(gen_psf_ag_tf(T1_tf, ior_tf, t_tf, defocus, 'angle',fld,CA_tf,lam_tf, xg_tf, yg_tf, Fx_tf, Fy_tf, 0, prop_pad = .5))
        return zstack
        #aper_mask=tf.cast(r, tf.float32)
        ##ui_plane=aper_mask*tf.exp(1j*2*pi_tf*(ior_tf-1)*lsurf/lam_tf)
        #fx= np.linspace(-1/2/px,1/2/px,1000)
        #fy= np.linspace(-1/2/py,1/2/py,500)
        #Fx,Fy = np.meshgrid(fx,fy)    
        #Fx_tf = tf.constant(Fx,dtype=tf.float32)
        #Fy_tf = tf.constant(Fy,dtype=tf.float32)
        #U_tf_r = tf.cos(k_tf * lsurf)
        #U_tf_i = tf.sin(k_tf * lsurf)
        #U_tf = tf.complex(aper_mask*U_tf_r, aper_mask*U_tf_i)
        #U_prop_tens = prop_tensorflow(z_tf, U_tf,Fx_tf,Fy_tf)
        #I_tf = tf.real(tf.conj(U_prop_tens) * U_prop_tens)
        #I_tf = I_tf/tf.norm(I_tf) #this step does not matter as much
        #H_tensor = tf.fft2d(tf.complex(tf.pad(I_tf, PADDING_2d, 'constant'), tf.zeros([DIMS0*2, DIMS1*2], dtype=tf.float32)))
        #H_tensor_adj = tf.conj(H_tensor)
        #HT_tensor = tf.conj(H_tensor)
        #sim=f(targets,H_tensor)
        #sim=sim/tf.reduce_max(sim)
        #sim=sim/tf.norm(sim)
        #return inputs*self.theta_tf   
model = Model()
inputs=1
y=model(inputs)
plt.figure()
plt.imshow(y)

In [None]:
freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3)
freqs
freqnpshift=np.fft.fftshift(freqs)
freqnpshift

In [None]:
freqshift=tf_fftshift(freqs,0)
freqshift=tf_fftshift(freqshift,1)
freqshift
#back=np.fft.ifftshift(freqshift)
#back

In [None]:
#Hf = tf.exp(1j*2*tf_pi*z_obj/lam_tf * tf.sqrt(1-tf.square(lam_tf*Fx_tf) - tf.square(lam_tf*Fy_tf)))
Hf = tf.exp(1j*2*tf_pi*z_obj/lam_tf)
print(Hf)

In [None]:
def loss(model, inputs, targets):
    error = model(inputs) - targets
    return tf.reduce_sum(tf.square(error))
def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets)
    return tape.gradient(loss_value, [model.Rlist])
