In [1]:
%matplotlib qt
# %matplotlib inline

In [2]:
import numpy as np
import os
import h5py
import sys
import matplotlib.pyplot as plt

In [3]:
sys.path.append('/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix')

In [4]:
from sim_utils import *
from sim_utils import _sigma
from recon_utils import *

In [5]:
matrix_path = '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_25April2020'
# run_ID: 26042020 - 100 iterations 
# run_ID: 2000iter
# json_files = get_ptyREX_recon_list(matrix_path, run_id = '26042020') 
json_files = get_ptyREX_recon_list(matrix_path, run_id = '2000iter')

In [6]:
len(json_files)

112

In [7]:
# sorting the json files
conv_angles = [0.016, 0.020, 0.024, 0.030, 0.040, 0.050, 0.064, 0.084]
diff_overlap = [-22.5, 1.99, 18.33, 34.70, 51.00, 60.80, 69.37, 76.67]
real_probe_overlap = [-5, 2, 15, 35, 60, 70, 80]


focus_probes_dict = []
defocus_probes_dict = []
for file in json_files:
    j_dict = json_to_dict(file)
    if j_dict['experiment']['optics']['lens']['defocus'][0] == 0.00:
        for i, angle in enumerate(conv_angles):
            focus_probes_dict.append([])
            if j_dict['process']['common']['probe']['convergence'] == angle:
                focus_probes_dict[i].append(j_dict)
    else:
        for i, angle in enumerate(conv_angles):
            defocus_probes_dict.append([])
            if j_dict['process']['common']['probe']['convergence'] == angle:
                defocus_probes_dict[i].append(j_dict)
                
for i, angle in enumerate(conv_angles):
    #print(angle)
    focus_probes_dict[i].sort(key=lambda e: e['process']['common']['scan']['dR'][0], reverse = True)
    defocus_probes_dict[i].sort(key=lambda e: e['process']['common']['scan']['dR'][0], reverse = True)

In [40]:
# focused probe cases - fft only
rows = len(defocus_probes_dict[0])
cols = len(conv_angles)
#print(rows, cols)
fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 12))
#fig.suptitle('Focused probe conditions', fontsize = 12)

for idx in range(rows):    
      for idy in range(cols): 
            obj = crop_recon_obj(focus_probes_dict[idy][idx]['json_path'])
            #print(idx, idy)
            j_dict = j_dict = json_to_dict(focus_probes_dict[idy][idx]['json_path'])
            axs[idx,idy].imshow(np.log(abs(get_fft(obj))), cmap = 'viridis')
