# Inspect ptyREX recon - amourphous C sim

[Evaluate using Fourier Ring Correlation](#FRC) <br>

In [1]:
%matplotlib qt
# %matplotlib inline

In [2]:
import numpy as np
import os
import h5py
import sys
import matplotlib.pyplot as plt
import hyperspy.api as hs

In [3]:
sys.path.append('/dls/science/groups/e02/Mohsen/code/Git_Repos/Merlin-Medipix/')
import epsic_tools.api as epsic



In [4]:
matrix_path = '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_amC_16June2020_512pixArray/'
# run_ID: 26042020 - 100 iterations 
# run_ID: 2000iter
# json_files = get_ptyREX_recon_list(matrix_path, run_id = '26042020') 
json_files = epsic.sim_utils.get_ptyREX_recon_list(matrix_path, run_id = '2000iter')

In [7]:
len(json_files)

64

In [6]:
# sorting the json files


conv_angles = [0.016, 0.020, 0.024, 0.030, 0.040, 0.050, 0.064, 0.084]
real_probe_overlap = [0, 5, 15, 35, 60, 70, 80, 90]


data_list_of_dicts = []
for file in json_files:
    j_dict = epsic.sim_utils.json_to_dict_sim(file)

    for i, angle in enumerate(conv_angles):
        data_list_of_dicts.append([])
        if j_dict['process']['common']['probe']['convergence'] == angle:
            data_list_of_dicts[i].append(j_dict)
                
for i, angle in enumerate(conv_angles):
    #print(angle)
    data_list_of_dicts[i].sort(key=lambda e: e['process']['common']['scan']['dR'][0], reverse = True)


In [12]:
probe_overlaps = [0, 5, 15, 35, 60, 70, 80, 90]

In [13]:
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
fig.suptitle('FFT of reconstructed objects')
for idx in range(rows):    
      for idy in range(cols): 
            try:
                obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
                j_dict = epsic.sim_utils.json_to_dict_sim(data_list_of_dicts[idy][idx]['json_path'])
                axs[idx,idy].imshow(np.log(10*abs(epsic.ptycho_utils.get_fft(obj, crop = 0.8, apply_hann=True))), cmap = 'RdBu', vmax = 10)

                axs[idx,idy].set_xticks([])
                axs[idx,idy].set_yticks([])
                if idx == 0 and idy == 0:
                    axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad', 
                                           color = 'red', fontsize = 8)
                    axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%', color = 'red', fontsize = 8)
                elif idx == 0:
                    axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad', 
                                           color = 'red', fontsize = 8)
                elif idy == 0:
                    axs[idx,idy].set_title('step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
                    axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%', color = 'red', fontsize = 8)
                else:
                    axs[idx,idy].set_title('%2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
            except IndexError:
                pass
#plt.tight_layout()   

In [14]:
# obj phase 
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
for idx in range(rows):    
      for idy in range(cols): 
            try:
                obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
                #print(idx, idy)
                j_dict = epsic.sim_utils.json_to_dict_sim(data_list_of_dicts[idy][idx]['json_path'])
                
                img = abs(np.min(np.angle(obj))) + np.angle(obj)
                sh0 = img.shape[0]
                sh1 = 200
                img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
                
                axs[idx,idy].imshow(img_crop, cmap = 'magma_r')
                axs[idx,idy].set_xticks([])
                axs[idx,idy].set_yticks([])
                if idx == 0 and idy == 0:
                    axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad', 
                                           color = 'red', fontsize = 8)
                    axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%', color = 'red', fontsize = 8)
                elif idx == 0:
                    axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad', 
                                           color = 'red', fontsize = 8)
                    
                elif idy == 0:
                    axs[idx,idy].set_title('step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
                    axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%', color = 'red', fontsize = 8)
                else:
                    axs[idx,idy].set_title('%2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
            except IndexError:
                pass
#plt.tight_layout()            

# Comparison with sim potential

In [15]:
# we get one of the potentials as ground truth to compare the recons with.
data_list_of_dicts[2][2]['sim_path']

'/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_amC_16June2020_512pixArray/amorCarbon_43p5nmCube_12.0mrad_244.21A_def_18.42A_step_size/amorCarbon_43p5nmCube_12.0mrad_244.21A_def_18.42A_step_size.h5'

In [16]:
pot = epsic.sim_utils.get_potential(data_list_of_dicts[2][2]['sim_path'])

In [17]:
# summing the sloces
pot = np.sum(pot, axis = 2)
pot.shape

(1024, 1024)

In [52]:
sh = pot.shape[0]
obj_pot = pot[int(0.33*sh):int(0.66*sh), int(0.33*sh):int(0.66*sh)]
obj_pot_fft = np.fft.fftshift(np.fft.fft2(obj_pot))
test = epsic.ptycho_utils.get_fft(obj_pot, crop=0.33, apply_hann=True)
phase_ideal = epsic.sim_utils._sigma(80000) * obj_pot
fig, ax = plt.subplots(1,3,figsize=(11,4))
im = ax[0].imshow(obj_pot)
fig.colorbar(im, ax = ax[0])
ax[1].imshow(np.log(abs(obj_pot_fft)), cmap = 'viridis')
im2 = ax[2].imshow(phase_ideal)
fig.colorbar(im2, ax = ax[2])
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[0].set_title('object potential in V$\AA$')
ax[1].set_title('fft of object potential')
ax[2].set_title('ideal phase shift')
# plt.tight_layout()            

Text(0.5, 1.0, 'ideal phase shift')

In [19]:
phase_ideal.dtype

dtype('float32')

In [20]:
obj_pot.shape

(338, 338)

In [47]:
phase_ideal = np.ones_like(phase_ideal)*np.exp(1j*phase_ideal)

In [63]:
# binning the ideal phase by 2 - the sim pixelSize is half the recon pixelSize - due to the way 4DSTEM
# data is saved in pyprimsatic
phase_ideal_hs = hs.signals.Signal2D(phase_ideal)
phase_ideal_bin = phase_ideal_hs.rebin(scale = (2,2))
# phase_ideal_bin = np.angle(phase_ideal_bin.data)
phase_ideal_bin = phase_ideal_bin.data

In [61]:
plt.figure()
plt.imshow(phase_ideal_bin)

<matplotlib.image.AxesImage at 0x7faf5f4e7190>

In [64]:
phase_ideal_bin.shape

(169, 169)

In [18]:
# offset_img_r = np.real(offset_img)

In [50]:
# ratio = np.divide(offset_img_r, phase_ideal_bin)

In [51]:
# plt.figure()
# plt.imshow(ratio, cmap = 'viridis')
# plt.colorbar()
# ratio_mean = np.mean(ratio)

In [52]:
# ratio_mean

0.9884089493296008

<a id='FRC'></a>
# Evaluate by Fourier Ring Correlation

In [23]:
def genAp(*args):
    """### Generate a circular aperture ###
    out:
        ap - Aperture
    in:
        shape - Array size
        r     - Radius of aperture
    """
    if len(args) > 0:
        shape = args[0]
    if len(args) > 1:
        r = args[1]
    if len(args) > 2:
        cent = args[2]

    ap = np.ones(shape) * np.exp(1j*np.zeros(shape))
    r = np.array(r)
    #print "gen_ap_r", r
    if len(args) < 3:
        cent = np.divide(shape,2)

    x = np.arange(0,np.size(ap,0)) - cent[0]
    y = np.arange(0,np.size(ap,1)) - cent[1]
    
#     print("gen ap r and size", r, r.size)
    
    if r.size == 1:
        #print "r dim is 1"
        yy, xx = np.meshgrid(y, x)
        grid = np.sqrt((xx**2)+(yy**2))
        rad = np.mean(r)
        ap[grid>rad] = 0
    elif r.size == 2:
        yy, xx = np.meshgrid(y, x)
        yy = np.abs(yy)
        xx = np.abs(xx)
        ap[yy>r[1]] = 0
        ap[xx>r[0]] = 0
    return ap

def genStop(*args):
    """### Generate a circular aperture ###
    out:
        ap - Aperture
    in:
        shape - Array size
        r     - Radius of aperture
    """
    if len(args) > 0:
        shape = args[0]
    if len(args) > 1:
        r = args[1]
    if len(args) > 2:
        cent = args[2]

    out = np.zeros(shape) * np.exp(1j*np.zeros(shape))

    if len(args) < 3:
        cent = np.divide(shape,2)
        
    x = np.arange(0,np.size(out,0)) - cent[0]
    y = np.arange(0,np.size(out,1)) - cent[1]
    yy, xx = np.meshgrid(y, x)
    grid = np.sqrt((xx**2)+(yy**2))
    
    rad = np.mean(r)
    out[grid>rad] = 1
    return out

In [181]:
def fft(ar):
    ar = np.fft.fftshift(np.fft.fft2(ar))
    return ar
def ifft(ar):
    ar = np.fft.ifft2(np.fft.fftshift(ar))
    return ar

# def fourierDownSample(image, keep_fraction, pixelSize):
#     """
#     Reduces the size of the FFT, returns also the new pixel size
#     """
#     im_fft = np.fft.fft2(image)
#     r, c = im_fft.shape[-2:]
#     im_fft_crop = np.delete(im_fft, np.arange(int(r*keep_fraction), int(r*(1 - keep_fraction))), 1)
#     im_fft_crop = np.delete(im_fft_crop, np.arange(int(c*keep_fraction), int(c*(1 - keep_fraction))), 2)
#     im_ds = np.fft.ifft2(im_fft_crop)
#     # stack_ds_hs = hs.signals.Signal2D(abs(stack_ds))
#     pixelSizeNew = (r / im_ds.shape[-2])*pixelSize

#     return im_ds, pixelSizeNew

def setPower(ar, power):
    P_sz = np.size(ar, -2) * np.size(ar, -1)
    int_in = np.float32(ar.real ** 2 + ar.imag ** 2)
    P_in = np.sum(int_in)
    P_in = np.multiply(P_in, P_sz)
    ratio = np.divide(power, P_in)
    int_out = np.multiply(int_in, ratio)
    mod_out = np.sqrt(int_out)
    ar = np.abs(mod_out) * np.exp(1j * (np.angle(ar)))
    return ar

def get_frc(ar1, ar2, dx, norm = False, plot=False):
    ar1 = fft(ar1)
    ar2 = fft(ar2)
    
    if norm is True:
        ar1 = setPower(ar1, np.sum(np.abs(ar2) ** 2))

    frc = np.zeros(np.uint32(ar1.shape[0]/2))
    two_sig = np.zeros(np.uint32(ar1.shape[0]/2))
    one_t = np.zeros(np.uint32(ar1.shape[0]/2))
    half_t = np.zeros(np.uint32(ar1.shape[0]/2))
    
    two_sig_lim_reached = False
    one_t_lim_reached = False
    half_t_lim_reached = False
    
    res_r = ar1.shape[0]/2
    for r in range(frc.shape[0]):
        ring_mask = np.abs(genAp(ar1.shape, r+1) * genStop(ar1.shape, r))
        npr = np.sum(ring_mask)
        ar1_r = ar1 * ring_mask
        ar2_r = ar2 * ring_mask
        frc[r] = np.sum(ar1_r * np.conj(ar2_r)) / np.sqrt( np.sum(np.square(np.abs(ar1_r))) * np.sum(np.square(np.abs(ar2_r))) )
        two_sig[r] = 2 / np.sqrt(npr/2)
        one_t[r] = (0.5+(2.4142/np.sqrt(npr))) / (1.5+(1.4142/np.sqrt(npr)))
        half_t[r] = (0.2071+(1.9102/np.sqrt(npr))) / (1.2071+(0.9102/np.sqrt(npr)))
        
        if r>1:
            if frc[r] <= two_sig[r] and two_sig_lim_reached == False:
                two_sig_r = r
                two_sig_lim_reached = True
            if frc[r] <= one_t[r] and not one_t_lim_reached:
                one_t_r = r
                one_t_lim_reached = True
            if frc[r] <= half_t[r] and not half_t_lim_reached:
                half_t_r = r
                half_t_lim_reached = True
    
    #print("dx", dx)
    u = 1/dx
    du = u/ar1.shape[0]
    du /= 1e9   
    two_sig_lim = 1e9/(float(two_sig_r) * float(du) * 1e9)
    one_t_lim = 1e9/(float(one_t_r) * float(du) * 1e9)
    half_t_lim = 1e9/(float(half_t_r) * float(du) * 1e9)
    
    if plot is True:
        x_axis = np.arange(2, ar1.shape[0]/2) * du
        plt.figure()
        plt.plot(x_axis, frc[1:], color = 'k')
        plt.plot(x_axis, two_sig[1:], color = 'r')
        plt.plot(x_axis, one_t[1:], color = 'g')
        plt.plot(x_axis, half_t[1:], color = 'b')
        plt.axvline(x=two_sig_r*du, color='r', linestyle='--', label="Two Sigma")
        plt.axvline(x=one_t_r*du, color='g', linestyle='--', label="One Bit")
        plt.axvline(x=half_t_r*du, color='b', linestyle='--', label="Half Bit")
        plt.ylabel('Ring Correlation')
        plt.xlabel('Reciprocal nms')
        plt.title('FRC\n Two Sigma Resolution = %fnm\n One Bit Resolution = %fnm\n Half Bit Resolution = %fnm' %(two_sig_lim, one_t_lim, half_t_lim))
        plt.legend()
        plt.show()
    
    return two_sig_lim

In [25]:
from skimage.feature import register_translation
from scipy.ndimage import fourier_shift

In [173]:
phase_ideal_bin.shape

(169, 169)

In [72]:
phase_ideal_bin.dtype

dtype('float32')

In [183]:
pix_size = epsic.ptycho_utils.get_json_pixelSize(data_list_of_dicts[6][6]['json_path'])
obj_test = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[6][6]['json_path'])

In [184]:
test = np.angle(obj_test)

In [185]:
test.shape

(664, 664)

In [186]:
sh = phase_ideal_bin.shape[0] + 50
c = int(test.shape[0] / 2)
test_crop = test[int(c - sh / 2):int(c + sh / 2), int(c - sh / 2):int(c + sh / 2)]

In [187]:
plt.figure()
plt.imshow(test_crop)

<matplotlib.image.AxesImage at 0x7faf66325d10>

In [170]:
plt.figure()
plt.imshow(phase_ideal_bin)

<matplotlib.image.AxesImage at 0x7faf70032350>

In [174]:
test_crop.shape

(219, 219)

In [163]:
def im_resize(im, new_dim):
    im_fft = np.fft.fftshift(np.fft.fft2(im))
#     plt.figure()
#     plt.imshow(np.real(im_fft))
    sh = im.shape[0]
    if new_dim < sh:
        im_fft_new = im_fft[int(sh / 2) - int(new_dim /2):int(new_dim / 2) + int(sh / 2), int(sh / 2) - int(new_dim /2):int(new_dim / 2) + int(sh / 2)]
        print(im_fft_new.shape)
    else:
        im_fft_new = np.pad(im_fft, int((new_dim - sh)/2), pad_with, padder = 1)
        print(im_fft_new.shape)
#     plt.figure()
#     plt.imshow(np.real(im_fft_crop))
    return np.fft.ifft2(np.fft.fftshift(im_fft_new))

def pad_with(vector, pad_width, iaxis, kwargs):
    pad_value = kwargs.get('padder', 10)
    vector[:pad_width[0]] = pad_value
    vector[-pad_width[1]:] = pad_value

In [80]:
def dist(p1, p2):
    return np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)

