In [1]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
# %load_ext autoreload
# %autoreload 2

The idea here is to fit an optical model of the lenslets to the captured PSF data. The idea is to give background- and noise-free impulse responses. This will be built on the tensorflow core code that runs the PSF optimization.

The problem statement is basically:



$$\arg\min_{\theta_l,\theta_M} \sum_{z=0}^{N_z-1} \left\lVert\left|F^{-1}diag(P_L)F\left[U_z(x,y|\theta_M).*\exp(-j\phi(x,y|\theta_l)\right]\right|^2 - b(x,y|z)\right\lVert$$

where $\theta_l$ represents a parameterized phase mask (i.e. lenslet locations, zernikes), $\theta_M$ represents the parameters of the miniscope build, including mask rotations and tip/tilt. $P_L$ is propagation by distance L (in frequency space), $F$ is DFT, $U_z$ is the wavefront in the pupil due to a point source at plane $z$ in front of the GRIN, and $\phi$ is the phase of the pupil-plane mask. $b(x,y|z)$ is the PSF measured from the as-built system with a point source at distance z-plane $z$.

To make this work, initialization is crucial. We will initialize the model with the background subtracted zstack taken from the as-built nanoscribe-based miniscope (``model.target_psf``). We will then generate PSFs with manually entered rotations and focus until a qualitatively close match is found between the recipe and the measured PSF. Once that has been found, we'll use feature matching + ransac (or similar) to find a homography between the measured PSF and the predicted one. The shift an rotation can be directly applied to the coordinates of the lenslets in the design. The magnification will be used to fine-tune the actual object distances (since scale approximately maps to z for small defocuses). With that homography, we'll have the parameters necessary to initialize the lenslet fitting problem fairly well. 

To summarize:
1. Initialize miniscope model with calibration stack as `model.target_psf`

2. Manually find rotation/focus to match a single measured PSF to a simulated one (qualitatively)
    2. Manually rotate `model.xpos` and `model.ypos`
    2. reinitialize model with rotated coordinates
    2. generate rotated surface function using `tf_utils.make_lenslet_tf_zern(model)` 
    2. generate zstacks using `model.gen_psf_stack`
    2. Repeat until good match is found

3. Find homgraphy between points generated in 2 and target psf
4. Use parameters of homography to update model
5. Fine tune locations, radii (and possibly zernikes) using gradient descent




In [2]:
import matplotlib.pyplot as plt
import numpy as np
import miniscope_utils_tf as tf_utils
from miniscope_model import Model as msu_model
import scipy as sc
import scipy.ndimage as ndim
import scipy.misc as misc
from scipy import signal
import scipy.io
from skimage.transform import resize as imresize
%matplotlib inline
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))
from IPython import display
import cv2 as cv2

import os
from os import listdir
from os.path import isfile, join
import matplotlib.animation as animation

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [4]:
def re_init_model(model_in, xi, yi, ri, zi,di=-1):
    if di is -1:
        di = tf.zeros(model_in.Nz,tf.float64)
    model_in.xpos.assign(xi)
    model_in.ypos.assign(yi)

    #model_in.lenslet_offset.assign(offsetinit)
    model_in.rlist.assign(ri)
    model_in.defocus_offset.assign(di)
    if tf.not_equal(tf.size(model.zernlist),0):
        model_in.zernlist.assign(zi)

In [5]:
zernikes_index = []
#model = msu_model(Nlenslets = 29, aberrations = False, zernikes = zernikes_index,loss_type='psf_error',psf_scale=1e2,
#                  lenslet_CA=0.2,lenslet_spacing = 'uniform')  # zsampling options: 'fixed' or 'uniform_random'

# model = msu_model(Nlenslets = 37, aberrations = False, zernikes = zernikes_index,loss_type='psf_error',psf_scale=1e2,
#                   lenslet_CA=0.2e3,lenslet_spacing = 'uniform',psf_file='../psf_meas/nanoscribev1_zstack_june262019_2x_dz40_shifted_center_interp.mat',incoherent_source = True)

#../psf_meas/psf_crop_nanoscribe_v1_re-registered.mat
model = msu_model(Nlenslets = 37, aberrations = False, zernikes = zernikes_index,loss_type='psf_error',psf_scale=1e2,
                  lenslet_CA=0.2e3,lenslet_spacing = 'uniform',psf_file='../psf_meas/psf_crop_nanoscribe_v1_re-registered.mat',incoherent_source = True)

#psf_file='../psf_meas/nanoscribev1_zstack_june262019_2x_dz40_shifted.mat'

In [6]:
model(0)

<tf.Tensor: id=3735, shape=(10,), dtype=float64, numpy=
array([98.59672453, 99.24510999, 99.0255113 , 97.81098796, 95.70625811,
       91.62814048, 86.33182633, 81.46963057, 76.66291931, 66.60136769])>