#             axs[idx,idy].imshow(abs(get_fft(obj)), cmap = 'viridis')

            axs[idx,idy].set_xticks([])
            axs[idx,idy].set_yticks([])
            if idx == 0:
                axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad \n' + \
                    'step size ($\AA$) %2.2f'%(1e10*j_dict['process']['common']['scan']['dR'][0]) + '\n' + \
                                       str(diff_overlap[idy]) + '%overlap', color = 'red', fontsize = 8)
            elif idy == 0:
                axs[idx,idy].set_title('step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
                axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%overlap', color = 'red', fontsize = 8)
            else:
                axs[idx,idy].set_title('%2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
#plt.tight_layout()

In [41]:
# focused probe cases - obj phase only
rows = len(defocus_probes_dict[0])
cols = len(conv_angles)
#print(rows, cols)
fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(18, 10))
#fig.suptitle('Focused probe conditions', fontsize = 12)

for idx in range(rows):    
      for idy in range(cols): 
            obj = crop_recon_obj(focus_probes_dict[idy][idx]['json_path'])
            #print(idx, idy)
            j_dict = j_dict = json_to_dict(focus_probes_dict[idy][idx]['json_path'])
            
            img = abs(np.min(np.angle(obj))) + np.angle(obj)
            sh0 = img.shape[0]
            sh1 = 82
            img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
            
            axs[idx,idy].imshow(img_crop, cmap = 'magma_r')
#             axs[idx,idy].imshow(abs(get_fft(obj)), cmap = 'viridis')

            axs[idx,idy].set_xticks([])
            axs[idx,idy].set_yticks([])
            if idx == 0:
                axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad \n' + \
                    'step size ($\AA$) %2.2f'%(1e10*j_dict['process']['common']['scan']['dR'][0]) + '\n' + \
                                       str(diff_overlap[idy]) + '%overlap', color = 'red', fontsize = 8)
            elif idy == 0:
                axs[idx,idy].set_title('step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 8)
                axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%overlap', color = 'red', fontsize = 8)
            else:
                axs[idx,idy].set_title('%2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 6)
#plt.tight_layout()

In [42]:
# defocused probe cases - fft only
rows = len(defocus_probes_dict[0])
cols = len(conv_angles)
#print(rows, cols)
fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
for idx in range(rows):    
      for idy in range(cols): 
            try:
                obj = crop_recon_obj(defocus_probes_dict[idy][idx]['json_path'])
                #print(idx, idy)
                j_dict = j_dict = json_to_dict(defocus_probes_dict[idy][idx]['json_path'])
                axs[idx,idy].imshow(np.sqrt(abs(get_fft(obj))), cmap = 'viridis')
                axs[idx,idy].set_xticks([])
                axs[idx,idy].set_yticks([])
                if idx == 0:
                    axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad \n' + \
                        'step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0])+ '\n' + \
                                       str(diff_overlap[idy]) + '%overlap', color = 'red', fontsize = 8)
                elif idy == 0:
                    axs[idx,idy].set_title('step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
                    axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%overlap', color = 'red', fontsize = 8)
                else:
                    axs[idx,idy].set_title('%2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
            except IndexError:
                pass
#plt.tight_layout()            

In [43]:
# defocused probe cases - obj phase only
rows = len(defocus_probes_dict[0])
cols = len(conv_angles)
#print(rows, cols)
fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
for idx in range(rows):    
      for idy in range(cols): 
            try:
                obj = crop_recon_obj(defocus_probes_dict[idy][idx]['json_path'])
                #print(idx, idy)
                j_dict = json_to_dict(defocus_probes_dict[idy][idx]['json_path'])
                
                img = abs(np.min(np.angle(obj))) + np.angle(obj)
                sh0 = img.shape[0]
                sh1 = 82
                img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
                
                axs[idx,idy].imshow(img_crop, cmap = 'magma_r')
                axs[idx,idy].set_xticks([])
                axs[idx,idy].set_yticks([])
                if idx == 0:
                    axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad \n' + \
                        'step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0])+ '\n' + \
                                       str(diff_overlap[idy]) + '%overlap', color = 'red', fontsize = 8)
                elif idy == 0:
                    axs[idx,idy].set_title('step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
                    axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%overlap', color = 'black', fontsize = 5)
                else:
                    axs[idx,idy].set_title('%2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
            except IndexError:
                pass
#plt.tight_layout()            

# Comparison with sim potential

In [8]:
# we get one of the potentials as ground truth to compare the recons with.
focus_probes_dict[4][5]['sim_path']

'/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_25April2020/Graphene_defect_20.0mrad_0.00A_def_1.70A_step_size/Graphene_defect_20.0mrad_0.00A_def_1.70A_step_size.h5'

In [9]:
pot = get_potential(focus_probes_dict[4][5]['sim_path'])

In [10]:
sh = pot.shape[0]
obj_pot = pot[int(0.33*sh):int(0.66*sh), int(0.33*sh):int(0.66*sh)]
obj_pot_fft = np.fft.fftshift(np.fft.fft2(obj_pot))
phase_ideal = _sigma(80000) * obj_pot
fig, ax = plt.subplots(1,3,figsize=(11,4))
im = ax[0].imshow(obj_pot)
fig.colorbar(im, ax = ax[0])
ax[1].imshow(np.log(abs(obj_pot_fft)), cmap = 'viridis')
im2 = ax[2].imshow(phase_ideal)
fig.colorbar(im2, ax = ax[2])
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[0].set_title('object potential in V$\AA$')
ax[1].set_title('fft of object potential')
ax[2].set_title('ideal phase shift')
# plt.tight_layout()            

Text(0.5, 1.0, 'ideal phase shift')

In [11]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import normalized_root_mse as nrmse
from skimage.metrics import mean_squared_error as mse
from skimage.feature import register_translation
from scipy.ndimage import fourier_shift


In [12]:
# binning the ideal phase by 2 - the sim pixelSize is half the recn pixelSize - due to the way 4DSTEM
# data is saved in pyprimsatic
phase_ideal_hs = hs.signals.Signal2D(phase_ideal)
phase_ideal_bin = phase_ideal_hs.rebin(scale = (2,2))
phase_ideal_bin = phase_ideal_bin.data

In [49]:
offset_img_r = np.real(offset_img)

In [50]:
ratio = np.divide(offset_img_r, phase_ideal_bin)

In [51]:
plt.figure()
plt.imshow(ratio, cmap = 'viridis')
plt.colorbar()
ratio_mean = np.mean(ratio)

In [52]:
ratio_mean

0.9884089493296008

In [13]:
ssim_scores = []
nrmse_scores = []
mse_scores = []
case_num = 0
convergence = []
step_size = []
for i in range(len(conv_angles)):
    for n in range(len(defocus_probes_dict[i])):
        obj = crop_recon_obj(defocus_probes_dict[i][n]['json_path'])
        convergence.append(defocus_probes_dict[i][n]['process']['common']['probe']['convergence'])
        step_size.append(defocus_probes_dict[i][n]['process']['common']['scan']['dR'][0])
        img = abs(np.min(np.angle(obj))) + np.angle(obj)
        sh0 = img.shape[0]
        sh1 = phase_ideal_bin.shape[0]
        img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
        shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
        offset_img = fourier_shift(np.fft.fftn(img_crop), shift)
        offset_img = np.fft.ifftn(offset_img)
        ssim_score = ssim(phase_ideal_bin, offset_img)
        nrmse_score = nrmse(phase_ideal_bin, offset_img, normalization= 'min-max')
        mse_score = mse(phase_ideal_bin, offset_img)
        ssim_scores.append(ssim_score)
        nrmse_scores.append(nrmse_score)
        mse_scores.append(mse_score)

convergence = np.asarray(convergence)
fig, ax = plt.subplots(1,2,figsize=(11,4))
for angle in conv_angles:
    inds = np.where(convergence==angle)
    _steps = np.take(step_size, inds[0])
    _ssim = np.take(ssim_scores, inds[0])
    _nrmse = np.take(nrmse_scores, inds[0])
    _mse = np.take(mse_scores, inds[0])

    ax[0].plot(_steps, _ssim,label = str(angle))
    ax[0].set_title('ssim scores')
    ax[0].set_xlabel('step_size(m)')
    ax[1].plot(_steps, _nrmse,label = str(angle))
    ax[1].set_title('nrmse scores')
    ax[1].set_xlabel('step_size(m)')
    ax[0].legend()
    ax[1].legend()
fig.suptitle('ssim and nrmse scores as function of convergence angle \n defocused probe cases')

  im2 = im2.astype(np.float64)
  ret = umr_sum(arr, axis, dtype, out, keepdims)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)


Text(0.5, 0.98, 'ssim and nrmse scores as function of convergence angle \n defocused probe cases')

In [54]:
ssim_scores = []
nrmse_scores = []
mse_scores = []
case_num = 0
convergence = []
step_size = []
for i in range(6):
    for n in range(len(focus_probes_dict[i])):
        obj = crop_recon_obj(focus_probes_dict[i][n]['json_path'])
        convergence.append(focus_probes_dict[i][n]['process']['common']['probe']['convergence'])
        step_size.append(focus_probes_dict[i][n]['process']['common']['scan']['dR'][0])
        img = abs(np.min(np.angle(obj))) + np.angle(obj)
        sh0 = img.shape[0]
        sh1 = phase_ideal_bin.shape[0]
        img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
        shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
        offset_img = fourier_shift(np.fft.fftn(img_crop), shift)
        offset_img = np.fft.ifftn(offset_img)
        ssim_score = ssim(phase_ideal_bin, offset_img)
        nrmse_score = nrmse(phase_ideal_bin, offset_img, normalization= 'min-max')
        mse_score = mse(phase_ideal_bin, offset_img)
        ssim_scores.append(ssim_score)
        nrmse_scores.append(nrmse_score)
        mse_scores.append(mse_score)

convergence = np.asarray(convergence)
fig, ax = plt.subplots(1,2,figsize=(11,4))
for angle in conv_angles:
    inds = np.where(convergence==angle)
    _steps = np.take(step_size, inds[0])
    _ssim = np.take(ssim_scores, inds[0])
    _nrmse = np.take(nrmse_scores, inds[0])
    _mse = np.take(mse_scores, inds[0])

    ax[0].plot(_steps, _ssim,label = str(angle))
    ax[0].set_title('ssim scores')
    ax[0].set_xlabel('step_size(m)')
    ax[1].plot(_steps, _nrmse,label = str(angle))
    ax[1].set_title('nrmse scores')
    ax[1].set_xlabel('step_size(m)')
    ax[0].legend()
    ax[1].legend()
fig.suptitle('ssim and nrmse scores as function of convergence angle \n focused probe cases')
None

  im2 = im2.astype(np.float64)
  ret = umr_sum(arr, axis, dtype, out, keepdims)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)
  return array(a, dtype, copy=False, order=order)


