## Import modules

In [None]:
import hyperspy.api as hs   # To import dm4 TEM image
import copy

import numpy as np
import matplotlib.pyplot as plt

import tools
import Gaussian_functions as gf
import optimization

import tensorflow as tf

#%matplotlib widget

## Import TEM images

In [None]:
base_dir = " "                  # This will be the base directory for all files being saved.
file = " "                      # TEM image File name


# Operation code -----------------------------------------------------------------------------------------------

TEM = hs.load(base_dir +  file)       # TEM is loaded dm4 file
TEM.plot()
TEM_array = TEM.data                  # TEM_array is 2d_array form of the TEM image

## Image rotation, clip, normalization, threshold

In [None]:
rotate_angle = 1    # rotate image counter-clock-wise direction with the unit of deg.

# Boundary information for clipping

up = 100
down = 500
left = 100
right = 500
line_width = 1


low_intensity_threshold = 0.15             # Remove all signals below this ratio to suppress noise
normalized_maximum_intensity = 10          # Normalized maximum intensity. Recommend to set this value similar to the atomic radius in pixel number.


# Operation code -----------------------------------------------------------------------------------------------

# im_analzyed is the processed 2d array of the image and im_shape is the shape of the image.

im_analyzed, tf_im_analyzed = tools.image_preprocess(TEM_array, rotate_angle, 
                                                     normalized_maximum_intensity, low_intensity_threshold, up, down, left, right, line_width)
im_shape = im_analyzed.shape

## Lattice positions generator and intensity investigator

In [None]:
# Lottice informations

x = 4                          # lateral lattice constant with the unit of angstrom
y = 10                         # vertical lattice constant with the unit of angstrom
x_off = 0                      # x_offset for the most upper left lattice point with the unit of pixel
y_off = 0                      # y_offset for the most upper left lattice point with the unit of pixel
len_pix = 0.1                  # length of one pixel in the unit of angstrom
row = 10                       # number of lattice rows to be analyzed
col = 10                       # number of lattice columns to be analyzed
sliding = [0, 0, 0]            # Relative x position difference between the rows with respect to the first row
                               # The number of elements should be row - 1


# Operation code -----------------------------------------------------------------------------------------------

# sum_image is 2d array of the image of average unit cell. lattices is the liattce points with the shape (n, 2), 
# where n is the number of lattice point and 2 is for the x and y cordinates. lattice_num is the number of lattice points: n. 

sum_image, lattices, lattice_num = tools.lattice_gen_check(x, y, x_off, y_off, len_pix, row, col, sliding, im_analyzed)

## Positions in unit cell generator

In [None]:
# Put atomic positions [x, y] in a unit cell with respect to lattice point with the unit of ratio with repect to lattice constant.

Atom_positions_dic = {"atom1" : [[0, 0], [0, 0], [0, 0], [0, 0]], 
                      "atom2" : [[0, 0], [0, 0], [0, 0], [0, 0]], 
                      "atom3" : [[0, 0], [0, 0]],
                      "atom4" : [[0, 0], [0, 0]]}


# Operation code -----------------------------------------------------------------------------------------------

# posit_pix is the list of relative atomic positions with the unit of pixel. 
# The length of posit_pix is the number of atom type (length of Atom_positions_dic).
# Each element of posit_pix is an array with the shape of (n, 2), where n is the number of cordinates of the atom type, and 2 is for x and y cordinates.
# atom_num_list is a list whose elements are numbers of cordinates of each atom type.

posit_pix, atom_num_list = tools.atom_positions_iu_gen_check(Atom_positions_dic, sum_image, x, y, len_pix)

## Positions generator

In [None]:
# Run this code to see that the positions are well generated.

# positions is an array with the shape of (2, n), where 2 is for x and y cordinates and n is the number of total atoms analyzed.
# atom_resolved_positions is a list whose elements are the positions of each atom type.

positions = tools.position_gen(lattices, posit_pix)
atom_resolved_positions = tools.unpack_atom_type(positions,  lattice_num, atom_num_list)

plt.imshow(im_analyzed, cmap = "gray")
plt.title("Positions look vaild?")

for atom in range(len(atom_num_list)):
    
    plt.scatter(atom_resolved_positions[atom][0,:], atom_resolved_positions[atom][1,:], s = 15)