In [7]:
# targ_psf2 = model.target_psf
# targ_psf2 = [cv2.resize(model.target_psf[n].numpy()[model.samples[0]//4:3*model.samples[0]//4,model.samples[1]//4:3*model.samples[1]//4],(model.samples[1],model.samples[0])) for n in range(model.Nz)]
# fig, ax = plt.subplots(1,2,figsize=(15,5))
# n = 0
# print(np.shape(model.target_psf[n].numpy()[model.samples[0]//4:3*model.samples[0]//4,model.samples[1]//4:3*model.samples[1]//4]))
# ax[0].imshow(model.target_psf[n])
# ax[1].imshow(targ_psf2[n])
# out_dict2 = {'zstack' : targ_psf2}
# sc.io.savemat('../psf_meas/nanoscribev1_zstack_june262019_2x_dz40_shifted.mat',out_dict2)

In [19]:
# Load in the x,y,r,zernikes from a previous run


#model_orig = sc.io.loadmat("../psf_meas/zstack_sim_test.mat")
#model_orig = sc.io.loadmat("../psf_meas/zstack_nanoscribe_recipe.mat")
model_orig = sc.io.loadmat('../l1_GOOD_scaled_nanoscribe_v1_20190813_150813.mat')
# xinit = model_orig['xpos'][0].astype(np.float64)*1e3
# yinit = model_orig['ypos'][0].astype(np.float64)*1e3
# rinit = model_orig['r'][0].astype(np.float64)*1e3

xinit = model_orig['xpos'][0].astype(np.float64)
yinig = model_orig['ypos'][0].astype(np.float64)
rinit = model_orig['rlist'][0].astype(np.float64)
zlist = model_orig['defocus_list'][0]
zerrlist = model_orig['defocus_correction'][0]
#test = model_orig['zern_list']
zerninit = test



[[(array([[1]], dtype=uint8),)]]


In [None]:
# Manual initialization with nanoscribed recipe
# Randoscope nanoscribe v1 recipe: 

# xpos=np.array([ 0.18578115,  0.74248238,  0.35555949,  0.10187365, -0.22157956,
#        -0.5502,  0.29999206,  0.68421244,  0.11299734,  0.3125789 ,
#         0.69618861, -0.28096646,  0.45888084, -0.27032854,  0.07697116,
#        -0.5625567 , -0.58, -0.22702032,  0.07976543, -0.70372301,
#        -0.34797056, -0.1152197 ,  0.41478314, -0.49736261,  0.39453689,
#        -0.09447947, -0.08091086, -0.40003641,  0.04180626,  0.49347817,
#         0.25394288,  0.02494959,  0.53552421,  0.28558491,  0.61067326,
#        -0.72650046,-0.32])

# ypos=np.array([-0.44942685,  0.0187904 , -0.08720638,  0.7031471 ,  0.11111811,
#         0.2784 ,  0.16797938, -0.22899178, -0.07601691,  0.67472988,
#         0.23505869,  0.65642641,  0.52935778,  0.36883709,  0.1762651 ,
#         0.01505985, -0.4, -0.69954257,  0.41217862, -0.2266678 ,
#        -0.06487182,  0.51443979, -0.30910161,  0.51766252, -0.53589769,
#        -0.11244721, -0.46289246, -0.56147136, -0.29425527,  0.2757895 ,
#        -0.69100594, -0.64317962,  0.01663745,  0.40540164, -0.42807524,
#         0.168447,-0.3  ])

# rlist=np.array([4.000442 , 4.440296 , 3.7237842, 3.5596673, 3.2063835, 5.4982405,
#        4.2089086, 3.0839548, 5.3172565, 3.1439776, 6.364798 , 3.4829168,
#        2.767    , 4.3215075, 4.839368 , 4.988825 , 2.8152227, 5.1478076,
#        3.409406 , 6.9092703, 4.6986055, 3.2713168, 5.691979 , 4.5658   ,
#        4.1020284, 2.970532 , 6.625868 , 7.218    , 3.8116512, 5.8998694,
#        2.865156 , 3.6398768, 3.9037654, 2.9168925, 3.3389342, 3.026181 ,
#        6.123522 ])



In [None]:
re_init_model(model,xinit, yinit, rinit, zerninit)
#re_init_model(model, xpos,ypos,rlist,zerninit)   # Initialize model with manually entered values


In [None]:
#model = msu_model(target_res=0.004,aberrations = True)  # zsampling options: 'fixed' or 'uniform_random'
load_init_from_file = False
# Load initialization from file
if load_init_from_file == True:
    print('loading initilization from file')
    file_best = '/media/hongdata/Kristina/MiniscopeData/best_init.mat'
    file_worst = '/media/hongdata/Kristina/MiniscopeData/worst_init.mat'
    model = load_model_from_file(model, file_best)
    
# Save initial values for later comparison 
re_init_model(model,xinit, yinit, rinit*1.1, zerninit)
model.defocus_offset.assign(tf.constant(0*np.random.randn(model.Nz)*1e-5, tf.float64))
Rmat=model(0)