In [62]:
# To plot objects as a function of convergence angles
conv_angles = np.asarray(conv_angles)
angle = 0.084
ind = np.where(conv_angles==angle)

objects = []
step_sizes = []
sh1 = 84
for json_dict in defocus_probes_dict[ind[0][0]]:
    step_size = json_dict['process']['common']['scan']['dR'][0]
    step_sizes.append(step_size)
    obj = crop_recon_obj(json_dict['json_path'])
    img = abs(np.min(np.angle(obj))) + np.angle(obj)
    sh0 = img.shape[0]
    img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
    objects.append(img_crop)
fig, ax = plt.subplots(4,2, figsize=(8,11))
for i, step in enumerate(step_sizes):

    ax[int(i / 2), i%2].imshow(objects[i], cmap = 'magma_r')
    ax[int(i / 2), i%2].set_title('%2.3f'%(1e10*step)+' $\AA$' )

    ax[int(i / 2), i%2].set_xticks([])
    ax[int(i / 2), i%2].set_yticks([])
fig.suptitle('probe convergence ' + str(angle)+ ' rad')
None

In [32]:
conv_angles

array([0.016, 0.02 , 0.024, 0.03 , 0.04 , 0.05 , 0.064, 0.084])

In [107]:
# To plot errors function of convergence angles
conv_angles = np.asarray(conv_angles)
angle = 0.016
ind = np.where(conv_angles==angle)