## Parameter initialization for lattice optimization

In [None]:
# The element order of lists is the same as that of Atom_positions_dic (positions in unit cell generator)
# Amplitude_pri and width_pri is the initial guess for the ampltidue and width of Gaussian function
# Use intensity investigator to determine initial Gaussain parameters.

# Lattice constant steps should be determined as the minimum value producing a pixel change when it is accumulated. For example, 
# if the number of column analyzed is 10, and a pixel size is 0.1 angstrom, update step should be silightly larger than 0.1/10 = 0.01.

pad_size_list = [10, 10, 10, 10]     # Gaussian fitting range for atoms with the unit of pixel. Should be large enough to cover atomic radius. 
                                     
amplitude_pri = [0, 0, 0, 0]       # Initial guess of Gaussian amplitude
width_pri = [0, 0, 0, 03]             # Initial guess of Gaussian width with the unit of pixel

x_update_step = 5e-3             # This is update step of x lattice constant with the unit of angstrom. 
y_update_step = 0.01            # This is update step of y lattice constant with the unit of angstrom.

loss_array = []                  # Here, loss is RMS eror between TEM image and the simulation. History will be accmulated and saved.

# Operation code -----------------------------------------------------------------------------------------------

x_l, y_l, x_off_l, y_off_l, sliding_l, posit_pix_l, tf_amplitude_init, tf_width_init, positions, positions_pri = \
tools.lattice_optimization_init(x, y, x_off, y_off, sliding, posit_pix, amplitude_pri, width_pri, row, col, len_pix)

## Lattice optimization

In [None]:
# Lattice constant (x and y), offset (x_off and y_off), slinding, positions in unit cell (posit_pix), 
# Gaussian amplitude for the each atom type (amplitude_pri), and Gaussian width for the each atom type (width_pri) will be optimized

# Optimization will produce new optimized values of x_l, y_l, x_off_l, y_off_l, sliding_l, posit_pix_l, tf_amplitude_init, tf_width_init.
           
# Optimized parameters will be saved below

lattice_atom_positions_file_name = "lattice_atom_positions"       # Optimized x_l, y_l, x_off_l, y_off_l, sliding_l, posit_pix_l (.npy)
params_init_file_name = "params_init"                             # Optimized Gaussian paramters will be saved here (.npy)

every_positions_init_file_name = "every_positions_init"           # Optimized positions for all atoms (.npy)
every_params_init_file_name = "every_params_init"                 # Optimized Gaussain paramters for all atoms (.npy)

loss_file_name = "loss"                                           # loss will be saved here (.npy)


num_epoch = 1                      # Number of sets of optimization 
num_amp = 1                        # Number of amplitude and width optimization in a set
num_posit = 1                      # Number of atomic positions optimization in a set

print_num = 1                      # Number of sets to be passed to print out loss

learning_rate = 0.01               # Learning rate, if higher, rough but fast, if lower, precise but slow. 

lattices_update = True             # Lattice_update updates lattice information: lattice constants, sliding. 
positions_update = False           # Positions_update updates atomic positions with respect to lattice point: poxit_pix
                                   # Fist, optimize lattice with positions_update False, and vice versa. Finally run with both True.



# Operation code -----------------------------------------------------------------------------------------------

optimization.lattice_optimization(num_epoch, print_num, num_amp, num_posit, learning_rate, lattices_update, positions_update, positions, im_analyzed,
                                 tf_im_analyzed, loss_array, tf_amplitude_init, tf_width_init, x_l, y_l, x_off_l, y_off_l, sliding_l, posit_pix_l, 
                                 atom_num_list, pad_size_list, x_update_step, y_update_step, row, col, len_pix, lattice_num,
                                 base_dir, lattice_atom_positions_file_name, params_init_file_name, every_positions_init_file_name, 
                                 every_params_init_file_name, loss_file_name)

    


## Lattice optimization results

In [None]:
# Run this code to the optimized atomic positions and print out the optmized lattice information.

# load recent optimized lattice information

lattice_atom_positions_file_name = lattice_atom_positions_file_name 
params_init_file_name = params_init_file_name 
every_positions_init_file_name = every_positions_init_file_name 
every_params_init_file_name = every_params_init_file_name 