R_init= Rmat.numpy()
Tinit,aper,_= tf_utils.make_lenslet_tf_zern(model)
Tinit = Tinit.numpy()
aper = aper.numpy()
#model_init = msu_model(target_res=0.004, aberrations = True, zernikes = zernikes_index)
# model_init = msu_model(Nlenslets = 37, aberrations = False, zernikes = zernikes_index,loss_type='psf_error',psf_scale=1e2,
#                   lenslet_CA=0.2e3,lenslet_spacing = 'uniform',psf_file='../psf_meas/nanoscribev1_zstack_june262019_2x_dz40_shifted.mat')
#re_init_model(model_init,xinit,yinit,rinit,zerninit)

Rmat_init = model(0)

fig=plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.imshow(Tinit)
plt.subplot(1,3,2)
plt.plot(R_init)




In [None]:
model.lenslet_offset = tf.zeros(model.Nlenslets,tf.float64)

In [None]:
# psf_norm_ave = [np.sum(model.target_psf[n].numpy()) for n in range(model.Nz)]
# print(psf_norm_ave)
# psf_out_normed = [model.target_psf[n].numpy() / psf_norm_ave[n] for n in range(model.Nz)]
# print(np.sum(psf_out_normed[0]))
# out_dict2 = {'zstack': psf_out_normed}
# sc.io.savemat('../psf_meas/nanoscribev1_zstack_june262019_2x_dz40_shifted_center_interp.mat', out_dict2)

In [None]:
# psf_zstack =  model.gen_psf_stack(Tinit,.9,0)
# psf_zstack_lst = [np.array(psf_zstack[n].numpy()/model.psf_scale) for n in range(model.Nz)]
# psf_zstack_arr = np.array(psf_zstack_lst)
# save_dict = {
#     'zstack':psf_zstack_arr,
#     'xpos':model.xpos.numpy(),
#     'ypos':model.ypos.numpy(),
#     'r':model.rlist.numpy(),
#     'zerns':model.zernikes,
#     'zern_list':model.zernlist
# }
# sc.io.savemat(,save_dict)

In [None]:
xp = xinit# + tf.random_normal(tf.shape(xinit), stddev=.02)
yp = yinit# + tf.random_normal(tf.shape(xinit), stddev=.02)
rp = rinit# + tf.random_normal(tf.shape(xinit), stddev=.5)
rp = tf.minimum(rp,model.Rmax)
tp = tf.maximum(rp,model.Rmin)
re_init_model(model,xp,yp,rp,zerninit)
Rmat= model(0)
Tp,aper,_= tf_utils.make_lenslet_tf_zern(model)
aper = aper.numpy()
fig, ax = plt.subplots(1,3,figsize=(15,5))
ax[0].plot(np.array(Rmat))
ax[0].set_title('Rmat - random purturbation')

ax[1].imshow(Tinit*aper)
ax[1].set_title('True surface')

ax[2].imshow(Tp*aper)
ax[2].set_title('Initialization')

In [None]:
plt.figure()
plt.imshow(model.target_psf[0],vmax = .01)

In [None]:
def rot_xy(xloc,yloc,thetad):
    #rotmat = [np.cos(theta) -np.sin(theta); np.sin(theta) np.cos(theta)]
    
    xnew = np.cos(thetad*np.pi/180)*xloc - np.sin(thetad*np.pi/180)*yloc
    ynew = np.sin(thetad*np.pi/180)*xloc + np.cos(thetad*np.pi/180)*yloc
    return(xnew,ynew)

In [None]:
def drawnow():
    display.display(plt.gcf())
    display.clear_output(wait=True)

In [None]:

# model_noco = msu_model(Nlenslets = 37, aberrations = False, zernikes = zernikes_index,loss_type='psf_error',psf_scale=1e2,
#                   lenslet_CA=0.2e3,lenslet_spacing = 'uniform',psf_file='../psf_meas/nanoscribev1_zstack_june262019_2x_dz40_shifted.mat',incoherent_source = False)
# re_init_model(model_noco,xinit, yinit, rinit*.6, zerninit)

# Tnoco,_,_ = tf_utils.make_lenslet_tf_zern(model)
# psf_init_zstack = model_noco.gen_psf_stack(Tnoco,aper,0,defocus_list)

In [None]:
#Manually rotate recipe until it qualitatively matches the measured PSF
zstart = model.zmin_virtual*1.3
zend = model.zmax_virtual*1.1
defocus_list = 1./(np.linspace(1/zstart, 1./zend, model.Nz))
xrot, yrot = rot_xy(xinit,-yinit,-30)
#xrot,yrot = rot_xy(xinit,yinit,0)
re_init_model(model, xrot,yrot,rinit*1.0,zerninit)
model.defocus_offset.assign(0*tf.ones(model.Nz,tf.float64))

fig,ax = plt.subplots(1,2,figsize=(20,7))

Trot,_,_ = tf_utils.make_lenslet_tf_zern(model)
psf_init_zstack = model.gen_psf_stack(Trot,aper,0,1./(1./defocus_list+model.defocus_offset))
psf_init_zstack = [tf_utils.tf_2d_conv(psf_init_zstack[n], model.source_kern,'SAME') for n in range(model.Nz)]
psf_init_zstack = [psf_init_zstack[n]*tf.reduce_sum(model.target_psf[n] * psf_init_zstack[n])/tf.reduce_sum(psf_init_zstack[n]**2) 
                      for n in range(model.Nz)]