errors = []
step_sizes = []
for json_dict in focus_probes_dict[ind[0][0]]:
    step_size = json_dict['process']['common']['scan']['dR'][0]
    step_sizes.append(step_size)
    name, ext = os.path.splitext(json_dict['json_path'])
    error = get_error(name + '.hdf')
    errors.append(error)
fig, ax = plt.subplots(4,2, figsize=(8,12))
for i, step in enumerate(step_sizes):

    ax[int(i / 2), i%2].plot(errors[i])
    ax[int(i / 2), i%2].set_title('%2.3f'%(1e10*step)+'$\AA$' )

fig.suptitle('error vs iter num ' + str(angle)+ 'rad')
None

# Evaluation using atomap

In [14]:
%matplotlib qt
import atomap.api as am

In [15]:
# Getting the reference positions
ref_hs = hs.signals.Signal2D(phase_ideal_bin)
ref_hs_crop = ref_hs.isig[5:70,5:70]
ref_atom_positions = am.get_atom_positions(ref_hs_crop, separation=2)
ref_sublattice = am.Sublattice(ref_atom_positions, image=ref_hs_crop.data)
ref_sublattice.find_nearest_neighbors()
ref_sublattice.refine_atom_positions_using_center_of_mass()
ref_sublattice.refine_atom_positions_using_2d_gaussian()
ref_atom_list = ref_sublattice.atom_list