In [171]:
dist((172,121), (79,153))

98.35141076771599

In [172]:
dist((137,89), (46,118))

95.50916186418976

In [177]:
test_crop_manual = test_crop[171 - 138: 169 + 171 - 138 , 122 - 90 : 169 + 122 - 90]
test_crop_manual.shape

(169, 169)

In [179]:
plt.imshow(test_crop_manual)

<matplotlib.image.AxesImage at 0x7faf67e52590>

In [182]:
test = get_frc(test_crop_manual, phase_ideal_bin, pix_size, plot = True)



In [192]:
idy = 3
idx = 5
pix_size = epsic.ptycho_utils.get_json_pixelSize(data_list_of_dicts[idy][idx]['json_path'])
obj_test = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
test = np.angle(obj_test)
sh = phase_ideal_bin.shape[0] + 50
c = int(test.shape[0] / 2)
test_crop = test[int(c - sh / 2):int(c + sh / 2), int(c - sh / 2):int(c + sh / 2)]
test_crop_manual = test_crop[171 - 138: 169 + 171 - 138 , 122 - 90 : 169 + 122 - 90]
test_crop_manual.shape

(169, 169)

In [193]:
test = get_frc(test_crop_manual, phase_ideal_bin, pix_size, plot = True)



In [194]:
test