# Operation code -----------------------------------------------------------------------------------------------

# positions is the optimized positions. tfv_params is an array with the shape of (2, n), 
# where 2 is for ampltidue and width, and n is the number of total atoms.
# posit_pix_l is the optmized posit_pix. params is the optimized amplitude_pri and width_pri with the shape of (2, number of atom type).
# atom_resolved_positions is the optimized atom_resolved_positions.

positions, tfv_params, posit_pix_l, params, atom_resolved_positions = \
tools.Lattice_optimization_results(base_dir, lattice_atom_positions_file_name, params_init_file_name,every_positions_init_file_name, 
                                   every_params_init_file_name, lattice_num, atom_num_list, row, col, len_pix,  im_analyzed, pad_size_list)

## Lattice optimization unit cell simulation

In [None]:
# Run this code to simulate the average unit cell based on the optimiz

# load recent parameters file

# unit_cell_pri is the 2d array of the unit cell simulated from initial guess.
# unit_cell_init is the 2d array of the unit cell simulated from optimized lattice parameters.

unit_cell_pri = gf.Gaussian_draw_init(posit_pix, np.array([amplitude_pri, width_pri]), np.max(pad_size_list), sum_image.shape)
unit_cell_init = gf.Gaussian_draw_init(posit_pix_l, params, np.max(pad_size_list), sum_image.shape)

fig, ax = plt.subplots(1, 2)

ax[0].imshow(unit_cell_pri, cmap = 'gray')
ax[0].set_title("Initial guess", pad = 10)

for i in range(len(posit_pix)):

    ax[0].scatter(posit_pix[i][:,0], posit_pix[i][:,1])

ax[1].imshow(unit_cell_init, cmap = 'gray')
ax[1].set_title("Lattice optimization", pad = 10)

for i in range(len(posit_pix_l)):

    ax[1].scatter(posit_pix_l[i][:,0], posit_pix_l[i][:,1])

ax[1].axhline(y = 26, linestyle = "--")
ax[1].axhline(y = 79, linestyle = "--")


## Paremeter initialization for free atom optimization

In [None]:
# Load optimized lattice information

positions_init_file_name = every_positions_init_file_name 
params_init_file_name = every_params_init_file_name


# Operation code -----------------------------------------------------------------------------------------------


positions = np.load(base_dir + positions_init_file_name +".npy")
tfv_params = np.load(base_dir + params_init_file_name +".npy")

positions_init = copy.deepcopy(positions)

tfv_params = tf.Variable(tfv_params, dtype = tf.float32)
tfv_params_init = copy.deepcopy(tfv_params)

#loss_array = []


## Free atom optimization

In [None]:
# Choose directories to save the optimized parameters 

positions_file_name = "positions_file"         # Optimized positions
params_file_name = "params_file"              # Optimized Gaussain parameters
loss_file_name = "loss"                        # Loss file

num_epoch = 1                      # Number of sets of optimization 
num_amp = 1                        # Number of amplitude and width optimization in a set
num_posit = 1                      # Number of atomic positions optimization in a set

print_num = 1                      # Number of sets to be passed to print out loss

learning_rate = 0.01               # Learning rate, if higher, rough but fast, if lower, precise but slow. 0.01 should be okay

gamma = 0                          # If this parmeter is 1, optimization tends to escape from asymmetric peak with a certain pobabilty.
                                   # If this parmeter is 0, optimization just fits any peak with Gaussian function regardless of its asymmetry.
                                   # Can be chosen from 0  to 1. Recommend set it to 0 unless asymmetric peak should be highly avoided.

reg_params = 0.01                  # This parameters keep the Gaussian parameters at the initial values for the stability                    
reg_posits = 0.01                  # This parameters keep the positions at the initial values for the stability   
                                # Gradually decrease reg_params and reg_positions to update the parameters in a stable way.
                                   

# Operation code -----------------------------------------------------------------------------------------------

optimization.free_atom_optimization(num_epoch, print_num, num_amp, num_posit, learning_rate, gamma, reg_params, reg_posits, positions, tfv_params,
                          positions_init, tfv_params_init, atom_num_list, pad_size_list, im_analyzed, tf_im_analyzed,
                          base_dir, positions_file_name, params_file_name, loss_file_name, loss_array)


