# Import Dependencies

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import abtem
from ase import io
import ase
from ase.visualize import view
from scipy.ndimage import gaussian_filter
import py4DSTEM
from pymatgen.io.ase import AseAtomsAdaptor
from pathlib import Path
import pandas as pd
from mp_api import client


# Define Fucntions

In [None]:
# Method to create non symetrically equivalent hkl vectors 

def create_zone_axes(
    struct,
    num_zones:int = 10,
    angle_step_zone_axis:float = 0.5,
    k_max:float = 1,
    max_hkl:int = 5,
    seed:int = 42,
    tol_den:int = 4,
    random_zones:bool = False,
    **kwargs,
):
    """
    struct: pymatgen structure
    num_zones: how many zones to return, if bigger than total provided it will give all zones 
    angle_step_zone_axis: step size, smaler step gives more zones, but will be high index 
    k_max: I don't think this matters for orientation plan... so keep small 
    max_hkl: largest vectors to include inclusive 
    seed: random seed set for reproducibility 
    tol_den: function to rationalise the indicies
    
    """
    
    # create a py4DSTEM crystal
    crystal = py4DSTEM.process.diffraction.Crystal.from_pymatgen_structure(struct)

    # calculate the structure factors 
    crystal.calculate_structure_factors(k_max)
    
    # calculate the orientation plan
    crystal.orientation_plan(
        angle_step_zone_axis=angle_step_zone_axis,
        zone_axis_range='auto',
        calculate_correlation_array = False,
    )
    
    #  self.orientation_ref
    zones = np.unique(
        np.apply_along_axis(
            crystal.rational_ind,
            axis=1,
            arr=crystal.orientation_vecs,
            tol_den=tol_den,
        ),
        axis=0
    )
    # filterk to keep the low index 
    filtered_zones = zones[(abs(zones) <= max_hkl).all(axis=1)]

    # if not random pick the lowest ones
    if not random_zones:
        
        if filtered_zones.shape[0] > num_zones:
            return filtered_zones[:num_zones]
    
        else:
            return filtered_zones
    #rng 
    rng = np.random.default_rng(seed=seed)
    
    # check if the number of filtered zones is longer than the number requested
    if filtered_zones.shape[0] > num_zones:
        
        # pick random zones from the filtered list if there are more zones than 
        zones = rng.permutation(filtered_zones)[:num_zones]
        
        return zones
    
    # if its not longer 
    else:
        print(f"less zones ({filtered_zones.shape[0]}) returned than requested ({num_zones}), try increasing max_hkl to ensure sufficent zones")
        # shuffle the array and return it
        return filtered_zones