0.14358392100000003

In [195]:
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
sampling_factors = []
res_lim = []
#print(rows, cols)
for idx in range(rows):    
      for idy in range(cols): 
            pix_size = epsic.ptycho_utils.get_json_pixelSize(data_list_of_dicts[idy][idx]['json_path'])
            obj_test = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
            test = np.angle(obj_test)
            sh = phase_ideal_bin.shape[0] + 50
            c = int(test.shape[0] / 2)
            test_crop = test[int(c - sh / 2):int(c + sh / 2), int(c - sh / 2):int(c + sh / 2)]
            test_crop_manual = test_crop[171 - 138: 169 + 171 - 138 , 122 - 90 : 169 + 122 - 90]
            res_lim.append(get_frc(test_crop_manual, phase_ideal_bin, pix_size, plot = False))



In [198]:
res_lim =np.asarray(res_lim)
res_lim_rs = res_lim.reshape(8,8)

In [None]:
probe_overlaps

In [199]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
im = ax.imshow(res_lim_rs, cmap = 'viridis')

# We want to show all ticks...
ax.set_xticks(np.arange(len(conv_angles)))
ax.set_yticks(np.arange(len(probe_overlaps)))
# ... and label them with the respective list entries
ax.set_xticklabels(conv_angles)
ax.set_yticklabels(probe_overlaps)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
plt.xlabel('convergence angle (rad)')
plt.ylabel('probe real space overlap (%)')

# Loop over data dimensions and create text annotations.
for i in range(len(conv_angles)):
    for j in range(len(probe_overlaps)):
        text = ax.text(j, i, np.round(res_lim_rs[i, j],2),
                       ha="center", va="center", color="w")

ax.set_title("Resolution limit")
fig.tight_layout()
plt.show()

In [69]:
plt.imshow(np.angle(obj_test))

<matplotlib.image.AxesImage at 0x7faf5f163910>

In [34]:
plt.figure()
plt.imshow(phase_ideal_bin)

<matplotlib.image.AxesImage at 0x7faf5efb2310>

In [None]:
pix_size = epsic.ptycho_utils.get_json_pixelSize(data_list_of_dicts[6][6]['json_path'])
obj_test = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[6][6]['json_path'])
sh = phase_ideal_bin.shape[0]
c = int(obj_test.shape[0] / 2)
obj_test_crop = obj_test[int(c - sh / 2):int(c + sh / 2), int(c - sh / 2):int(c + sh / 2)]
obj_complex = np.ones_like(np.angle(obj_test_crop))* np.exp(1j*np.angle(obj_test_crop))