Center of mass: 100%|██████████| 149/149 [00:00<00:00, 1738.78it/s]
Gaussian fitting: 100%|██████████| 149/149 [00:11<00:00, 13.12it/s]


In [16]:
ref_atom_list[0]

<Atom_Position,  (x:61.7,y:61.1,sx:0.7,sy:0.6,r:1.9,e:1.2)>

In [17]:
# This is not returning the same value as e!
ref_atom_list[0].amplitude_gaussian 

1.338411804525878

In [18]:
ref_atom_list[0].pixel_x

61.694970982469904

In [19]:
ref_atom_list[0].sigma_average

0.6677949193333059

In [20]:
ref_coord = []
for i in range(len(ref_atom_list)):
    
    ref_coord.append(list((ref_atom_list[i].pixel_x,
                    ref_atom_list[i].pixel_y,
                    ref_atom_list[i].sigma_average,
                    ref_atom_list[i].amplitude_gaussian)))

ref_coord = np.asarray(ref_coord)
print('Number of atoms in reference: ', ref_coord.shape[0])


Number of atoms in reference:  149


In [22]:
# Getting the experimental positions and comparison


In [None]:
# defocused probe cases 
rows = len(defocus_probes_dict[0])
cols = len(conv_angles)

for idx in range(rows):    
      for idy in range(cols): 
            obj = crop_recon_obj(defocus_probes_dict[idy][idx]['json_path'])
            img = abs(np.min(np.angle(obj))) + np.angle(obj)
            sh0 = img.shape[0]
            sh1 = phase_ideal_bin.shape[0]
            img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
            shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
            offset_img = fourier_shift(np.fft.fftn(img_crop), shift)
            offset_img = np.real(np.fft.ifftn(offset_img))

            img_hs = hs.signals.Signal2D(offset_img)
            img_hs_crop = img_hs.isig[5:70,5:70]


            exp_positions = am.get_atom_positions(img_hs_crop, separation=2)
            exp_sublattice = am.Sublattice(exp_positions, image=img_hs_crop.data)

            exp_sublattice.find_nearest_neighbors()
            exp_sublattice.refine_atom_positions_using_center_of_mass()
            exp_sublattice.refine_atom_positions_using_2d_gaussian()

            exp_atom_list = exp_sublattice.atom_list


            exp_coord = []
            for i in range(len(exp_atom_list)):

                exp_coord.append(list((exp_atom_list[i].pixel_x,
                                exp_atom_list[i].pixel_y,
                                exp_atom_list[i].sigma_average,
                                exp_atom_list[i].amplitude_gaussian)))

            exp_coord = np.asarray(exp_coord)
            defocus_probes_dict[idy][idx]['atom_pos'] = exp_coord

          

Center of mass: 100%|██████████| 148/148 [00:00<00:00, 955.05it/s]
Gaussian fitting: 100%|██████████| 148/148 [01:00<00:00,  2.43it/s]
Center of mass: 100%|██████████| 125/125 [00:00<00:00, 1506.77it/s]
Gaussian fitting: 100%|██████████| 125/125 [00:28<00:00,  4.45it/s]
Center of mass: 100%|██████████| 136/136 [00:00<00:00, 2353.28it/s]
Gaussian fitting: 100%|██████████| 136/136 [00:49<00:00,  2.75it/s]
Center of mass: 100%|██████████| 135/135 [00:00<00:00, 2700.00it/s]
Gaussian fitting: 100%|██████████| 135/135 [00:33<00:00,  4.03it/s]
Center of mass: 100%|██████████| 143/143 [00:00<00:00, 1982.07it/s]
Gaussian fitting: 100%|██████████| 143/143 [00:35<00:00,  4.03it/s]
Center of mass: 100%|██████████| 141/141 [00:00<00:00, 2770.13it/s]
Gaussian fitting: 100%|██████████| 141/141 [00:31<00:00,  4.48it/s]
Center of mass: 100%|██████████| 140/140 [00:00<00:00, 3070.79it/s]
Gaussian fitting:  64%|██████▍   | 90/140 [00:39<00:17,  2.93it/s]