## Free atom opmization results

In [None]:
# Run this code to see the change in the lattice information and compare the TEM image and fitted image

# laod positions and params

positions_file_name = positions_file_name
params_file_name = params_file_name

# positions and parameters are the arrays of optimized positions and Gaussian parameters with the shape of (2, n).
# tfv_params is the variable version of params.
# atom_resolved_positions is a list with the length of number of atom type, whose elements are the optmized positions of each atom type.

positions, params, tfv_params, atom_resolved_positions = tools.free_atom_opmization_results(base_dir, positions_file_name, params_file_name, 
                                                                                            lattice_num, atom_num_list, pad_size_list, im_analyzed)


## Cutting boundary row or col

In [None]:
# Probably boundary parts would not be well fitted. Cut that parts. 
# cut_row (cut_col) = n will cut n rows (cols) from each upper (left) and lower (right) boundaries. 
# Total number of cutting will be 2*n.

cut_row = 0        
cut_col = 0


# Operation code -----------------------------------------------------------------------------------------------


atom_resolved_params = tools.unpack_atom_type(params, lattice_num, atom_num_list)

# cut_atom_positions and cut_atom_params is the cut versions of positions and params.

cut_atom_positions, cut_atom_params = tools.boundary_cut(atom_resolved_positions, atom_resolved_params, row, col, cut_col = cut_col, cut_row = cut_row)

row_cut_start = int(y*cut_row/len_pix)
row_cut_end = im_analyzed.shape[0]-int(y*cut_row/len_pix)
col_cut_start = int(x*cut_col/len_pix)
col_cut_end = im_analyzed.shape[1]-int(x*cut_col/len_pix)

plt.imshow(im_analyzed[row_cut_start:row_cut_end, 
             col_cut_start : col_cut_end], cmap = "gray")
plt.title("Positions with optimized every single atom information")

for atom in range(len(atom_num_list)):
    
    plt.scatter(cut_atom_positions[atom][0,:] - col_cut_start, cut_atom_positions[atom][1,:]-row_cut_start, s = 10)

fig, ax = plt.subplots(1, 2, figsize = (10, 5))

ax[0].imshow(im_analyzed[row_cut_start:row_cut_end, 
             col_cut_start : col_cut_end], cmap = "gray")
ax[0].set_title("Experimental image")
ax[1].imshow(gf.Gaussian_position(positions, params, atom_num_list, pad_size_list, im_analyzed.shape)[row_cut_start:row_cut_end, 
             col_cut_start : col_cut_end], cmap = "gray")
ax[1].set_title("Fitted image")

cut_image = im_analyzed[row_cut_start:row_cut_end, col_cut_start : col_cut_end]
cut_simul = gf.Gaussian_position(positions, params, atom_num_list, pad_size_list, im_analyzed.shape)[row_cut_start:row_cut_end, 
             col_cut_start : col_cut_end]

## Average and standard deviation of Gaussain parameters

In [None]:
# The information will be saved at the below directory. The ratio is std/avg.

average_std_gaussain_params_file_name = "average_std_gaussain_params.txt"


# Operation code -----------------------------------------------------------------------------------------------


atom_names = list(Atom_positions_dic.keys())

average_amplitude = ""
average_width = ""

for i in range(len(atom_names)):

    average_amplitude += (f"{atom_names[i]}: avg amp = {np.mean(cut_atom_params[i][0]):.4g} \
std = {np.std(cut_atom_params[i][0]):.4g} (ratio : {np.std(cut_atom_params[i][0])/np.mean(cut_atom_params[i][0]):.4g})\n")

    average_width += (f"{atom_names[i]}: avg width = {np.mean(cut_atom_params[i][1]):.4g} \
std = {np.std(cut_atom_params[i][1]):.4g} (ratio : {np.std(cut_atom_params[i][1])/np.mean(cut_atom_params[i][1]):.4g})\n")

average_std_gaussain_params = average_amplitude + "\n\n" +average_width

print(average_std_gaussain_params)

with open(base_dir + average_std_gaussain_params_file_name, "w") as file:
    file.write(average_std_gaussain_params)

    