psf_init_zstack = [psf_init_zstack[n].numpy() for n in range(model.Nz)]

im_init = ax[0].imshow(psf_init_zstack[0],vmax=.005)
cb1 = plt.colorbar(im_init,ax=ax[0])
im_targ = ax[1].imshow(model.target_psf[0])
cb2 = plt.colorbar(im_targ,ax=ax[1])
for n in [9]:
    
    cb1.remove()
    im_init = ax[0].imshow(psf_init_zstack[n],vmax=.003)
    cb1 = plt.colorbar(im_init,ax=ax[0])
    cb2.remove()
    im_targ = ax[1].imshow(model.target_psf[n],vmax=.003)
    cb2 = plt.colorbar(im_targ,ax=ax[1])
    drawnow()

In [None]:
def tf_2d_conv(x,y,padstr):
    # Inputs x and y tensors (2d)
    x_tensor = tf.reshape(x,[1,tf.shape(x)[0], tf.shape(x)[1], 1])
    
    y_tensor = tf.reshape(y,[tf.shape(y)[0], tf.shape(y)[1],1, 1])
    return tf.squeeze(tf.nn.convolution(x_tensor,y_tensor,padstr))

In [None]:
# xtest = np.random.rand(5,5)
# ytest = np.random.rand(8,9)
# ctest = tf_2d_conv(ytest,xtest,'SAME')
# print(tf.shape(ctest))


In [None]:
# # Blur PSF to see if match improves

# bead_size = 4.8e-3   #mm, diameter
# mag = 6.1   #System magnification
# bead_size_sensor = np.ceil(bead_size*mag/model.px)

# # Construct the circular convolution kernel of size bead_size_sensor (in diameter)
# xkern = np.r_[-np.floor(bead_size_sensor/2):np.ceil(bead_size_sensor/2)]
# Xkern, Ykern = np.meshgrid(xkern,xkern)
# Rkern = np.sqrt(Xkern**2 + Ykern**2)
# kern_numpy = (Rkern<=(bead_size_sensor/2)) /np.sum(Rkern)
# # Rkern = np.atleast_3d(np.sqrt(Xkern**2 + Ykern**2))
# # Rkern = np.moveaxis(Rkern,2,0)

# kern = tf.constant(kern_numpy,tf.float64)


# tz = 0
# target_slice = psf_init_zstack[tz]
# slice_blurred = tf_2d_conv(target_slice, kern,'SAME')
# #slice_blurred = tf.nn.convolution(model.target_psf, kern,'same')
# fit, ax = plt.subplots(2,2,figsize=(20,20))
# ax[0,0].imshow(model.target_psf[tz].numpy(),vmax=.03)

# ax[0,1].imshow(slice_blurred)

# ax[1,0].imshow((target_slice.numpy()/14 - model.target_psf[tz].numpy()))

# ax[1,1].imshow((slice_blurred.numpy() - model.target_psf[tz].numpy()))

In [None]:
# bead_kern = tf.constant(np.array())
# print(tf.reduce_sum(tf.abs(psf_init_zstack[0] - model.target_psf[0])))
# print(tf.reduce_sum(model.target_psf[0]))
# plt.figure(figsize=(20,20))
# plt.imshow((psf_init_zstack[1] - model.target_psf[1]))

In [None]:
# for n in range(model.Nz):
#     model.target_psf[n]*=3

In [None]:
# #Get shifts through zstack
# shifts_target_psf = []
# f, ax = plt.subplots(1,3,figsize=(25,5))
# zstack_shifted = []
# for dp in range(model.Nz):

#     c = np.fft.ifft2((np.fft.fft2(model.target_psf[dp]))* np.conj(np.fft.fft2(psf_init_zstack[dp])))
#     mr = np.unravel_index(np.argmax(c, axis=None), c.shape)
    
#     mrt = tuple([-mr[n] for n in range(2)])
#     im_shifted = np.roll(model.target_psf[dp],mrt,axis=(0,1))

#     zstack_shifted.append(im_shifted)
#     shifts_target_psf.append(mrt)
    
    
#     ax[0].imshow(im_shifted,vmax=.05)
#     ax[1].imshow(psf_init_zstack[dp],vmax=.05)
#     ax[1].set_title(mrt)
#     ax[2].imshow((im_shifted - psf_init_zstack[dp]),vmin=-.1,vmax=.03)
#     drawnow()


In [None]:
# out_dict_2={'zstack':zstack_shifted}
# sc.io.savemat('../psf_meas/nanoscribev1_zstack_june262019_2x_dz40_shifted.mat',out_dict_2)


In [None]:
# psf_zstack =  model.gen_psf_stack(Tinit,.9,0)
# psf_zstack_lst = [np.array(psf_zstack[n].numpy()) for n in range(model.Nz)]
# psf_zstack_arr = np.array(psf_zstack_lst)
# sc.io.savemat("../psf_meas/zstack_sim_test.mat",{'zstack':psf_zstack_arr})