In [None]:
# rotate and tile the unit cell with cropping
def rotate_and_tile_atoms(
    struct, 
    # atoms, 
    # uvw,
    # pos,
    # num,
    proj_dir = (0,0,1),
    proj_dir_cartesian = None,
    cell_size = (50,50,50),
    return_cell_params: bool = False,
    **kwargs,
    ):
    """
    Rotate and tile unit cell to fill a larger orthogonal cell. Imports N sites, tiles to M sites.

    Parameters
    ----------
    struct: pymatgen.Structure
        pymatgen structure to be roated, tiled and returned as an ase object
    # atoms: ase.Atoms
    #     ase atoms object to be rotated and tiled 
    # uvw: np.array
    #     (3,3) array where the rows are the (u,v,w) unit cell vectors (Angstroms).
    # pos: np.array
    #     (N,3) array containing fractional atomic coordinates
    # num: np.array
    #     (N, ) array containing atomic numbers.
    proj_dir: np.array
        (3, ) projection direction in terms of the u,v,w vectors.
    proj_dir_cartesian: np.array
        (3, ) projection direction in terms of (x,y,z) coordinates.
    cell_size: 
        (3, ) cell size (Angstroms).

    Returns
    ----------
    structure: ase.Atoms
        ase.atoms structure
    xyz_tile, num_tile: (3,)
        optional returns 
    
    """

    atoms = AseAtomsAdaptor.get_atoms(struct)
    uvw = np.array(atoms.cell.data)
    pos = atoms.get_scaled_positions()
    num = atoms.numbers
    # projection vectors
    if proj_dir_cartesian is None:
        #         w_proj = uvw @ np.array(proj_dir).astype('float')
        #         w_proj = uvw @ np.array(proj_dir).astype('float')
        w_proj = \
            uvw[0,:] * proj_dir[0] + \
            uvw[1,:] * proj_dir[1] + \
            uvw[2,:] * proj_dir[2]
    else:
        w_proj = np.array(proj_dir_cartesian).astype('float')
    w_proj /= np.linalg.norm(w_proj)
    if w_proj[0] < 1e-3:
        u_proj = np.array((1.0,0,0))
    else:
        u_proj = np.array((0,1.0,0))
    v_proj = np.cross(w_proj, u_proj)
    v_proj /= np.linalg.norm(v_proj)
    u_proj = np.cross(v_proj, w_proj)

    proj = np.linalg.inv(np.vstack((
        u_proj,
        v_proj,
        w_proj,
    )))    
    uvw_proj = uvw @ proj
    
    # Determine tiling range
    pos_corner = np.array((
        (0,           0,0                      ),
        (cell_size[0],0,0                      ),
        (0,           cell_size[1],0           ),
        (cell_size[0],cell_size[1],0           ),
        (0,           0,           cell_size[2]),
        (cell_size[0],0,           cell_size[2]),
        (0,           cell_size[1],cell_size[2]),
        (cell_size[0],cell_size[1],cell_size[2]),
    ))
    abc = (pos_corner @ np.linalg.inv(uvw_proj))
    # print(abc.round(2))
    a_range = (
        np.floor(np.min(abc[:,0])).astype('int'),
        np.ceil(np.max(abc[:,0])).astype('int'),
    )
    b_range = (
        np.floor(np.min(abc[:,1])).astype('int'),
        np.ceil(np.max(abc[:,1])).astype('int'),
    )
    c_range = (
        np.floor(np.min(abc[:,2])).astype('int'),
        np.ceil(np.max(abc[:,2])).astype('int'),
    )

    # Tiling indices
    a,b,c,ind = np.meshgrid(
        np.arange(a_range[0],a_range[1]),
        np.arange(b_range[0],b_range[1]),
        np.arange(c_range[0],c_range[1]),
        np.arange(pos.shape[0]),
        indexing = 'ij',
    )
    abc_ind_tile = np.vstack((
        a.ravel(),
        b.ravel(),
        c.ravel(),
        ind.ravel(),
    ))

    # Cartesian coordinates
    abc_tile \
        = abc_ind_tile[:3] \
        + pos[abc_ind_tile[3,:],:].T
    xyz_tile = \
        abc_tile[0][:,None] * uvw_proj[0,:] + \
        abc_tile[1][:,None] * uvw_proj[1,:] + \
        abc_tile[2][:,None] * uvw_proj[2,:]
         
    # Atomic identities
    num_tile = num[abc_ind_tile[3,:]]

    # delete atoms outsize of cell boundaries
    keep = np.logical_and.reduce((
        xyz_tile[:,0] >= 0.0,
        xyz_tile[:,1] >= 0.0,
        xyz_tile[:,2] >= 0.0,
        xyz_tile[:,0] < cell_size[0],
        xyz_tile[:,1] < cell_size[1],
        xyz_tile[:,2] < cell_size[2],        
    ))
    
    xyz_tile = xyz_tile[keep,:]
    num_tile = num_tile[keep]
    

    # Convert to an ASE structure 

    structure = ase.Atoms(
        positions = xyz_tile,
        numbers = num_tile,
        cell = cell_size,
    )

    if not return_cell_params:
        return structure
    else:
        return structure, xyz_tile, num_tile

# Example

In [None]:
# connect to materials project
mpr = client.MPRester()
# get the structure
# struct = mpr.get_structure_by_material_id('mp-54') # picked randomly
struct = mpr.get_structure_by_material_id('mp-19') # picked randomly


In [None]:
# create the zones
zones = create_zone_axes(struct)

In [None]:
zones

In [None]:
# generate the rotated and tiled strcture
atoms_list = []

for zone in zones:
    atoms = rotate_and_tile_atoms(
        struct=struct, 
        proj_dir=zone, 
        cell_size=(50,50,100),
    )
    atoms_list.append(atoms)

In [None]:
for atoms in atoms_list:
    abtem.show_atoms(atoms)

In [None]:
# ABTEM SIMULATION

potential = abtem.Potential(atoms_list[0],sampling=.08, slice_thickness=2, parametrization='kirkland', projection='finite')
wave = abtem.PlaneWave(
    energy=300e3 # acceleration voltage in eV
)

exit_wave = wave.multislice(potential)

exit_wave.intensity().show(figsize=(6,6))


In [None]:
ctf = abtem.CTF(
    energy = wave.energy,
    semiangle_cutoff = 4500, # mrad
    focal_spread = 40, # Å
    defocus = -45.46, # Å
    Cs = -7e-6 * 1e10, # Å
)

ctf.profiles(max_angle=50).show(legend=True)

In [None]:
image_wave = exit_wave.apply_ctf(ctf)

image_wave.intensity().show()