obj_test2 = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[1][1]['json_path'])
sh = phase_ideal_bin.shape[0]
c = int(obj_test2.shape[0] / 2)
obj_test2_crop = obj_test2[int(c - sh / 2):int(c + sh / 2), int(c - sh / 2):int(c + sh / 2)]
obj2_complex = np.ones_like(np.angle(obj_test2_crop))* np.exp(1j*np.angle(obj_test2_crop))

test = get_frc(obj2_complex, obj_complex, pix_size, plot = True)

In [27]:
pix_size = epsic.ptycho_utils.get_json_pixelSize(data_list_of_dicts[6][6]['json_path'])
obj_test = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[6][6]['json_path'])
sh = phase_ideal_bin.shape[0]
c = int(obj_test.shape[0] / 2)
obj_test_crop = obj_test[int(c - sh / 2):int(c + sh / 2), int(c - sh / 2):int(c + sh / 2)]
obj_complex = np.ones_like(np.angle(obj_test_crop))* np.exp(1j*np.angle(obj_test_crop))

obj_test2 = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[1][1]['json_path'])
sh = phase_ideal_bin.shape[0]
c = int(obj_test2.shape[0] / 2)
obj_test2_crop = obj_test2[int(c - sh / 2):int(c + sh / 2), int(c - sh / 2):int(c + sh / 2)]
obj2_complex = np.ones_like(np.angle(obj_test2_crop))* np.exp(1j*np.angle(obj_test2_crop))

test = get_frc(obj2_complex, obj_complex, pix_size, plot = True)



In [173]:
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
frc_results = []
shifts = []

for idx in range(rows):    
      for idy in range(cols): 
            try:
                pix_size = epsic.ptycho_utils.get_json_pixelSize(data_list_of_dicts[idy][idx]['json_path'])
                
                obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
                img = abs(np.min(np.angle(obj))) + np.angle(obj)
                sh0 = img.shape[0]
                sh1 = phase_ideal_bin.shape[0]
                img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
                shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
                shifts.append(shift)
                offset_img = fourier_shift(np.fft.fftn(img), shift)
                offset_img = np.real(np.fft.ifftn(offset_img))
                offset_img_crop = offset_img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
                
                obj_crop = obj[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
#                 obj_complex = np.ones_like(offset_img_crop) + 1j * offset_img_crop
#                 phase_ideal_complex = np.ones_like(phase_ideal_bin) + 1j * phase_ideal_bin
                
                obj_complex = np.ones_like(offset_img_crop) * np.exp(1j * offset_img_crop)
                phase_ideal_complex = np.ones_like(phase_ideal_bin) * np.exp(1j * phase_ideal_bin)

#                 frc_results.append(get_frc(obj_complex, phase_ideal_complex, pix_size))
                frc_results.append(get_frc(obj_crop, phase_ideal_complex, pix_size))

            except IndexError:
                pass




In [174]:
offset_img_crop

array([[1.21816886, 1.29672846, 1.14799272, ..., 1.09424785, 1.16532647,
        1.00256828],
       [1.16891784, 1.20498251, 1.05230546, ..., 1.00813638, 1.028826  ,
        0.88564277],
       [1.09167513, 1.10325467, 0.99511174, ..., 0.89040908, 0.8612156 ,
        0.78291933],
       ...,
       [1.30920778, 1.42899643, 1.34151094, ..., 1.34447889, 1.33461289,
        1.11353762],
       [1.24552419, 1.34929685, 1.28420547, ..., 1.2192526 , 1.19232808,
        1.02034627],
       [1.11235409, 1.19783328, 1.1804829 , ..., 1.07721939, 1.07903651,
        0.9888357 ]])

In [175]:
np.angle(obj_complex)

array([[1.21816886, 1.29672846, 1.14799272, ..., 1.09424785, 1.16532647,
        1.00256828],
       [1.16891784, 1.20498251, 1.05230546, ..., 1.00813638, 1.028826  ,
        0.88564277],
       [1.09167513, 1.10325467, 0.99511174, ..., 0.89040908, 0.8612156 ,
        0.78291933],
       ...,
       [1.30920778, 1.42899643, 1.34151094, ..., 1.34447889, 1.33461289,
        1.11353762],
       [1.24552419, 1.34929685, 1.28420547, ..., 1.2192526 , 1.19232808,
        1.02034627],
       [1.11235409, 1.19783328, 1.1804829 , ..., 1.07721939, 1.07903651,
        0.9888357 ]])

In [176]:
frc_results = np.asarray(frc_results)
frc_results = frc_results.reshape(8,8)

In [177]:
probe_overlaps = [0, 5, 15, 35, 60, 70, 80, 90]
conv_angles = [0.016, 0.020, 0.024, 0.030, 0.040, 0.050, 0.064, 0.084]

In [178]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.imshow(frc_results, cmap = 'RdBu')

plt.xlabel('convergence angle (rad)')
plt.ylabel('probe real space overlap (%)')

ax.set_xticks(np.arange(len(conv_angles)))
ax.set_yticks(np.arange(len(probe_overlaps)))
ax.set_xticklabels(conv_angles)
ax.set_yticklabels(probe_overlaps)


plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

for i in range(len(conv_angles)):
    for j in range(len(probe_overlaps)):
        text = ax.text(j, i, np.round(frc_results[i, j],2),
                       ha="center", va="center", color="w")

ax.set_title("FRC values")

fig.tight_layout()
plt.show()

In [233]:
idx = 6
idy = 6

pix_size = epsic.ptycho_utils.get_json_pixelSize(data_list_of_dicts[idy][idx]['json_path'])
                
obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
img = np.angle(obj) - np.min(np.angle(obj))
sh0 = img.shape[0]
sh1 = phase_ideal_bin.shape[0]
img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
shifts.append(shift)
offset_img = fourier_shift(np.fft.fftn(img), shift)
offset_img = np.angle(np.fft.ifftn(offset_img))
offset_img_crop = offset_img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]

# obj_complex = np.ones_like(offset_img_crop) + 1j*offset_img_crop
# phase_ideal_complex = np.ones_like(phase_ideal_bin) + 1j*phase_ideal_bin


obj_complex = obj[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]

                
# obj_complex = np.ones_like(offset_img_crop) * np.exp(1j * offset_img_crop)
phase_ideal_complex = np.ones_like(phase_ideal_bin) * np.exp(1j * phase_ideal_bin)