In [None]:
def gradient (model, myloss, inputs):
    with tf.GradientTape() as tape:
        lossvalue, Rmat = myloss(model, inputs)
        return tape.gradient(lossvalue, model.variables),lossvalue, Rmat
    
def gradients_and_scaling(model, loss, inputs):
    grad,lossvalue, Rmat=gradient(model,loss, inputs)
    
    grad[0] = grad[0]
    #grad[1] = grad[1] * 1000
    #grad[2] = grad[2] * 1000
    
    grad[1] = grad[1]
    grad[2] = grad[2]
    grad[3] = grad[3]/1000000000000000

    grads=tf_utils.remove_nan_gradients(grad)
    return grad, lossvalue, Rmat


In [None]:
# Check sensitivity of loss function to lenslet params

# defocus_grid = 1./(np.linspace(1/model.zmin_virtual/.8, 1./model.zmax_virtual/.6, model.Nz))
# delta = .001;
# vars_pert = np.zeros((len(model.variables),model.Nlenslets))
# loss_nopert = []
# loss_pert = []
# for mv in range(len(model.variables)):
#     for vi in range(model.Nlenslets):
#         re_init_model(model,xinit, yinit, rinit, zerninit)
#         loss_nopert.append(tf_utils.loss(model, defocus_grid))
#         vars_pert[0,:] = np.copy(rinit)
#         vars_pert[1,:] = np.copy(xinit)
#         vars_pert[2,:] = np.copy(yinit)
#         vars_pert[mv,vi] += delta
#         rpert = vars_pert[0,:]
#         xpert = vars_pert[1,:]
#         ypert = vars_pert[2,:]
#         re_init_model(model,xpert, ypert, rpert, zerninit)
        
#         loss_pert.append(tf_utils.loss(model, defocus_grid))

In [None]:

defocus_grid = 1./(np.linspace(1/zstart, 1./zend, model.Nz))
grad,lossvalue, Rmat=gradient(model,tf_utils.loss, defocus_grid)
# grad_nonan = tf_utils.remove_nan_gradients(grad)
# grad_array = [grad_nonan[n].numpy() for n in range(len(grad_nonan))]


# plt.figure(figsize=(15,5))
# loss_vals_nopert = np.array([loss_nopert[n][0] for n in range(len(loss_nopert))])
# loss_vals_pert = np.array([loss_pert[n][0] for n in range(len(loss_pert))])
# finite_diff_grad = (loss_vals_pert - loss_vals_nopert)/delta
# plt.plot(finite_diff_grad,'r',label='finite diff')
# plt.plot(np.array(grad_array).ravel(),'k-.',label='tf gradient')
# plt.legend()
# plt.title('change in loss function caused by wiggling inputs by {}'.format(delta))
# plt.xlabel('parameter')
# plt.ylabel('d-loss')

In [None]:
re_init_model(model, xrot,yrot,rinit*1.0,zerninit)
#re_init_model(model,xinit,yinit,rinit,zerninit)

In [None]:
vars_init = model.variables

In [None]:
def diff_tf(arr,ax):
    ndims = arr.ndim
    slicer_ = tuple(slice(0+int(n==ax),-1,1) for n in range(ndims))
    slicer = tuple(slice(0,-1-int(n==ax),1) for n in range(ndims))
    return(arr[slicer_] - arr[slicer])

In [None]:
# Options:
step_size = .1e7/20000. #1e-8 works well for l2
use_averaged_gradient = False  # True: uses averaged gradient, False: uses single gradient 
optimizer_type = 'nesterov'           # options: 'gd': normal gradient descent, 'nesterov': nesterov acceleration
num_iterations = 5000
num_batches = 1
randomize_z = False   #If true, randomize zlist each epoch. If false, leave original order.
Rin = model(0)
nvars = len(model.variables)
#optimizer = tf.train.AdamOptimizer(learning_rate=step_size)
#optimizer=tf.train.GradientDescentOptimizer(learning_rate=step_size)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=step_size)
#optimizer=tf.train.MomentumOptimizer(learning_rate=step_size,momentum= 0.9, use_nesterov = True)
fig, ax = plt.subplots(2,4,figsize=(20,10))

Tir = 0
Tic = 0
ax[Tir,Tic].imshow(Trot)
ax[Tir,Tic].set_title('Initial surface')


Tor = 0
Toc = 1
topt = ax[Tor,Toc].imshow(Trot)
ax[Tor,Toc].set_title('Optimized surface')

Ter = 0
Tec = 2
imh = ax[Ter,Tec].imshow(Trot - Trot)
ax[Ter,Tec].set_title('Phase error')
cb = plt.colorbar(imh,ax=ax[Ter,Tec])

per = 0
pec = 3
pl = ax[per,pec].plot([])
# ax[per,pec].set_title('phase MSE')
ax[per,pec].set_title('focus error')

lr = 1
lc = 0
l = ax[lr,lc].plot([])
ax[lr,lc].set_title('Loss')