In [110]:
# identified atom positions versus the known positions
rows = len(defocus_probes_dict[0])
cols = len(conv_angles)
#print(rows, cols)
fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
for idx in range(rows):    
      for idy in range(cols): 
            try:
                j_dict = json_to_dict(defocus_probes_dict[idy][idx]['json_path'])
                exp_coord = defocus_probes_dict[idy][idx]['atom_pos']
#                 print(json_dict['json_path'].split('/')[-2])
#                 print('Difference in number f atoms: ', ref_coord.shape[0] - exp_coord.shape[0])
                
                missed_atoms = ref_coord.shape[0] - exp_coord.shape[0]
                
                axs[idx,idy].scatter(ref_coord[:,0], ref_coord[:,1], s = 1, c = 'g')
                axs[idx,idy].scatter(exp_coord[:,0], exp_coord[:,1], s = 1, c = 'r')
                axs[idx,idy].set_xticks([])
                axs[idx,idy].set_yticks([])
                axs[idx,idy].set_title(missed_atoms, color = 'black')
#                 if idx == 0:
#                     axs[idx,idy].set_title(str(1e3 * j_dict['process']['common']['probe']['convergence']) + 'mrad \n' + \
#                         'step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0])+ '\n' + \
#                                        str(diff_overlap[idy]) + '%overlap', color = 'black', fontsize = 8)
#                 elif idy == 0:
#                     axs[idx,idy].set_title('step size ($\AA$) %2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
#                     axs[idx, idy].set_ylabel(str(real_probe_overlap[idx]) + '%overlap', color = 'black', fontsize = 5)
#                 else:
#                     axs[idx,idy].set_title('%2.3f'%(1e10*j_dict['process']['common']['scan']['dR'][0]), color = 'black', fontsize = 5)
            except IndexError:
                pass
#plt.tight_layout()            

In [105]:
json_dict['json_path'].split('/')[-2]

'Graphene_defect_8.0mrad_0.00A_def_1.83A_step_size'

# Trials

In [66]:
ref_hs = hs.signals.Signal2D(phase_ideal_bin)
ref_hs_crop = ref_hs.isig[5:70,5:70]
ref_hs_crop.plot()

In [36]:
s_peaks = am.get_feature_separation(ref_hs_crop, separation_range=(2, 20))

100%|██████████| 18/18 [00:00<00:00, 328.02it/s]
100%|██████████| 280/280 [00:00<00:00, 3478.03it/s]


In [37]:
s_peaks.plot()

In [236]:
ref_atom_positions = am.get_atom_positions(ref_hs_crop, separation=2)

In [237]:
ref_sublattice = am.Sublattice(ref_atom_positions, image=ref_hs_crop.data)

In [238]:
ref_sublattice

<Sublattice,  (atoms:149,planes:0)>

In [239]:
ref_sublattice.find_nearest_neighbors()
ref_sublattice.refine_atom_positions_using_center_of_mass()
ref_sublattice.refine_atom_positions_using_2d_gaussian()

Center of mass: 100%|██████████| 149/149 [00:00<00:00, 1791.21it/s]
Gaussian fitting: 100%|██████████| 149/149 [00:10<00:00, 14.47it/s]


In [240]:
ref_sublattice.get_position_history().plot()

100%|██████████| 149/149 [00:00<00:00, 2943.29it/s]


In [241]:
ref_atom_list = ref_sublattice.atom_list

In [242]:
ref_atom_list[0]

<Atom_Position,  (x:61.7,y:61.1,sx:0.7,sy:0.6,r:1.9,e:1.2)>

In [270]:
defocus_probes_dict[0][6]['json_path']

'/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_25April2020/Graphene_defect_8.0mrad_281.70A_def_4.35A_step_size/2000iter_20200426-222252.json'