test = get_frc(obj_crop, phase_ideal_complex, pix_size, norm= True, plot=True)



In [227]:
pix_size

1.6992180000000002e-11

In [234]:

fig, ax = plt.subplots(3,2)
ax[0,0].imshow(np.angle(obj_complex))
ax[0,1].imshow(np.angle(phase_ideal_complex))
ax[1,0].imshow(np.abs(obj_complex))
ax[1,1].imshow(np.abs(phase_ideal_complex))
obj_fft = np.fft.fft2(obj_complex)
ideal_fft = np.fft.fft2(phase_ideal_complex)
ax[2,0].imshow(np.log(np.abs(np.fft.fftshift(obj_fft))))
ax[2,1].imshow(np.log(np.abs(np.fft.fftshift(ideal_fft))))

<matplotlib.image.AxesImage at 0x7fd5b386c5d0>

In [219]:
plt.close('all')

In [None]:
epsic.ptycho_utils

# Evaluate by radial profiles

In [15]:
type(phase_ideal_bin)

numpy.ndarray

In [16]:
phase_ideal_bin_fft = epsic.ptycho_utils.get_fft(phase_ideal_bin, crop = None, apply_hann=True)
sh = phase_ideal_bin_fft.shape[0]
ideal_profile = epsic.radial_profile.radial_profile(phase_ideal_bin_fft, center = (sh//2,sh//2))

In [17]:
plt.figure()
plt.plot(ideal_profile)

[<matplotlib.lines.Line2D at 0x7f68eed7a190>]

In [18]:

# fft of obj phase 

rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
fft_line_profiles = []
#print(rows, cols)
# fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
for idx in range(rows):    
      for idy in range(cols): 
            try:
                obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
                #print(idx, idy)
                obj_fft = epsic.ptycho_utils.get_fft(np.angle(obj), crop = None, apply_hann=True)
                sh = obj_fft.shape[0]
                fft_prof = epsic.radial_profile.radial_profile(obj_fft, center = (sh//2,sh//2))
                fft_line_profiles.append(fft_prof)

            except IndexError:
                pass
#plt.tight_layout()      

In [19]:
fft_line_profiles_np = np.asarray(fft_line_profiles)

In [20]:
fft_line_profiles_reshaped = fft_line_profiles_np.reshape(8,8)

In [21]:
fft_line_profiles_reshaped.shape

(8, 8)

In [27]:
probe_overlaps = [0, 5, 15, 35, 60, 70, 80, 90]
conv_angles = [0.016, 0.020, 0.024, 0.030, 0.040, 0.050, 0.064, 0.084]

In [48]:
plt.figure()
plt.plot(ideal_profile, label='ideal')
for i in range(8):
    plt.plot(fft_line_profiles_reshaped[5][i], label = 'probe overlap '+ str(probe_overlaps[i]))

plt.legend()

<matplotlib.legend.Legend at 0x7f68e71c8f90>

In [50]:
plt.figure()
plt.plot(ideal_profile, label='ideal')
for i in range(8):
    plt.plot(fft_line_profiles_reshaped[i][5], label = 'conv angle '+ str(conv_angles[i]))

plt.legend()

<matplotlib.legend.Legend at 0x7f68e6bac1d0>

In [55]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.fftpack import fft2, fftshift
from skimage import img_as_float
from skimage.color import rgb2gray
from skimage.data import astronaut
from skimage.filters import window

image = img_as_float(rgb2gray(astronaut()))

wimage = image * window('hann', image.shape)

image_f = np.abs(fftshift(fft2(image)))
wimage_f = np.abs(fftshift(fft2(wimage)))

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
ax = axes.ravel()
ax[0].set_title("Original image")
ax[0].imshow(image, cmap='gray')
ax[1].set_title("Windowed image")
ax[1].imshow(wimage, cmap='gray')
ax[2].set_title("Original FFT (frequency)")
ax[2].imshow(np.log(image_f), cmap='magma')
ax[3].set_title("Window + FFT (frequency)")
ax[3].imshow(np.log(wimage_f), cmap='magma')
plt.show()

ImportError: cannot import name 'window' from 'skimage.filters' (/dls_sw/apps/python/anaconda/4.6.14/64/envs/epsic3.7/lib/python3.7/site-packages/skimage/filters/__init__.py)

In [74]:
ssim_scores = []
nrmse_scores = []
mse_scores = []
case_num = 0
convergence = []
step_size = []
for i in range(len(conv_angles)):
    for n in range(len(data_list_of_dicts[i])):
        obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[i][n]['json_path'])
        convergence.append(data_list_of_dicts[i][n]['process']['common']['probe']['convergence'])
        step_size.append(data_list_of_dicts[i][n]['process']['common']['scan']['dR'][0])
        img = abs(np.min(np.angle(obj))) + np.angle(obj)
        sh0 = img.shape[0]
        sh1 = phase_ideal_bin.shape[0]
        img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
        shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
        offset_img = fourier_shift(np.fft.fftn(img_crop), shift)
        offset_img = np.fft.ifftn(offset_img)
        ssim_score = ssim(phase_ideal_bin, offset_img)
        nrmse_score = nrmse(phase_ideal_bin, offset_img, normalization= 'min-max')
        mse_score = mse(phase_ideal_bin, offset_img)
        ssim_scores.append(ssim_score)
        nrmse_scores.append(nrmse_score)
        mse_scores.append(mse_score)

convergence = np.asarray(convergence)
fig, ax = plt.subplots(1,2,figsize=(11,4))
for angle in conv_angles:
    inds = np.where(convergence==angle)
    _steps = np.take(step_size, inds[0])
    _ssim = np.take(ssim_scores, inds[0])
    _nrmse = np.take(nrmse_scores, inds[0])
    _mse = np.take(mse_scores, inds[0])

    ax[0].plot(_steps, _ssim,label = str(angle))
    ax[0].set_title('ssim scores')
    ax[0].set_xlabel('step_size(m)')
    ax[1].plot(_steps, _nrmse,label = str(angle))
    ax[1].set_title('nrmse scores')
    ax[1].set_xlabel('step_size(m)')
    ax[0].legend()
    ax[1].legend()
fig.suptitle('ssim and nrmse scores as function of convergence angle')

  im2 = im2.astype(np.float64)
  ret = umr_sum(arr, axis, dtype, out, keepdims)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)


Text(0.5, 0.98, 'ssim and nrmse scores as function of convergence angle')

In [91]:
# To plot objects as a function of convergence angles
conv_angles = np.asarray(conv_angles)
angle = 0.084
ind = np.where(conv_angles==angle)
print(ind)
objects = []
step_sizes = []
sh1 = 200
for json_dict in data_list_of_dicts[ind[0][0]]:
    step_size = json_dict['process']['common']['scan']['dR'][0]
    step_sizes.append(step_size)
    obj = epsic.ptycho_utils.crop_recon_obj(json_dict['json_path'])
    img = abs(np.min(np.angle(obj))) + np.angle(obj)
    sh0 = img.shape[0]
    img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
    objects.append(img_crop)
fig, ax = plt.subplots(4,2, figsize=(8,11))
n_cols = 2
for i, step in enumerate(step_sizes):

    ax[i // n_cols, i%n_cols].imshow(objects[i], cmap = 'magma_r')
    ax[i // n_cols, i%n_cols].set_title('%2.3f'%(1e10*step)+' $\AA$' )

    ax[int(i / n_cols), i%n_cols].set_xticks([])
    ax[int(i / n_cols), i%n_cols].set_yticks([])
fig.suptitle('probe convergence ' + str(angle)+ ' rad')
None

(array([2]),)


In [79]:
conv_angles

array([0.05 , 0.064, 0.084])

In [81]:
# To plot errors function of convergence angles
conv_angles = np.asarray(conv_angles)
angle = 0.050
ind = np.where(conv_angles==angle)

errors = []
step_sizes = []
for json_dict in data_list_of_dicts[ind[0][0]]:
    step_size = json_dict['process']['common']['scan']['dR'][0]
    step_sizes.append(step_size)
    name, ext = os.path.splitext(json_dict['json_path'])
    error = epsic.ptycho_utils.get_error(name + '.hdf')
    errors.append(error)
fig, ax = plt.subplots(4,2, figsize=(8,12))
for i, step in enumerate(step_sizes):

    ax[int(i / 2), i%2].plot(errors[i])
    ax[int(i / 2), i%2].set_title('%2.3f'%(1e10*step)+'$\AA$' )

fig.suptitle('error vs iter num ' + str(angle)+ 'rad')
None

# Evaluation using atomap

In [15]:
%matplotlib qt
import atomap.api as am

In [16]:
ref_hs = hs.signals.Signal2D(phase_ideal_bin)
ref_hs.plot()

In [17]:
# Getting the reference positions
ref_hs = hs.signals.Signal2D(phase_ideal_bin)
# ref_hs_crop = ref_hs.isig[5:70,5:70]
ref_atom_positions = am.get_atom_positions(ref_hs, separation=4)
ref_sublattice = am.Sublattice(ref_atom_positions, image=ref_hs.data)
ref_sublattice.find_nearest_neighbors()
ref_sublattice.refine_atom_positions_using_center_of_mass()
ref_sublattice.refine_atom_positions_using_2d_gaussian()
ref_atom_list = ref_sublattice.atom_list

Center of mass: 100%|██████████| 282/282 [00:00<00:00, 5349.39it/s]
Gaussian fitting: 100%|██████████| 282/282 [00:05<00:00, 48.75it/s]


In [18]:
ref_sublattice.plot()

In [19]:
# This is not returning the same value as e!
ref_atom_list[0].amplitude_gaussian 

5.891599456829207

In [20]:
ref_atom_list[0].pixel_x

134.14714481805655

In [21]:
ref_atom_list[0].sigma_average

1.2057729108139426

In [22]:
ref_coord = []
for i in range(len(ref_atom_list)):
    
    ref_coord.append(list((ref_atom_list[i].pixel_x,
                    ref_atom_list[i].pixel_y,
                    ref_atom_list[i].sigma_average,
                    ref_atom_list[i].amplitude_gaussian)))

ref_coord = np.asarray(ref_coord)
print('Number of atoms in reference: ', ref_coord.shape[0])


Number of atoms in reference:  282


In [23]:
np.save('ground_truth_positions', ref_coord)

In [24]:
# Getting the experimental positions and comparison


In [25]:
# Calculating the sampling factor directly from the json dicts
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
sampling_factors = []
#print(rows, cols)
for idx in range(rows):    
      for idy in range(cols): 
            l = epsic.ptycho_utils.e_lambda(data_list_of_dicts[idy][idx]['process']['common']['source']['energy'][0])
            CL = data_list_of_dicts[idy][idx]['process']['common']['detector']['distance']
            det_pix_array = data_list_of_dicts[idy][idx]['process']['common']['detector']['crop']
            det_pitch = data_list_of_dicts[idy][idx]['process']['common']['detector']['pix_pitch'][0]
            num_probe_pos = data_list_of_dicts[idy][idx]['process']['common']['scan']['N'][0]
            probe_semi_angle = data_list_of_dicts[idy][idx]['process']['common']['probe']['convergence']
            probe_step_size = data_list_of_dicts[idy][idx]['process']['common']['scan']['dR'][0]
            defocus = data_list_of_dicts[idy][idx]['experiment']['optics']['lens']['defocus'][0]
            
            recon_pix_size = l * CL / (det_pix_array[0] * det_pitch)
            probe_rad = epsic.sim_utils.calc_probe_size(recon_pix_size, det_pix_array, l, defocus, probe_semi_angle, plot_probe=False)
            s = epsic.ptycho_utils.get_sampling_factor( recon_pix_size * det_pix_array[0], 2 * probe_rad, num_probe_pos, probe_step_size)
            sampling_factors.append(s)
            
            data_list_of_dicts[idy][idx]['sampling_factor'] = s

In [26]:
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)

for idx in range(rows):    
      for idy in range(cols): 
            obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
            img = abs(np.min(np.angle(obj))) + np.angle(obj)
            sh0 = img.shape[0]
            sh1 = phase_ideal_bin.shape[0]
            img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
            shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
            offset_img = fourier_shift(np.fft.fftn(img_crop), shift)
            offset_img = np.real(np.fft.ifftn(offset_img))

            img_hs = hs.signals.Signal2D(offset_img)
#             img_hs_crop = img_hs.isig[5:70,5:70]


            exp_positions = am.get_atom_positions(img_hs, separation=4)
            exp_sublattice = am.Sublattice(exp_positions, image=img_hs.data)

            exp_sublattice.find_nearest_neighbors()
            exp_sublattice.refine_atom_positions_using_center_of_mass()
            exp_sublattice.refine_atom_positions_using_2d_gaussian()

            exp_atom_list = exp_sublattice.atom_list


            exp_coord = []
            for i in range(len(exp_atom_list)):

                exp_coord.append(list((exp_atom_list[i].pixel_x,
                                exp_atom_list[i].pixel_y,
                                exp_atom_list[i].sigma_average,
                                exp_atom_list[i].amplitude_gaussian)))

            exp_coord = np.asarray(exp_coord)
            data_list_of_dicts[idy][idx]['atom_pos'] = exp_coord

          

Center of mass: 100%|██████████| 271/271 [00:00<00:00, 5385.36it/s]
Gaussian fitting: 100%|██████████| 271/271 [00:08<00:00, 30.82it/s]
Center of mass: 100%|██████████| 278/278 [00:00<00:00, 5469.07it/s]
Gaussian fitting: 100%|██████████| 278/278 [00:07<00:00, 35.77it/s]
Center of mass: 100%|██████████| 266/266 [00:00<00:00, 6357.94it/s]
Gaussian fitting: 100%|██████████| 266/266 [00:09<00:00, 28.93it/s]
Center of mass: 100%|██████████| 276/276 [00:00<00:00, 6323.76it/s]
Gaussian fitting: 100%|██████████| 276/276 [00:08<00:00, 34.04it/s]
Center of mass: 100%|██████████| 301/301 [00:00<00:00, 5476.64it/s]
Gaussian fitting: 100%|██████████| 301/301 [00:09<00:00, 31.28it/s]
Center of mass: 100%|██████████| 272/272 [00:00<00:00, 6547.13it/s]
Gaussian fitting: 100%|██████████| 272/272 [00:08<00:00, 33.63it/s]
Center of mass: 100%|██████████| 296/296 [00:00<00:00, 6178.32it/s]
Gaussian fitting: 100%|██████████| 296/296 [00:09<00:00, 31.94it/s]
Center of mass: 100%|██████████| 285/285 [00:00<

In [27]:
# identified atom positions versus the known positions
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
#print(rows, cols)
# fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(8, 11))

for idx in range(rows):    
      for idy in range(cols): 
            try:
                j_dict = epsic.sim_utils.json_to_dict_sim(data_list_of_dicts[idy][idx]['json_path'])
                exp_coord = data_list_of_dicts[idy][idx]['atom_pos']
                
                missed_atoms = ref_coord.shape[0] - exp_coord.shape[0]
                s = data_list_of_dicts[idy][idx]['sampling_factor']
                axs[idx,idy].scatter(ref_coord[:,0], ref_coord[:,1], s = 1, c = 'g')
                axs[idx,idy].scatter(exp_coord[:,0], exp_coord[:,1], s = 1, c = 'r')
                axs[idx,idy].set_xticks([])
                axs[idx,idy].set_yticks([])
                axs[idx,idy].set_title(str(missed_atoms) + '       ' + str(np.round(s,1)), color = 'black', fontsize = 7)

            except IndexError:
                pass
fig.tight_layout(pad = 1.0)            

In [28]:
plt.figure()
plt.plot(sampling_factors)

[<matplotlib.lines.Line2D at 0x7f7b544be390>]

In [29]:
# plotting individual dataset
plt.figure()
plt.scatter(ref_coord[:,0], ref_coord[:,1], s = 5, c = 'g')
exp_pos = data_list_of_dicts[7][7]['atom_pos']
plt.scatter(exp_pos[:,0], exp_pos[:,1], s = 5, c = 'r')

<matplotlib.collections.PathCollection at 0x7f7b47aad650>

In [32]:
# Saving the fitted data into hdf file
results_folder = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results'

rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
for idx in range(rows):    
      for idy in range(cols): 
            h5_file_name = os.path.join(results_folder, data_list_of_dicts[idy][idx]['json_path'].split('/')[-2] + '.h5')
            save_dict_to_hdf5(data_list_of_dicts[idy][idx], h5_file_name)

In [33]:
dd = load_dict_from_hdf5('/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size.h5')


In [34]:
dd

{'atom_pos': array([[121.41388905, 163.86922113,   1.68823719,   0.72288737],
        [113.00532043, 164.16584041,   1.70654456,   1.05055257],
        [ 96.50912662, 164.13181785,   1.97042467,   1.20076947],
        ...,
        [ 98.23902267,  10.10589817,   1.55754718,   0.49673928],
        [ 89.56400168,  10.0962495 ,   1.78996969,   0.84017837],
        [ 64.93901982,  10.33318438,   1.76751978,   0.82241142]]),
 'base_dir': '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_4June2020_512pixArray/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size',
 'experiment': {'data': {'data_path': '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_4June2020_512pixArray/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size.h5',
   'dead_pixel_flag': 0,
   'flat_field_flag': 0,
   'key': '4DSTEM_simulation/data/datacubes/hdose_noisy_data',
   'load_flag': 1,
   'meta_type': 'hdf'},
 

In [31]:
# Saving the atomic fitting data
import numpy as np
import h5py
import os
def save_dict_to_hdf5(dic, filename):
    """
    ....
    """
    if os.path.exists(filename):
        with h5py.File(filename, 'a') as h5file:
            recursively_save_dict_contents_to_group(h5file, '/', dic)
    else:
        
        with h5py.File(filename, 'w') as h5file:
            recursively_save_dict_contents_to_group(h5file, '/', dic)

def recursively_save_dict_contents_to_group(h5file, path, dic):
    """
    ....
    """
    for key, item in dic.items():
        if isinstance(item, (np.ndarray, list, float, int, str)):
            h5file[path + key] = item
        elif isinstance(item, dict):
            recursively_save_dict_contents_to_group(h5file, path + key + '/', item)

            
def load_dict_from_hdf5(filename):
    """
    ....
    """
    with h5py.File(filename, 'r') as h5file:
        return recursively_load_dict_contents_from_group(h5file, '/')

def recursively_load_dict_contents_from_group(h5file, path):
    """
    ....
    """
    ans = {}
    for key, item in h5file[path].items():
        if isinstance(item, h5py._hl.dataset.Dataset):
            ans[key] = item[()]
        elif isinstance(item, h5py._hl.group.Group):
            ans[key] = recursively_load_dict_contents_from_group(h5file, path + key + '/')
    return ans

# Trials on single datasets

In [25]:
ref_hs = hs.signals.Signal2D(phase_ideal_bin)
ref_hs_crop = ref_hs.isig[5:70,5:70]
ref_hs_crop.plot()

In [36]:
s_peaks = am.get_feature_separation(ref_hs_crop, separation_range=(2, 20))

100%|██████████| 18/18 [00:00<00:00, 328.02it/s]
100%|██████████| 280/280 [00:00<00:00, 3478.03it/s]


In [37]:
s_peaks.plot()

In [236]:
ref_atom_positions = am.get_atom_positions(ref_hs_crop, separation=2)

In [237]:
ref_sublattice = am.Sublattice(ref_atom_positions, image=ref_hs_crop.data)

In [238]:
ref_sublattice

<Sublattice,  (atoms:149,planes:0)>

In [239]:
ref_sublattice.find_nearest_neighbors()
ref_sublattice.refine_atom_positions_using_center_of_mass()
ref_sublattice.refine_atom_positions_using_2d_gaussian()

Center of mass: 100%|██████████| 149/149 [00:00<00:00, 1791.21it/s]
Gaussian fitting: 100%|██████████| 149/149 [00:10<00:00, 14.47it/s]


In [240]:
ref_sublattice.get_position_history().plot()

100%|██████████| 149/149 [00:00<00:00, 2943.29it/s]


In [241]:
ref_atom_list = ref_sublattice.atom_list

In [242]:
ref_atom_list[0]

<Atom_Position,  (x:61.7,y:61.1,sx:0.7,sy:0.6,r:1.9,e:1.2)>

In [176]:
data_list_of_dicts[5][7]

{'process': {'gpu_flag': 1,
  'save_interval': 10,
  'PIE': {'iterations': 2000},
  'common': {'source': {'flux': -1,
    'energy': [80000],
    'radiation': 'electron'},
   'detector': {'bin': [1, 1],
    'distance': 0.11470962889594202,
    'orientation': '00',
    'min_max': [0, 1000000],
    'crop': [512, 512],
    'mask_flag': 0,
    'optic_axis': [256.0, 256.0],
    'pix_pitch': [5.5e-05, 5.5e-05]},
   'probe': {'distance': -1,
    'aperture_size': 0.0028677407223985504,
    'focal_dist': -1,
    'load_flag': 0,
    'diffuser': 0,
    'convergence': 0.05,
    'aperture_shape': 'circ'},
   'object': {'load_flag': 0},
   'scan': {'fast_axis': 1,
    'orientation': '00',
    'type': 'tv',
    'N': [13, 13],
    'load_flag': 0,
    'rotation': 0,
    'dR': [2.1737000000000002e-10, 2.1737000000000002e-10]}},
  'save_dir': '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_4June2020_512pixArray/Graphene_defect_25.0mrad_181.36A_def_2.17A_step_size',
  'core

In [218]:
plt.figure()
plt.imshow(np.angle(test_obj))

<matplotlib.image.AxesImage at 0x7f69c1ab6510>

In [271]:
test_obj = crop_recon_obj(data_list_of_dicts[0][6]['json_path'])

img = abs(np.min(np.angle(test_obj))) + np.angle(test_obj)
sh0 = img.shape[0]
sh1 = phase_ideal_bin.shape[0]
img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
offset_img = fourier_shift(np.fft.fftn(img_crop), shift)
offset_img = np.real(np.fft.ifftn(offset_img))

fig, ax = plt.subplots(1,2)
ax[0].imshow(img_crop)
ax[1].imshow(np.real(offset_img))

img_hs = hs.signals.Signal2D(offset_img)
img_hs_crop = img_hs.isig[5:70,5:70]
img_hs_crop.plot()

In [231]:
img_hs.plot()

In [272]:
s_peaks = am.get_feature_separation(img_hs, separation_range=(2, 20))
s_peaks.plot()

100%|██████████| 18/18 [00:00<00:00, 565.63it/s]
100%|██████████| 188/188 [00:00<00:00, 3748.05it/s]


In [273]:
test_positions = am.get_atom_positions(img_hs_crop, separation=2)

In [274]:
test_sublattice = am.Sublattice(test_positions, image=img_hs_crop.data)

In [275]:
test_sublattice

<Sublattice,  (atoms:104,planes:0)>

In [276]:
test_sublattice.find_nearest_neighbors()
test_sublattice.refine_atom_positions_using_center_of_mass()
test_sublattice.refine_atom_positions_using_2d_gaussian()

Center of mass: 100%|██████████| 104/104 [00:00<00:00, 3044.89it/s]
Gaussian fitting: 100%|██████████| 104/104 [00:13<00:00,  7.64it/s]


In [184]:
test_sublattice.get_position_history().plot()

100%|██████████| 239/239 [00:00<00:00, 1313.30it/s]


In [65]:
test_sublattice.plot_ellipticity_map()

In [277]:
test_atom_list = test_sublattice.atom_list

In [278]:
test_sublattice.plot()

In [279]:
test_atom_list[0]

<Atom_Position,  (x:38.6,y:61.8,sx:1.6,sy:2.2,r:0.4,e:1.4)>

In [280]:
ref_coord = []
test_coord = []
for i in range(len(ref_atom_list)):
    ref_coord.append(list(ref_atom_list[i].get_pixel_position()))
for i in range(len(test_atom_list)):
    test_coord.append(list(test_atom_list[i].get_pixel_position()))
ref_coord = np.asarray(ref_coord)
test_coord = np.asarray(test_coord)

In [281]:
plt.figure()
plt.scatter(ref_coord[:,0], ref_coord[:,1])
plt.scatter(test_coord[:,0], test_coord[:,1])

<matplotlib.collections.PathCollection at 0x7f69c0fb71d0>

In [282]:
def atom_dist(x1, y1, x2, y2):
    return(np.sqrt((y2-y1)**2+(x2-x1)**2))


def check_atom_found(ref_atom, exp_list, tol):
    distance_check = []
    for exp_atom in exp_list:
        distance_check.append(atom_dist(ref_atom[0], ref_atom[1], exp_atom[0], exp_atom[1]) < tol)
    return any(distance_check)

In [283]:
len(ref_coord)

149

In [284]:
len(test_coord)

104

In [293]:
ref_to_compare = []
ind_to_del = []
match_count = 0
for i, atom in enumerate(ref_coord):
    if check_atom_found(atom, test_coord, 2.28):
        match_count += 1
    else: 
        ind_to_del.append(i)
ref_to_compare = np.delete(ref_coord, ind_to_del, 0)
print('number of atoms missing: ', len(ref_coord) - match_count)

number of atoms missing:  45


In [294]:
match_count

104

In [295]:
len(ref_coord) - len(test_coord)

45

In [296]:
len(ref_to_compare)

104

In [297]:
mse = ((ref_to_compare - test_coord)**2).mean()
print(mse)

279.6274438286989