xyr = 1
xyc = 1
lh = ax[xyr,xyc].scatter(model.xpos.numpy(),model.ypos.numpy(),c='k', marker='.',label='init')
lh = ax[xyr,xyc].scatter(model.xpos.numpy(),model.ypos.numpy(),c='b', marker='.',label='opt')
ax[xyr,xyc].scatter(xinit*0/0,yinit,c='r',marker='x',label='target')
ax[xyr,xyc].axis('equal')

rr = 1
rc = 2
ax[rr,rc].plot(model.rlist.numpy(),'k.',label='init')
rh = ax[rr,rc].plot(model.rlist.numpy(),'b.',label='opt')
ax[rr,rc].plot(rinit*0/0,'rx',label='target')


Rr = 1
Rc = 3
Rh = ax[Rr,Rc].plot(Rin.numpy())
gradr = ax[Rr,Rc].plot(model.rlist.numpy(),'r',label='R')
gradx = ax[Rr,Rc].plot(model.xpos.numpy(),'g',label='X')
grady  = ax[Rr,Rc].plot(model.ypos.numpy(),'b',label='Y')


losslist = []
phase_err = []
rmean=[]

#defocus_grid=  1./(np.linspace(1/model.zmin_virtual, 1./model.zmax_virtual, model.Nz * num_batches)) #mm or dioptres
#defocus_grid = 1./(np.linspace(1/model.zmin_virtual, 1./model.zmax_virtual/.6, model.Nz))  
defocus_grid = 1./(np.linspace(1/zstart, 1./zend, model.Nz))

if optimizer_type == 'nesterov':
    tk = tf.constant(1,tf.float64)
    tkp = tf.constant(1,tf.float64)

    #xkp = model.variables
    nvars = np.shape(model.variables)[0]
    xk = []
    xkp = []
    [xk.append(tf.Variable(tf.zeros(tf.size(model.variables[n]),tf.float64))) for n in range(nvars)]
    [xkp.append(tf.Variable(tf.zeros(tf.size(model.variables[n]),tf.float64))) for n in range(nvars)]
    [tf.assign(xkp[n],model.variables[n]) for n in range(nvars)]

lossbest = 1e10
for i in range(num_iterations):
    if randomize_z:
        defocus_epoch = np.random.permutation(defocus_grid)
    else:
        defocus_epoch = defocus_grid
    for j in range(num_batches):
        if use_averaged_gradient == True:
            grad,lossvalue, Rmat =  averaged_gradient(model, tf_utils.loss, num_averages = 10)
        else: 
            grad, lossvalue, Rmat = gradients_and_scaling(model, tf_utils.loss, defocus_epoch[j*model.Nz:j*model.Nz+model.Nz])  # initial value 


        #grad,lossvalue, Rmat=gradient(model,loss)

       # new_xpos = model.xpos - step_size*grad[2]
       # new_ypos = model.xpos - step_size*grad[3] 
        #new_grad, test_dist, test_dist_bool = constrain_distances(model, new_xpos, new_ypos, grad) # apply constraint 


        # Gradient step
        if optimizer_type == 'gd':
            optimizer.apply_gradients(zip(grad,model.variables),global_step=tf.train.get_or_create_global_step())
            # Projection step
            tf_utils.project_to_aper_keras(model)

        if optimizer_type == 'nesterov':

            optimizer.apply_gradients(zip(grad,model.variables),global_step=tf.train.get_or_create_global_step())


            # Projection step
            tf_utils.project_to_aper_keras(model,1100)

            # Update variables for next loop
            tk = tkp
            [tf.assign(xk[n],xkp[n]) for n in range(nvars)]

            # Get state after project (gradient_step(yk))
            [tf.assign(xkp[n],model.variables[n]) for n in range(nvars)]

            #Acceleration
            tkp = 0.5*(1.0 + tf.sqrt(1.0 + 4*tf.square(tk)))

            bkp = (tk - 1)/tkp
            ykp = [xkp[n] + bkp*(xkp[n] - xk[n]) for n in range(nvars)]


            # Update model variables (yk)
            [model.variables[n].assign(ykp[n]) for n in range(nvars)]




        T,aper,T2=tf_utils.make_lenslet_tf_zern(model)
        found_better = False
        if lossvalue <= lossbest:
            found_better = True
            bestvars = model.variables
            rbest = bestvars[0].numpy()
            xbest = bestvars[1].numpy()
            ybest = bestvars[2].numpy()
            defocus_best = bestvars[3].numpy()
            lossbest = lossvalue
        else:
            tk = tf.constant(1.,tf.float64)
            tkp = tf.constant(1.,tf.float64)
            
        losslist.append(lossvalue)
        phase_err.append(.5*tf.norm(T - Tinit)**2/model.samples[0]/model.samples[1])

        
        topt.remove()
        topt = ax[Tor,Toc].imshow(T)
        ax[Tor,Toc].set_title("Optimized surface, iter {}".format(i))

        l[0].remove()
        l = ax[lr,lc].plot(losslist,'k')
        ax[lr,lc].set_title('Loss (best = {})'.format(found_better))
        
        pl[0].remove()
        pl = ax[per,pec].plot(model.defocus_offset.numpy(),'k')
        
