# 5- Evaluation of Graphene reconstructions using atomap

In this notebook we use atomap to identify atomic positions in the reconstructed objects and then compare these with the known atomic positions in the ground truth. We then save these positions as hdf5 files. We use sklearn KDTree method to evaluate each case, quantifying the mean squared error of the identified to ground truth atomic position distances and also false positive and false negative values.

[Finding atomic positions using atomap](#am)</br>
[Comparison with ground truth](#compare)</br>
[Trials on single datasets](#single)

In [1]:
%matplotlib qt

In [2]:
import os
import sys
sys.path.append('/dls/science/groups/e02/Mohsen/code/Git_Repos/Merlin-Medipix/')
import epsic_tools.api as epsic
import numpy as np
import atomap.api as am
import hyperspy.api as hs
hs.set_log_level('ERROR')
from skimage.feature import register_translation
from scipy.ndimage import fourier_shift
import matplotlib.pyplot as plt



In [3]:
matrix_path = '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_4June2020_512pixArray/'
json_files = epsic.sim_utils.get_ptyREX_recon_list(matrix_path, run_id = '2000iter')
conv_angles = [0.016, 0.020, 0.024, 0.030, 0.040, 0.050, 0.064, 0.084]
diff_overlap = [-22.5, 1.99, 18.33, 34.70, 51.00, 60.80, 69.37, 76.67]
real_probe_overlap = [0, 5, 15, 35, 60, 70, 80, 90]


data_list_of_dicts = []
for file in json_files:
    j_dict = epsic.sim_utils.json_to_dict_sim(file)

    for i, angle in enumerate(conv_angles):
        data_list_of_dicts.append([])
        if j_dict['process']['common']['probe']['convergence'] == angle:
            data_list_of_dicts[i].append(j_dict)
                
for i, angle in enumerate(conv_angles):
    data_list_of_dicts[i].sort(key=lambda e: e['process']['common']['scan']['dR'][0], reverse = True)


<a id='am'></a>
# Finding atomic positions using atomap

In [4]:
phase_ideal_bin = hs.load('phase_ideal_bin.hspy')
phase_ideal_bin.axes_manager
# phase_ideal_bin = phase_ideal_bin.data

Signal axis name,size,offset,scale,units
,169,0.5,2.0,
,169,0.5,2.0,


In [5]:
phase_ideal_bin.data.shape

(169, 169)

Getting the reference positions:

In [6]:
ref_atom_positions = am.get_atom_positions(phase_ideal_bin, separation=4)
ref_sublattice = am.Sublattice(ref_atom_positions, image=phase_ideal_bin.data)
ref_sublattice.find_nearest_neighbors()
ref_sublattice.refine_atom_positions_using_center_of_mass()
ref_sublattice.refine_atom_positions_using_2d_gaussian()
ref_atom_list = ref_sublattice.atom_list

HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=282.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=282.0, style=ProgressStyle(descrip…




In [7]:
ref_sublattice.plot()

In [8]:
ref_coord = []
for i in range(len(ref_atom_list)):
    
    ref_coord.append(list((ref_atom_list[i].pixel_x,
                    ref_atom_list[i].pixel_y,
                    ref_atom_list[i].sigma_average,
                    ref_atom_list[i].amplitude_gaussian)))

ref_coord = np.asarray(ref_coord)
print('Number of atoms in reference: ', ref_coord.shape[0])


Number of atoms in reference:  282


In [9]:
np.save('ground_truth_positions', ref_coord)

Here we get the object phase in each case of the matrix, crop the area to match the size of the ideal phase, check for shifts, apply the required shift to the full image and then crop to feed to atomap for atomic position detection:

In [10]:
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
sampling_factors = []
#print(rows, cols)
for idx in range(rows):    
      for idy in range(cols): 
            l = epsic.ptycho_utils.e_lambda(data_list_of_dicts[idy][idx]['process']['common']['source']['energy'][0])
            CL = data_list_of_dicts[idy][idx]['process']['common']['detector']['distance']
            det_pix_array = data_list_of_dicts[idy][idx]['process']['common']['detector']['crop']
            det_pitch = data_list_of_dicts[idy][idx]['process']['common']['detector']['pix_pitch'][0]
            num_probe_pos = data_list_of_dicts[idy][idx]['process']['common']['scan']['N'][0]
            probe_semi_angle = data_list_of_dicts[idy][idx]['process']['common']['probe']['convergence']
            probe_step_size = data_list_of_dicts[idy][idx]['process']['common']['scan']['dR'][0]
            defocus = data_list_of_dicts[idy][idx]['experiment']['optics']['lens']['defocus'][0]
            
            recon_pix_size = l * CL / (det_pix_array[0] * det_pitch)
            probe_rad = epsic.sim_utils.calc_probe_size(recon_pix_size, det_pix_array, l, defocus, probe_semi_angle, plot_probe=False)
            s = epsic.ptycho_utils.get_sampling_factor( recon_pix_size * det_pix_array[0], 2 * probe_rad, num_probe_pos, probe_step_size)
            sampling_factors.append(s)
            
            data_list_of_dicts[idy][idx]['sampling_factor'] = s

In [13]:
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
phase_ideal_bin = phase_ideal_bin.data
for idx in range(rows):    
      for idy in range(cols): 
            obj = epsic.ptycho_utils.crop_recon_obj(data_list_of_dicts[idy][idx]['json_path'])
            img = abs(np.min(np.angle(obj))) + np.angle(obj)
            sh0 = img.shape[0]
            sh1 = phase_ideal_bin.shape[0]
            img_crop = img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]
            shift, error, diffphase = register_translation(phase_ideal_bin, img_crop)
#             register_translation requires the two inpute to have the same size
            offset_img = fourier_shift(np.fft.fftn(img), shift)
            offset_img = np.real(np.fft.ifftn(offset_img))
            offset_img_crop = offset_img[int(sh0/2 - sh1/2):int(sh0/2 + sh1/2), int(sh0/2 - sh1/2):int(sh0/2 + sh1/2)]

            img_hs = hs.signals.Signal2D(offset_img_crop)
            


            exp_positions = am.get_atom_positions(img_hs, separation=4)
            exp_sublattice = am.Sublattice(exp_positions, image=img_hs.data)

            exp_sublattice.find_nearest_neighbors()
            exp_sublattice.refine_atom_positions_using_center_of_mass()
            exp_sublattice.refine_atom_positions_using_2d_gaussian()

            exp_atom_list = exp_sublattice.atom_list


            exp_coord = []
            for i in range(len(exp_atom_list)):

                exp_coord.append(list((exp_atom_list[i].pixel_x,
                                exp_atom_list[i].pixel_y,
                                exp_atom_list[i].sigma_average,
                                exp_atom_list[i].amplitude_gaussian)))

            exp_coord = np.asarray(exp_coord)
            data_list_of_dicts[idy][idx]['atom_pos'] = exp_coord

HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=275.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=275.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=271.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=271.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=260.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=260.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=281.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=281.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=304.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=304.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=291.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=291.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=302.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=302.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=287.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=287.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=275.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=275.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=294.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=294.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=255.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=255.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=270.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=270.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=286.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=286.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=289.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=289.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=281.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=281.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=283.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=283.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=287.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=287.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=291.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=291.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=270.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=270.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=268.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=268.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=280.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=280.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=300.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=300.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=291.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=291.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=283.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=283.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=273.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=273.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=268.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=268.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=271.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=271.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=267.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=267.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=288.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=288.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=271.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=271.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=276.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=276.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=271.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=271.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=237.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=237.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=242.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=242.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=210.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=210.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=229.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=229.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=244.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=244.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=279.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=279.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=276.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=276.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=277.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=277.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=216.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=216.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=217.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=217.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=242.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=242.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=267.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=267.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=262.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=262.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=276.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=276.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=275.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=275.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=276.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=276.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=222.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=222.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=256.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=256.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=267.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=267.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=269.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=269.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=274.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=274.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=275.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=275.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=276.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=276.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=277.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=277.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=259.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=259.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=260.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=260.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=277.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=277.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=271.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=271.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=273.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=273.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=277.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=277.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=228.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=228.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Center of mass', max=250.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Gaussian fitting', max=250.0, style=ProgressStyle(descrip…




identified atom positions versus the known positions:

In [14]:
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
#print(rows, cols)
# fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(8, 11))

for idx in range(rows):    
      for idy in range(cols): 
            try:
                j_dict = epsic.sim_utils.json_to_dict_sim(data_list_of_dicts[idy][idx]['json_path'])
                exp_coord = data_list_of_dicts[idy][idx]['atom_pos']
                
                missed_atoms = ref_coord.shape[0] - exp_coord.shape[0]
                s = data_list_of_dicts[idy][idx]['sampling_factor']
                axs[idx,idy].scatter(ref_coord[:,0], ref_coord[:,1], s = 1, c = 'g')
                axs[idx,idy].scatter(exp_coord[:,0], exp_coord[:,1], s = 1, c = 'r')
                axs[idx,idy].set_xticks([])
                axs[idx,idy].set_yticks([])
                axs[idx,idy].set_title(str(missed_atoms) + '       ' + str(np.round(s,1)), color = 'black', fontsize = 7)

            except IndexError:
                pass
fig.tight_layout(pad = 1.0)            

In [28]:
plt.figure()
plt.plot(sampling_factors)

[<matplotlib.lines.Line2D at 0x7f7b544be390>]

In [29]:
# plotting individual dataset
plt.figure()
plt.scatter(ref_coord[:,0], ref_coord[:,1], s = 5, c = 'g')
exp_pos = data_list_of_dicts[7][7]['atom_pos']
plt.scatter(exp_pos[:,0], exp_pos[:,1], s = 5, c = 'r')

<matplotlib.collections.PathCollection at 0x7f7b47aad650>

In [32]:
# Saving the fitted data into hdf file
results_folder = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results'

rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
for idx in range(rows):    
      for idy in range(cols): 
            h5_file_name = os.path.join(results_folder, data_list_of_dicts[idy][idx]['json_path'].split('/')[-2] + '.h5')
            save_dict_to_hdf5(data_list_of_dicts[idy][idx], h5_file_name)

In [11]:
dd = epsic.ptycho_utils.load_dict_from_hdf5('/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix_backup/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size.h5')


In [12]:
dd

{'atom_pos': array([[121.41388905, 163.86922113,   1.68823719,   0.72288737],
        [113.00532043, 164.16584041,   1.70654456,   1.05055257],
        [ 96.50912662, 164.13181785,   1.97042467,   1.20076947],
        ...,
        [ 98.23902267,  10.10589817,   1.55754718,   0.49673928],
        [ 89.56400168,  10.0962495 ,   1.78996969,   0.84017837],
        [ 64.93901982,  10.33318438,   1.76751978,   0.82241142]]),
 'base_dir': '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_4June2020_512pixArray/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size',
 'experiment': {'data': {'data_path': '/dls/e02/data/2020/cm26481-1/processing/pty_simulated_data_MD/sim_matrix_ptyREX_4June2020_512pixArray/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size/Graphene_defect_8.0mrad_281.75A_def_2.18A_step_size.h5',
   'dead_pixel_flag': 0,
   'flat_field_flag': 0,
   'key': '4DSTEM_simulation/data/datacubes/hdose_noisy_data',
   'load_flag': 1,
   'meta_type': 'hdf'},
 

In [13]:
def get_distance(p1, p2):
    x1 = p1[0]
    y1 = p1[1]
    x2 = p2[0]
    y2 = p2[1]
    
    dx = x2 - x1
    dy = y2 - y1
    return np.sqrt(dx ** 2  +dy **2)

In [14]:
import matplotlib.pyplot as plt
def plot_kdtree_results(paired_list, FN, FP, unlabeled_experiment=None, truth=None):
    

    if type(FN) is list:
        FN = np.asarray(FN)
    if type(FP) is list:
        FP = np.asarray(FP)
        
    if (unlabeled_experiment is None) and (truth is None):
        fig, axs = plt.subplots(nrows=1, ncols=1)
        detected_list = []
        for i in range(len(paired_list)):
            detected_list.append(paired_list[i][0])
        detected_arr = np.asarray(detected_list)
        axs.scatter(detected_arr[:,0], detected_arr[:,1], c='g', marker='s', label='matched')
        axs.scatter(FN[:,0], FN[:,1], marker='o', c = 'r', label = 'false_neg')
        axs.scatter(FP[:,0], FP[:,1], marker='o', c = 'b', label = 'false_pos')
        axs.set_xticks([])
        axs.set_yticks([])
        axs.legend()
        
        
    else:
        if type(unlabeled_experiment) is list:
            unlabeled_experiment = np.asarray(unlabeled_experiment)
        if type(truth) is list:
            truth = np.asarray(truth)
        fig, axs = plt.subplots(nrows=1, ncols=2)
        detected_list = []
        for i in range(len(paired_list)):
            detected_list.append(paired_list[i][1][0])
        detected_arr = np.asarray(detected_list)
        axs[0].scatter(detected_arr[:,0], detected_arr[:,1], c='g', marker='s', label='matched')
        axs[0].scatter(FN[:,0], FN[:,1], marker='x', c = 'r', label = 'false_neg')
        axs[0].scatter(FP[:,0], FP[:,1], marker='o', c = 'm', label = 'false_pos')
    #     axs[0].scatter(truth_arr[:,0], truth_arr[:,1],s =3 ,marker = 'o', c = 'r', label = 'truth')

        axs[1].scatter(truth[:,0], truth[:,1], s =3 ,marker = 'o', c = 'r',label = 'truth')
        axs[1].scatter(unlabeled_experiment[:,0], unlabeled_experiment[:,1], marker = 'x', label = 'experiment_unlabeled')
        axs[0].legend()
        axs[1].legend()
    return 


# Iterating over the fitting results 

In [9]:
# C-C bond is approx. 8 pixels and recon pixel size is 1.7e-11 m

In [15]:
# fitting_folder = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix/'
fitting_folder = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix_4pix_searchrad/'
truth = np.load('/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/ground_truth_positions.npy')
truth_pos = truth[:,:2]
# test = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/test'
search_rad = 4.0
for file in os.listdir(fitting_folder):
    file_path = os.path.join(fitting_folder, file)
    data_dict = epsic.ptycho_utils.load_dict_from_hdf5(file_path)
    exp_pos = data_dict['atom_pos'][:,:2]
    nn_dict = epsic.ptycho_utils.kdtree_NN(exp_pos, truth_pos, search_rad)
    data_dict.update({'NN_Results': nn_dict})
    epsic.ptycho_utils.save_dict_to_hdf5(data_dict, os.path.join(fitting_folder,file))

# Plotting results

In [16]:
hdf_folder = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix_4pix_searchrad/' 

In [17]:
hdf_files = []
for file in os.listdir(hdf_folder):
    hdf_files.append(os.path.join(hdf_folder, file))

In [18]:
# sorting the json files
conv_angles = [0.016, 0.020, 0.024, 0.030, 0.040, 0.050, 0.064, 0.084]
# diff_overlap = [-22.5, 1.99, 18.33, 34.70, 51.00, 60.80, 69.37, 76.67]
# real_probe_overlap = [0, 5, 15, 35, 60, 70, 80, 90]


data_list_of_dicts = []
for file in hdf_files:
    j_dict = epsic.ptycho_utils.load_dict_from_hdf5(file)

    for i, angle in enumerate(conv_angles):
        data_list_of_dicts.append([])
        if j_dict['process']['common']['probe']['convergence'] == angle:
            data_list_of_dicts[i].append(j_dict)
                
for i, angle in enumerate(conv_angles):
    print(angle)
    data_list_of_dicts[i].sort(key=lambda e: e['process']['common']['scan']['dR'][0], reverse = True)


0.016
0.02
0.024
0.03
0.04
0.05
0.064
0.084


In [19]:
data_list_of_dicts[6][6]['NN_Results']['RMSE']

0.1754256989157329

In [20]:
plot_kdtree_results(data_list_of_dicts[6][6]['NN_Results']['TP_list'],
                   data_list_of_dicts[6][6]['NN_Results']['FN_list'],
                   data_list_of_dicts[6][6]['NN_Results']['FP_list'])

In [21]:
import matplotlib.pyplot as plt
def plot_kdtree_results_matrix(paired_list, FN, FP, axs=None, **plt_kwargs):
    

    if type(FN) is list:
        FN = np.asarray(FN)
    if type(FP) is list:
        FP = np.asarray(FP)
        
    if axs is None:
        axs = plt.gca()

    detected_list = []
    for i in range(len(paired_list)):
        detected_list.append(paired_list[i][0])
    detected_arr = np.asarray(detected_list)
    axs.scatter(detected_arr[:,0], detected_arr[:,1], c='g', marker='s', label='matched', **plt_kwargs)
    axs.scatter(FN[:,0], FN[:,1], marker='o', c = 'r', label = 'false_neg', **plt_kwargs)
    axs.scatter(FP[:,0], FP[:,1], marker='o', c = 'b', label = 'false_pos', **plt_kwargs)
    axs.set_xticks([])
    axs.set_yticks([])
#         axs.legend()
    return 


In [22]:
# NN results
rows = len(data_list_of_dicts[0])
cols = len(conv_angles)
#print(rows, cols)
# fig, axs = plt.subplots(nrows=rows, ncols=cols, sharex=True, sharey=True,figsize=(8, 11))
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(8, 11))

rmse_list = []
precision_list = []
recall_list = []

conv_angles = []
step_sizes = []

for idx in range(rows):    
      for idy in range(cols): 
            try:
                plot_kdtree_results_matrix(data_list_of_dicts[idy][idx]['NN_Results']['TP_list'],
                   data_list_of_dicts[idy][idx]['NN_Results']['FN_list'],
                   data_list_of_dicts[idy][idx]['NN_Results']['FP_list'],
                                          axs=axs[idx,idy], s =1)
                axs[idx,idy].set_title(str(np.round(data_list_of_dicts[idy][idx]['NN_Results']['RMSE'],2)) + '  ' + \
                                       str(np.round(data_list_of_dicts[idy][idx]['NN_Results']['Precision'],2)) + '  ' + \
                                       str(np.round(data_list_of_dicts[idy][idx]['NN_Results']['Recall'],2)), fontsize = 8)
                
                rmse_list.append(data_list_of_dicts[idy][idx]['NN_Results']['RMSE'])
                precision_list.append(data_list_of_dicts[idy][idx]['NN_Results']['Precision'])
                recall_list.append(data_list_of_dicts[idy][idx]['NN_Results']['Recall'])
                
                step_sizes.append(data_list_of_dicts[idy][idx]['process']['common']['scan']['dR'][0])
                conv_angles.append(data_list_of_dicts[idy][idx]['process']['common']['probe']['convergence'])

            except IndexError:
                pass
fig.tight_layout(pad = 1.0)  

In [83]:
plt.close('all')

In [10]:
min(rmse_list)

0.05760031729525667

In [11]:
max(recall_list)

0.9787234042553191

In [12]:
max(precision_list)

0.9822064056939501

In [13]:
conv_angles = np.asarray(conv_angles)
step_sizes = np.asarray(step_sizes)
rmse_list = np.asarray(rmse_list)
precision_list = np.asarray(precision_list)
recall_list = np.asarray(recall_list)

In [14]:
conv_angles_reshaped = conv_angles.reshape(8,8)
conv_angles_vals = conv_angles_reshaped[0]

In [15]:
conv_angles_vals

array([0.016, 0.02 , 0.024, 0.03 , 0.04 , 0.05 , 0.064, 0.084])

In [16]:
step_sizes_reshaped = step_sizes.reshape(8,8)


In [17]:
probe_overlaps = np.asarray([0, 5, 15, 35, 60, 70, 80, 90])

In [18]:
rmse_vals = rmse_list.reshape(8,8)
precision_vals = precision_list.reshape(8,8)
recall_vals = recall_list.reshape(8,8)

In [45]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
im = ax.imshow(rmse_vals, cmap = 'RdBu')

# We want to show all ticks...
ax.set_xticks(np.arange(len(conv_angles_vals)))
ax.set_yticks(np.arange(len(probe_overlaps)))
# ... and label them with the respective list entries
ax.set_xticklabels(conv_angles_vals)
ax.set_yticklabels(probe_overlaps)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
plt.xlabel('convergence angle (rad)')
plt.ylabel('probe real space overlap (%)')

# Loop over data dimensions and create text annotations.
for i in range(len(conv_angles_vals)):
    for j in range(len(probe_overlaps)):
        text = ax.text(j, i, np.round(rmse_vals[i, j],2),
                       ha="center", va="center", color="w")

ax.set_title("RMSE values")
fig.tight_layout()
plt.show()

In [30]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1,3, sharex=True, sharey=True)
ax[0].imshow(rmse_vals, cmap = 'RdBu')

plt.xlabel('convergence angle (rad)')
plt.ylabel('probe real space overlap (%)')

ax[1].imshow(precision_vals, cmap = 'RdBu')
ax[2].imshow(recall_vals, cmap = 'RdBu')

# We want to show all ticks...
ax[0].set_xticks(np.arange(len(conv_angles_vals)))
ax[0].set_yticks(np.arange(len(probe_overlaps)))
# ... and label them with the respective list entries
ax[0].set_xticklabels(conv_angles_vals)
ax[0].set_yticklabels(probe_overlaps)



# Rotate the tick labels and set their alignment.
plt.setp(ax[0].get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
plt.setp(ax[1].get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
plt.setp(ax[2].get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")


# Loop over data dimensions and create text annotations.
for i in range(len(conv_angles_vals)):
    for j in range(len(probe_overlaps)):
        text = ax[0].text(j, i, np.round(rmse_vals[i, j],2),
                       ha="center", va="center", color="w")
        text = ax[1].text(j, i, np.round(precision_vals[i, j],2),
                       ha="center", va="center", color="w")
        text = ax[2].text(j, i, np.round(recall_vals[i, j],2),
                       ha="center", va="center", color="w")

ax[0].set_title("RMSE values")
ax[1].set_title("Precision values")
ax[2].set_title("Recall values")
fig.tight_layout()
plt.show()

# Trial on three example datasets

In [95]:
data_file_good = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix/Graphene_defect_25.0mrad_181.36A_def_2.17A_step_size.h5'
data_dict_good = load_dict_from_hdf5(data_file_good)
data_file_fair = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix/Graphene_defect_42.0mrad_144.04A_def_2.17A_step_size.h5'
data_dict_fair = load_dict_from_hdf5(data_file_fair)
data_file_poor = '/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/fitting_results/graphene_512_64matrix/Graphene_defect_8.0mrad_281.75A_def_21.77A_step_size.h5'
data_dict_poor = load_dict_from_hdf5(data_file_poor)

In [None]:
good = data_dict_good['atom_pos'][:,:2]
fair = data_dict_fair['atom_pos'][:,:2]
poor = data_dict_poor['atom_pos'][:,:2]
truth = np.load('/dls/science/groups/e02/Mohsen/code/Git_Repos/Staff-notebooks/ptyREX_sim_matrix/ground_truth_positions.npy')
truth_pos = truth[:,:2]

In [121]:
data_dict = data_dict_good

plot_kdtree_results(data_dict['NN_Results']['TP_list'],
                   data_dict['NN_Results']['FN_list'],
                   data_dict['NN_Results']['FP_list'],
                   unlabeled_experiment=data_dict['atom_pos'][:,:2],
                    truth=truth_pos)

IndexError: too many indices for array

In [None]:
plot_kdtree_results()

In [112]:
t = data_dict['atom_pos'][:,:2]

In [118]:
t[0]

array([161.68821904, 164.40397512])

In [105]:
data_dict_good.keys()

dict_keys(['NN_Results', 'atom_pos', 'base_dir', 'experiment', 'json_path', 'process', 'sampling_factor', 'sim_path'])

In [98]:
good = data_dict_good['atom_pos'][:,:2]
fair = data_dict_fair['atom_pos'][:,:2]
poor = data_dict_poor['atom_pos'][:,:2]

In [99]:
type(good)

numpy.ndarray

In [8]:
import matplotlib.pyplot as plt
truth_arr = np.asarray(truth)
good_arr = np.asarray(good)
plt.figure()
plt.scatter(truth_arr[:,0], truth_arr[:,1], label = 'truth')
plt.scatter(good_arr[:,0], good_arr[:,1], label = 'exp')
# plt.scatter(false_pos[:,0], false_pos[:,1], label = 'false_pos')
# plt.scatter(false_neg[:,0], false_neg[:,1], label = 'false_neg')
plt.legend()

<matplotlib.legend.Legend at 0x7fd6d4068950>

In [25]:
from sklearn.neighbors import KDTree
import numpy as np
search_rad = 2.
exp_list = list(poor)
truth_list = list(truth_pos)
false_neg = []
# we have atoms in truth_pos that have gone undetected in recon
false_pos = [] 
# detected an atom not present in truth_pos
paired_list = []
paired_dict = {}

distances = []
inds = []

for i in range(len(exp_list)):
    test = np.vstack((exp_list[i], truth_list))
    tree = KDTree(test, leaf_size=10)
    [ind, d] = tree.query_radius(test[:1], r=search_rad, count_only=False, return_distance = True)
    
    if len(ind[0]) == 1:
        false_pos.append(exp_list[i])
    elif len(ind[0]) == 2:
        inds.append([e-1 for e in ind[0] if e != 0])
#         paired_dict['atom ' + str(i)] = truth_list.pop(inds[-1][0])
        distances.append([e for e in d[0] if e != 0])
        atom_entry = list([truth_list.pop(inds[-1][0]), [exp_list[i], distances[-1]]])
        paired_list.append(atom_entry)

    del(test)
TP = len(paired_list)
print('TP: ', len(paired_list))
false_pos = np.asarray(false_pos)
FP = len(false_pos)
print('FP: ', len(false_pos))
false_neg = np.asarray(truth_list)
FN = len(false_neg)
print('FN: ', len(false_neg))

precision = TP / (TP + FP)
recall = TP / (TP + FN)

print('precision: ', precision)
print('recall: ', recall)
print('RMSE: ', get_RMSE(distances))

TP:  47
FP:  224
FN:  235
precision:  0.17343173431734318
recall:  0.16666666666666666
RMSE:  4.696551537163271


In [44]:
distances

[[1.813620242193791],
 [1.5240340549740834],
 [1.8158794473964452],
 [1.8918504239565088],
 [1.924412955214493],
 [1.0877746702820588],
 [1.888554457483371],
 [0.6883910758359352],
 [0.19851619319172334],
 [1.378038032807637],
 [0.8940102395976366],
 [1.9068711380041232],
 [1.697555206522311],
 [1.3480604839508514],
 [1.4982839130722612],
 [1.0365523807629742],
 [1.1094273551780591],
 [1.1656045897172729],
 [1.9318610629409174],
 [1.6723768141168398],
 [1.1605140764529431],
 [0.40437386604805103],
 [1.613268538698602],
 [1.421998215577083],
 [1.0803711483420237],
 [1.5452792070518802],
 [1.6160567210403405],
 [1.5572063237842462],
 [1.2697243345458555],
 [1.3968424372183674],
 [1.8441915545188545],
 [1.3479695637911615],
 [1.6066363888910158],
 [1.9878233192958],
 [0.6284453855924453],
 [1.5375351467389697],
 [1.4224620575007696],
 [1.7658292689060713],
 [1.9947522771351225],
 [1.1367729222186793],
 [1.4532954790605834],
 [1.5754557479591076],
 [0.3396763319442411],
 [1.928918404540200

In [49]:
distances = np.asarray(distances)
rmse = np.square((distances ** 2).mean())
print(rmse)

4.696551537163271


In [37]:
t = plot_kdtree_results(paired_list, false_neg, false_pos, truth_arr, np.asarray(poor))

'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.


In [182]:
get_distance((41.06,53.65),(41.79,54.38))

1.03237590053236

In [112]:
plt.figure()
plt.scatter(test[:,0], test[:,1])
plt.scatter(truth_pos[1][0], truth_pos[1][1])

<matplotlib.collections.PathCollection at 0x7fbf499e7410>

In [47]:
import numpy as np
rng = np.random.RandomState(0)
X = rng.random_sample((20, 2))  # 10 points in 3 dimensions
tree = KDTree(X, leaf_size=2)              # doctest: +SKIP
dist, ind = tree.query(X[:1], k=2)                # doctest: +SKIP
print(ind)  # indices of 3 closest neighbors

print(dist)  # distances to 3 closest neighbors


[[ 0 11]]
[[0.         0.10907128]]


In [53]:
import numpy as np
rng = np.random.RandomState(0)
X = rng.random_sample((10, 2))  # 10 points in 3 dimensions
tree = KDTree(X, leaf_size=2)     # doctest: +SKIP
print(tree.query_radius(X[:1], r=0.2, count_only=True))

ind = tree.query_radius(X[:1], r=0.2)  # doctest: +SKIP
print(ind)  # indices of neighbors within distance 0.3


[3]
[array([2, 0, 1])]


In [50]:
X

array([[0.5488135 , 0.71518937],
       [0.60276338, 0.54488318],
       [0.4236548 , 0.64589411],
       [0.43758721, 0.891773  ],
       [0.96366276, 0.38344152],
       [0.79172504, 0.52889492],
       [0.56804456, 0.92559664],
       [0.07103606, 0.0871293 ],
       [0.0202184 , 0.83261985],
       [0.77815675, 0.87001215]])