In [218]:
plt.figure()
plt.imshow(np.angle(test_obj))

<matplotlib.image.AxesImage at 0x7f69c1ab6510>

In [271]:
test_obj = crop_recon_obj(defocus_probes_dict[0][6]['json_path'])

img = abs(np.min(np.angle(test_obj))) + np.angle(test_obj)
sh0 = img.shape[0]
sh1 = phase_ideal_bin.shape[0]
img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
offset_img = fourier_shift(np.fft.fftn(img_crop), shift)
offset_img = np.real(np.fft.ifftn(offset_img))

fig, ax = plt.subplots(1,2)
ax[0].imshow(img_crop)
ax[1].imshow(np.real(offset_img))

img_hs = hs.signals.Signal2D(offset_img)
img_hs_crop = img_hs.isig[5:70,5:70]
img_hs_crop.plot()

In [231]:
img_hs.plot()

In [272]:
s_peaks = am.get_feature_separation(img_hs, separation_range=(2, 20))
s_peaks.plot()

100%|██████████| 18/18 [00:00<00:00, 565.63it/s]
100%|██████████| 188/188 [00:00<00:00, 3748.05it/s]


In [273]:
test_positions = am.get_atom_positions(img_hs_crop, separation=2)

In [274]:
test_sublattice = am.Sublattice(test_positions, image=img_hs_crop.data)

In [275]:
test_sublattice

<Sublattice,  (atoms:104,planes:0)>

In [276]:
test_sublattice.find_nearest_neighbors()
test_sublattice.refine_atom_positions_using_center_of_mass()
test_sublattice.refine_atom_positions_using_2d_gaussian()

Center of mass: 100%|██████████| 104/104 [00:00<00:00, 3044.89it/s]
Gaussian fitting: 100%|██████████| 104/104 [00:13<00:00,  7.64it/s]


In [184]:
test_sublattice.get_position_history().plot()

100%|██████████| 239/239 [00:00<00:00, 1313.30it/s]


In [65]:
test_sublattice.plot_ellipticity_map()

In [277]:
test_atom_list = test_sublattice.atom_list

In [278]:
test_sublattice.plot()

In [279]:
test_atom_list[0]

<Atom_Position,  (x:38.6,y:61.8,sx:1.6,sy:2.2,r:0.4,e:1.4)>

In [280]:
ref_coord = []
test_coord = []
for i in range(len(ref_atom_list)):
    ref_coord.append(list(ref_atom_list[i].get_pixel_position()))
for i in range(len(test_atom_list)):
    test_coord.append(list(test_atom_list[i].get_pixel_position()))
ref_coord = np.asarray(ref_coord)
test_coord = np.asarray(test_coord)

In [281]:
plt.figure()
plt.scatter(ref_coord[:,0], ref_coord[:,1])
plt.scatter(test_coord[:,0], test_coord[:,1])

<matplotlib.collections.PathCollection at 0x7f69c0fb71d0>

In [282]:
def atom_dist(x1, y1, x2, y2):
    return(np.sqrt((y2-y1)**2+(x2-x1)**2))


def check_atom_found(ref_atom, exp_list, tol):
    distance_check = []
    for exp_atom in exp_list:
        distance_check.append(atom_dist(ref_atom[0], ref_atom[1], exp_atom[0], exp_atom[1]) < tol)
    return any(distance_check)

In [283]:
len(ref_coord)

149

In [284]:
len(test_coord)

104

In [293]:
ref_to_compare = []
ind_to_del = []
match_count = 0
for i, atom in enumerate(ref_coord):
    if check_atom_found(atom, test_coord, 2.28):
        match_count += 1
    else: 
        ind_to_del.append(i)
ref_to_compare = np.delete(ref_coord, ind_to_del, 0)
print('number of atoms missing: ', len(ref_coord) - match_count)

number of atoms missing:  45


In [294]:
match_count

104

In [295]:
len(ref_coord) - len(test_coord)

45

In [296]:
len(ref_to_compare)

104

In [297]:
mse = ((ref_to_compare - test_coord)**2).mean()
print(mse)

279.6274438286989