#        ax[per,pec].set_title('phase MSE')

        ax[per,pec].set_title('focus error')
        
        cb.remove()
        imh.remove()
        
        imh = ax[Ter,Tec].imshow((T - Trot))
        ax[Ter,Tec].set_title('Phase diff micron')
        cb = plt.colorbar(imh,ax=ax[Ter,Tec])

        
        lh.remove()
        lh = ax[xyr,xyc].scatter(model.xpos.numpy(),model.ypos.numpy(),c='b', marker='.',label='opt')
        ax[xyr,xyc].legend()
        ax[xyr,xyc].set_title('Positions')
        
        Rh[0].remove()
        gradr[0].remove()
        gradx[0].remove()
        grady[0].remove()
        Rh = ax[Rr,Rc].semilogy(Rmat.numpy(),'k')
        gradr = ax[Rr,Rc].semilogy(grad[0].numpy(),'r',label='R')
        gradx = ax[Rr,Rc].semilogy(grad[1].numpy(),'g',label='X')
        grady  = ax[Rr,Rc].semilogy(grad[2].numpy(),'b',label='Y')
        
        ax[Rr,Rc].legend()
        #ax[Rr,Rc].set_title('gradient')
        ax[Rr,Rc].set_title('Rmat and grads')
        
        
        rh[0].remove()
        rh = ax[rr,rc].plot(model.rlist.numpy(),'b.',label='opt')
        ax[rr,rc].legend()
        display.display(plt.gcf())
        display.clear_output(wait=True)
#     pl.remove?
    


#cbar = fig.colorbar(rshow)

In [None]:
zstack_test = [psf_init_zstack[n]*tf.reduce_sum(model.target_psf[n] * psf_init_zstack[n])/tf.reduce_sum(psf_init_zstack[n]**2) 
                      for n in range(model.Nz)]

scale_list = [tf.reduce_sum(model.target_psf[n] * psf_init_zstack[n])/tf.reduce_sum(psf_init_zstack[n]**2) 
                      for n in range(model.Nz)]
print(np.array(scale_list))
print("Scaled {}".format(tf.reduce_sum(tf.abs(zstack_test[0] - model.target_psf[0])**2)))
print("unscaled {}".format(tf.reduce_sum(tf.abs(psf_init_zstack[0] - model.target_psf[0])**2)))


In [None]:
vars_final = model.variables

In [None]:
# model_good = sc.io.loadmat('../l1_scaled_lenslet_fit_nanoscribe_v1_20190812_173906.mat')
# xgood = model_good['xbest'][0].astype(np.float64)
# ygood = model_good['ybest'][0].astype(np.float64)
# rgood = model_good['rbest'][0].astype(np.float64)
# defocus_good = model_good['defocus_list'][0].astype(np.float64)
# zerngood = model_orig['zern_list']



In [None]:
import time
import datetime

bead_size = 4.8   #mm, diameter
mag = 6.1   #System magnification
bead_size_sensor = np.ceil(bead_size*mag/model.px)
re_init_model(model,xbest,ybest,rbest,zerninit,defocus_best)
Topt,_,_ = tf_utils.make_lenslet_tf_zern(model)
psf_opt_zstack = model.gen_psf_stack(Topt,aper,0,1./(1./defocus_list + model.defocus_offset))

psf_opt_zstack = [tf_utils.tf_2d_conv(psf_opt_zstack[n], model.source_kern,'SAME') for n in range(model.Nz)]
psf_opt_zstack = [psf_opt_zstack[n]*tf.reduce_sum(model.target_psf[n] * psf_opt_zstack[n])/tf.reduce_sum(psf_opt_zstack[n]**2) 
                      for n in range(model.Nz)]

psf_opt_zstack = [psf_opt_zstack[n].numpy() for n in range(model.Nz)]

# best_dict = {
#     'xpos':xbest,
#     'ypos':ybest,
#     'rlist':rbest,
#     'zstack_best':psf_opt_zstack,
#     'defocus_list':defocus_grid,
#     'defocus_correction':model.defocus_offset,
#     'loss':losslist
# }

# dt = datetime.datetime.now()
# dts = dt.strftime("%Y%m%d_%H%M%S")

# sc.io.savemat('../l1_GOOD_scaled_nanoscribe_v1_' + dts + '.mat',best_dict)
tz =0
target_slice = psf_opt_zstack[tz]
# slice_blurred = tf_2d_conv(target_slice, model.source_kern,'SAME')
#slice_blurred = tf.nn.convolution(model.target_psf, kern,'same')
fit, ax = plt.subplots(1,3,figsize=(20,10))
ax[0].imshow(model.target_psf[tz].numpy(),vmax=.005)

ax[1].imshow(psf_opt_zstack[tz],vmax=.005)


ax[2].imshow((psf_opt_zstack[tz] - model.target_psf[tz].numpy()))