## Average atomic positions and weighted average atomic positions in unit cell

In [None]:
# The information will be saved at the below directory. 

average_atom_file_name = "average_w_average_atom_posit_in.txt"

# This code will generate average atomic positions in unit cell. The weighted average is (sum(amplitude*position)/sum(amplitude)), 
# which makes the position move toward to the high intensity positions


# Operation code -----------------------------------------------------------------------------------------------

atom_names = list(Atom_positions_dic.keys())

# deep_unpacked_positions and deep_unpacked_params are lists, whose elements are the positions and params of each atom with different cordinates
# in the unit cell.
# a_pix_posit and a_params have the same type and shape with posit_pix, but with the average of optimized parameters.
# wa_pix_posit and wa_params have the same type and shape with posit_pix, but with the weighted average of optimized parameters.
# The positions information will be moved so that the first positions of wa_posit_pix or a_posit_pix are the same as that of posit_pix.

deep_unpacked_positions, deep_unpacked_params = tools.deep_unpack(cut_atom_positions,cut_atom_params, atom_num_list)
a_posit_pix, a_params = tools.weight_average_poist_pix(deep_unpacked_positions, deep_unpacked_params, atom_num_list, weight = False)
wa_posit_pix, wa_params = tools.weight_average_poist_pix(deep_unpacked_positions, deep_unpacked_params, atom_num_list)

average_atom = ""

for i in range(len(atom_names)):

    average_atom += f"{atom_names[i]}: avg amplitude = \n{a_posit_pix[i] + posit_pix_l[0][0] - a_posit_pix[0][0]}\n\n"

for i in range(len(atom_names)):

    average_atom += f"{atom_names[i]}: weighted avg amplitude = \n{wa_posit_pix[i] + posit_pix_l[0][0] - wa_posit_pix[0][0]}\n\n"

for i in range(len(atom_names)):

    average_atom += f"{atom_names[i]}: avg Gaussian parameters = \n{a_params[i]}\n\n"


print(average_atom)

with open(base_dir + average_atom_file_name, "w") as file:
    file.write(average_atom)
    

## Average unit cell and weighted average unit cell from free atom opimization

In [None]:
resolution = 10      # image quality increases with this parameter

# The information will be saved at the below directory. 

average_atom_file_name = "average_w_average_atom_posit_in.txt"


# Operation code -----------------------------------------------------------------------------------------------

# w_atom_im is 2d array of the unit cell image with weighted average positions. w_atom_positions is the weighted average positions.
# atom_im is 2d array of the unit cell image with average positions. atom_positions is the average positions.

w_atom_im, w_atom_positions = gf.Gaussian_draw_high_resol(wa_pix_posit, wa_params, sum_image.shape, posit_pix_l[0][0], resolution, np.max(pad_size_list))
atom_im, atom_positions = gf.Gaussian_draw_high_resol(a_pix_posit, a_params, sum_image.shape, posit_pix_l[0][0], resolution, np.max(pad_size_list))

# w_atom_resolved_positions is a list with the length of number of atom type, whose elements are the weighted average positions of each atom type.
# a_atom_resolved_positions is a list with the length of number of atom type, whose elements are the average positions of each atom type.

w_atom_resolved_positions = tools.unpack_atom_type(w_atom_positions, 1, atom_num_list)    
a_atom_resolved_positions = tools.unpack_atom_type(atom_positions, 1, atom_num_list)        
   
fig, ax = plt.subplots(1, 2)

plt.suptitle("Free atom optimization average unit cell", y = 1.05)

ax[0].imshow(atom_im, extent = [0, 22, 106, 0], cmap = "gray")
ax[0].set_title("Avearge positions \n w/o amplitude weight", pad = 15)

for i in range(len(a_atom_resolved_positions)):
    ax[0].scatter(a_atom_resolved_positions[i][0], a_atom_resolved_positions[i][1])

ax[1].imshow(w_atom_im, extent = [0, 22, 106, 0], cmap = "gray")
ax[1].set_title("Avearge positions \n w amplitude weight", pad = 15)

for i in range(len(a_atom_resolved_positions)):
    ax[1].scatter(w_atom_resolved_positions[i][0], w_atom_resolved_positions[i][1])