In [None]:
#Get shifts through zstack
shifts_target_psf = []
f, ax = plt.subplots(1,3,figsize=(25,5))
zstack_shifted = []
for dp in range(model.Nz):

    c = np.fft.ifft2((np.fft.fft2(model.target_psf[dp]))* np.conj(np.fft.fft2(psf_opt_zstack[dp])))
    mr = np.unravel_index(np.argmax(c, axis=None), c.shape)
    
    mrt = tuple([-mr[n] for n in range(2)])
    im_shifted = np.roll(model.target_psf[dp].numpy(),mrt,axis=(0,1))

    zstack_shifted.append(im_shifted/model.psf_scale.numpy())
    shifts_target_psf.append(mrt)
    
    
    ax[0].imshow(im_shifted,vmax=.005)
    ax[1].imshow(psf_opt_zstack[dp],vmax=.005)
    ax[1].set_title(mrt)
    ax[2].imshow((im_shifted - psf_opt_zstack[dp]),vmin=-.01,vmax=.01)
    drawnow()

In [None]:
print(np.roll(model.target_psf[dp].numpy(),mrt,axis=(0,1)))

In [None]:
# sc.io.savemat('../psf_meas/psf_crop_nanoscribe_v1_re-registered.mat',{'zstack':zstack_shifted})

In [None]:
best_dict = {
    'xpos':xbest,
    'ypos':ybest,
    'rlist':rbest,
    'zstack_best':psf_opt_zstack,
    'defocus_list':defocus_grid,
    'defocus_correction':model.defocus_offset,
    'loss':losslist
}

In [None]:
best_dict['defocus_correction']

In [None]:
bead_size = 4.8   #um, diameter
mag = 6.1   #System magnification
bead_size_sensor = np.ceil(bead_size*mag/model.px)
xkern = np.r_[-np.floor(bead_size_sensor/2):np.ceil(bead_size_sensor/2)]
Xkern, Ykern = np.meshgrid(xkern,xkern)
Rkern = np.sqrt(Xkern**2 + Ykern**2)
kern_numpy = (Rkern<=(bead_size_sensor/2)) /np.sum(Rkern)
kern = tf.constant(kern_numpy,tf.float32)
#test_zstack = [tf_2d_conv(psf_opt_zstack[n], kern,'SAME') for n in range(len(psf_opt_zstack))]
test_zstack = []
[test_zstack[n] = tf_2d_conv(psf_opt_zstack[n], kern,'SAME') for n in range(len(psf_opt_zstack))]
plt.figure(figsize=(20,20))
plt.imshow(np.abs(test_zstack[3]),vmax=.0001)

In [None]:
np.r_[-np.floor(bead_size_sensor/2):np.ceil(bead_size_sensor/2)]

In [None]:
grad, lossvalue, Rmat = gradients_and_scaling(model, tf_utils.loss, defocus_epoch[j*model.Nz:j*model.Nz+model.Nz])  # initial value 
grad

In [None]:
model.variables[0] - rinit*.6

In [None]:
T,aper,T2=tf_utils.make_lenslet_tf_zern(model)

print(model.min_r)
print(model.max_r)

a = lambda t: tf.clip_by_value(t,model.min_r, model.max_r)
print(a(model.rlist)/ model.rlist)


In [None]:
print(defocus_epoch[j*model.Nz:j*model.Nz+model.Nz])
print(defocus_grid)

In [None]:
plt.imshow(psf_init_zstack[0]-model.target_psf[0],vmin=-100,vmax=100)

In [None]:
print(np.array(losslist))

In [None]:
lh = ax[xyr,xyc].scatter(model.xpos.numpy(),model.ypos.numpy(),c='b', marker='x')

print(nvars)
imh.remove()

In [None]:
n = 5
sphere = tf.real(tf.sqrt(
    tf.square(model.rlist[n])
    - tf.square((model.xgm-model.xpos[n]))
    - tf.square((model.ygm-model.ypos[n]))))

piston = tf.real
(
    tf.sqrt
    (
        tf.square(model.rlist[n])-tf.square(model.mean_lenslet_CA)
    )
)

In [None]:
max(model.rlist.numpy())

In [None]:
print(model.rlist)

In [None]:
Tgd,_,_= tf_utils.make_lenslet_tf_zern(model)
psf_zstack =  model.gen_psf_stack(Tgd,.9,.5)
vup =.05
#zdisp = 0
f, ax = plt.subplots(2,2,figsize=(12,12))
for zdisp in range(1):
    ax[0,0].imshow(model.target_psf[zdisp],vmax = vup)
    ax[0,0].set_title('target')
    ax[0,1].imshow(psf_zstack[zdisp],vmax=vup)
    ax[0,1].set_title('Optimized')
    ax[1,0].imshow(Tinit)
    ax[1,0].set_title('target')
    
    if zdisp != 0:
        cb.remove()
        e.remove()

    e = ax[1,1].imshow(model.target_psf[zdisp] - psf_zstack[zdisp])
   
    ax[1,1].set_title('error')
    cb = plt.colorbar(e,ax=ax[1,1])
    
    
    display.display(f)
    display.clear_output(wait=True)


In [None]:
plt.plot(rinit.numpy(),label='init')
plt.plot(model.rlist.numpy(),label='opt')
plt.legend()
model_init.ypos.numpy() - model.ypos.numpy()