In [1]:
from mp_api.client import MPRester
from pymatgen.io.ase import AseAtomsAdaptor
from ase.visualize import view
import ase
from ase.io.trajectory import Trajectory
import math
import abtem
import numpy as np
from tqdm import tqdm
import os
import abtem
import pandas as pd
import plotly.express as px
import matplotlib

# Creating the strucutre in ASE

In [2]:
def get_structure_from_MPR(matarial_ids, api_key = None):
    """" 
    A function that returns an ase.atoms object interpretable by abtem
    based off stuctrure data form Materials Project Database.

    Parameters
    ----------
    material_ids : str or list of str
        String of material id/ids from Materials Project Database
    api_key: str
        Specific key for accesing Materials Project Database's API. 
        For further information visit https://next-gen.materialsproject.org/api

    Returns
    -------
    atoms : ase.Atoms
        ase.Atoms object/objects of material
    
    """
    
    if api_key is None:
        print('No api-key given - Please get api-key by logging in to Materials Project Database.')
    else:
        with MPRester(api_key) as mpr:
            docs = mpr.summary.search(material_ids=matarial_ids, fields=["structure"])
            return AseAtomsAdaptor().get_atoms(docs[0].structure)

In [3]:
def center_rotation(atoms, theta):
    """" 
    Performs a anti-clockwise rotation on a structure arounds it's center in the z-axis
    The center is defined as the center of the the atoms objects cell.

    Parameters
    ----------
    atoms : ase.Atoms
        ase.Atoms object to rotate
    theta: int or float
        number of rotation angle
    """

    rot_center = [atoms.cell[0][0]*0.5, atoms.cell[1][1]*0.5, atoms.cell[2][2]*0.5]#defining center of rotation of structure
    atoms.rotate(theta, 'z', center = rot_center) #rotation of second layer

In [4]:
def sampling_parameters(ase_structure, energy, resolution = 1):
    """
    Calculates TEM parameters based off the structure object, the resolution
    and the energy of the electron wave.
     Returns a dict with sampling paramateres. Units are in Å/Å^-1 and mrad
    
    Parameters
    ----------
    ase_structure : se.Atoms
        ase.Atoms object/objects of material
    energy: float or int
        Energy of electron wave given in eV. Standard values between 80,000 to 300,000 eV
    resolution: float or int
        Desired resolution of TEM image given in Å.

    Returns
    -------
    sampling metadata : dict 
        A dictionary with sampling parameters: 
          keys()
            energy: float
              Energy of electron wave given in eV.
            resolution: float
              Resolution of TEM image given in Å.
            extent: float
              extent of TEM given in Å - same in both x and y
            pixel size: float
              size of each pixel in Å - smaller than resolution due to Nyquist sampling limitaitons
            gpts: int
              number of pixels N in the NxN TEM image calculation
              always in factors of 2
            Wavelenght: float
              Wavelengt of the electron wave in Å
            Reciprocal pixel size: float
              size of each pixel in Å^-1
            k_max nyquist: float
              Limit of angular resolution due to nyquist given in Å^-1
            k_max antialiasing: float
              angular limit to avoid antialiasing in Å^-1
            Angular limitied resolution: float
              Angular limited resulotion based off k_max antialiasing given in mrad
    """
    
    f_nyquist = 0.5 # approximately double sampling rate of finest features of experiment
    Picture_size = np.max(ase_structure.cell[0:2])  #Å - taking the largest of coordinate in of the cell in the x-y plane
    delta_x = f_nyquist * resolution #pixel size length in Å
    N = Picture_size/delta_x  # number of pixels
    N = 2**math.ceil(math.log2(N)) #rounding up to nearest higher pixel count for FFT (as a factor of 2)
    delta_x = Picture_size/N # recalculating pixel size with respect to new gpts


    plane_wave = abtem.PlaneWave(gpts = N, extent=Picture_size,energy=energy)
    wavelength = plane_wave.wavelength #defining the wavelength of the source electrons given in Å

    reciprocal_P = round(1/(N*delta_x),6)
    k_max = round(f_nyquist/delta_x,6) #å^-1
    k_max_antialiasing = round(2/3 * k_max,6) 
    alpha = k_max_antialiasing * wavelength *10**3


    return {"energy":energy, "resolution":resolution, "extent":Picture_size,
            "Pixel size": delta_x, "gpts": N, "Wavelength" : wavelength,
            "Reciprocal pixel size": reciprocal_P, "k_max nyquist": k_max,
          "k_max antialiasing": k_max_antialiasing, "Angular limitied resolution": alpha}


In [5]:

def create_supercell(unit_cell, extent):
    """
    Create a supercell of the given unit cell with the specified extent.
    
    Parameters:
    unit_cell (ase.Atoms): The unit cell to be replicated.
    extent (float): The desired extent of the supercell in Ångströms.
    
    Returns:
    ase.Atoms: The created supercell.
    """
    
    # Calculate the minimum distance from the center to the line U and V

    min_dist = extent * np.sqrt(2)/2 #minimum distance from the center to the line U and V


    #define vectors u and v to get the minimum distance from the center of the supercell to the line U and V
    u = unit_cell.cell[0][0:2] #X vector
    v = unit_cell.cell[1][0:2] #Y vector

    #minimum reptetions to get distance from center to the line U
    min_rep_v = min_dist*2/v[1] #minimum repetitions of v given unit vector size and extent of the cell
    min_rep_u = min_dist*(2*np.sqrt(v[0]**2+v[1]**2))/(u[0]*v[1]) #minimum repetitions of u given unit vector size and extent of the cell

    #rounding the min rep to nearest even integer
    min_rep_v = np.ceil(min_rep_v/2)*2 
    min_rep_u = np.ceil(min_rep_u/2)*2 

    #making the new supercell with the new repetitions of u and v
    supercell = unit_cell.copy() * (int(min_rep_u), int(min_rep_v), 1)
    return supercell

In [6]:
def cut_supercell(supercell, extent, theta=0, X_translation = 0, Y_translation = 0, rounding_error_limit=0.001):
    """
    Cut the supercell to the desired extent and angle.
    
    Parameters:
    supercell (ase.Atoms): The supercell to be cut.
    extent (float): The desired extent of the supercell in Ångströms.
    theta (float): The angle of rotation in degrees.
    
    Returns:
    ase.Atoms: The cut supercell.
    """

    
    # Calculate the minimum distance from the center to the line U and V
    hypotenuse = extent * np.sqrt(2)/2 #minimum distance from the center to the line U and V
    katete = hypotenuse/np.sqrt(2) #minimum distance from the center to the line U and V
 
    # Define the center of the supercell
    X = (supercell.cell[0][0] + supercell.cell[1][0])/2
    Y = (supercell.cell[0][1] + supercell.cell[1][1])/2
    Z = supercell.cell[2][2]/2
    Center = np.array([X, Y, Z]) #center of the supercell

    #rotate
    rotate_unitcell = supercell.copy()
    rotate_unitcell.rotate(theta, 'z', center=Center) #rotate the supercell theta degrees around the Z axis


    # Define the vector to translate the supercell to have bottom corner of the supercell with the correct extent and center:
    translation_vector = np.array([katete-Center[0] + X_translation, katete-Center[1] + Y_translation, 0])

    # Translate the supercell to the desired position
    rotate_unitcell.translate(translation_vector )

    #redraw the supercell matrix as square in XY
    rotate_unitcell.cell[0] = [extent,0,0]
    rotate_unitcell.cell[1] = [0,extent,0]
    
    #remove all points that fall outside the supercell
    index = 0
    while index < len(rotate_unitcell.positions):
        if (min(rotate_unitcell.positions[index][0:2])<= - rounding_error_limit
        or max(rotate_unitcell.positions[index][0:2])> extent+rounding_error_limit):
            rotate_unitcell.pop(index)
        else:
            index+=1
    return rotate_unitcell


In [7]:
def combine_layers(layer1, layer2, interlayer_distance=1.0, xy_padding = 3.0, z_padding = 1.0):
    """
    Combine two layers into a single structure.
    
    Parameters:
    layer1 (ase.Atoms): The first layer.
    layer2 (ase.Atoms): The second layer.
    
    Returns:
    ase.Atoms: The combined structure.
    """
    
    # Create a new structure by combining the two layers starting with layer2
    combined_structure = layer2.copy()

    # Translate layer2 to the desired interlayer distance ontop of layer1
    z_l1 = layer1.cell[2][2] #z coordinate of the layer 1
    translation_vector = np.array([0, 0, z_l1 + interlayer_distance])
    combined_structure.translate(translation_vector)

    # Append layer1 to the combined structure
    combined_structure = combined_structure + layer1

    # Update the cell dimensions to fit both layers
    combined_structure.cell[2][2] = layer1.cell[2][2] + layer2.cell[2][2] + interlayer_distance
    
    # Set the vacuum space in XY and Z directions
    combined_structure.center(vacuum = xy_padding, axis=(0,1))
    combined_structure.center(vacuum = z_padding, axis=2)

    return combined_structure

In [8]:
def visualize_snippet(atoms, start= (0.0, 0.0), end=None, plane: tuple[float, float] | str = "xy",
    ax: matplotlib.axes.Axes = None,
    scale: float = 0.75,
    title: str = None,
    numbering: bool = False,
    show_periodic: bool = False,
    figsize: tuple[float, float] = [10, 10],
    legend: bool = False,
    merge: float = 1e-2,
    tight_limits: bool = True,
    show_cell: bool = None):
    """
    Display 2D projection of atoms as a matplotlib plot.

    Parameters
    ----------
    atoms : ase.Atoms
        The atoms to be shown.
    start : two float or Atom, optional
        Start corner of the scan [Å]. May be given as fractional coordinate if `fractional=True`. Default is (0., 0.).
    end : two float or Atom, optional
        End corner of the scan [Å]. May be given as fractional coordinate if `fractional=True`.
        Default is None, the scan end point will match the extent of the potential.
    plane : str, two float
        The projection plane given as a concatenation of 'x' 'y' and 'z', e.g. 'xy', or as two floats representing the
        azimuth and elevation angles of the viewing direction [degrees], e.g. (45, 45).
    ax : matplotlib.axes.Axes, optional
        If given the plots are added to the axes.
    scale : float
        Factor scaling their covalent radii for the atom display sizes (default is 0.75).
    title : str
        Title of the displayed image. Default is None.
    numbering : bool
        Display the index of the Atoms as a number. Default is False.
    show_periodic : bool
        If True, show the periodic images of the atoms at the cell boundary.
    figsize : two int, optional
        The figure size given as width and height in inches, passed to `matplotlib.pyplot.figure`.
    legend : bool
        If True, add a legend indicating the color of the atomic species.
    merge: float
        To speed up plotting large numbers of atoms, those closer than the given value [Å] are merged.
    tight_limits : bool
        If True the limits of the plot are adjusted
    kwargs : Keyword arguments for matplotlib.collections.PatchCollection.

    Returns
    -------
    matplotlib.figure.Figure, matplotlib.axes.Axes
    """
    #start by moving the structure to have start at the origin
    X = atoms.cell[0][0] + atoms.cell[1][0]
    Y = atoms.cell[0][1] + atoms.cell[1][1]
    translation = np.array([-X*start[0], -Y*start[1], 0])
    vis_atoms = atoms.copy()
    vis_atoms.translate(translation) #translate the atoms to have start at the origin

    x_dim  =end[0]-start[0]
    y_dim = end[1]-start[1]

    #get the new X and Y cell coordinates:
    vis_atoms.cell[0][0] = atoms.cell[0][0]*(x_dim)
    vis_atoms.cell[1][1] = atoms.cell[1][1]*(y_dim)

    return abtem.show_atoms(atoms=vis_atoms, plane=plane,
    ax=ax,
    scale =scale,
    title = title,
    numbering = numbering,
    show_periodic= show_periodic,
    figsize = figsize,
    legend = legend,
    merge = merge,
    tight_limits = tight_limits,
    show_cell = show_cell)

# Point plot scatter simulation

In [9]:
#functions
import pandas as pd
import numpy as np
import plotly.express as px
import array
def get_covalent_radii(ase_structre):
    """ A function that returns the covalent radii of the atoms in the structure.

    Parameters
    ----------
    ase_structre : ase.Atoms
        ase.Atoms object of the structure

    returns
    -------
    covalent_radii : list of floats
        list of covalent radii of the atoms in the structure
        The covalent radii are taken from the ASE data module.   
    
    """

    atomic_number = ase_structre.get_atomic_numbers() #retriving atomic number of each atom

    #list to store covalent radii
    covalent_radii = []
    for at_num in atomic_number:
        covalent_radii.append(ase.data.covalent_radii[at_num])
    return covalent_radii


In [10]:

def get_atom_2d_positions(ase_structure, round = 3, bias = [0,0]):
    """ Function that returns the 2D positions of atoms in the XY plane of a given structure.

    parameters
    ----------
    ase_structure : ase.Atoms
        ase.Atoms object of a single unitcell or base cut of ase.surface()
    round : int
        number of decimal places to round the coordinates to. Default is 3.
    bias : list of floats
        bias to be added to the x and y coordinates of the atoms. Default is [0,0].
        This is useful for centering the image in the XY plane.
    
    returns
    -------
    xy_pos : np.ndarray
        2D array of shape (n_atoms, 2) containing the x and y coordinates of the atoms in the XY plane.
        The coordinates are rounded to the specified number of decimal places and biased by the given values.    
    
    """
    
    xy_pos = np.round(ase_structure.get_positions()[:,0:2], 3) 

    #add the bias in the x and y direction
    xy_pos[:,0] = xy_pos[:,0] - bias[0]
    xy_pos[:,1] = xy_pos[:,1] - bias[1]

    return xy_pos



In [11]:

def rotate_points(points, angle_degrees, center= [0,0]):
    """
    Rotate an array of 2D points around the origin by a given angle.

    Parameters:
    points (np.ndarray): An array of shape (n, 2) representing n 2D points.
    angle_degrees (float): The angle by which to rotate the points, in degrees.

    Returns:
    np.ndarray: The rotated array of 2D points.
    """
    # Convert the angle from degrees to radians
    angle_radians = np.radians(angle_degrees)

    #translate the points to the origin
    points = points - np.array(center)

    # Define the rotation matrix https://en.wikipedia.org/wiki/Rotation_matrix
    rotation_matrix = np.array([
        [np.cos(angle_radians), -np.sin(angle_radians)],
        [np.sin(angle_radians), np.cos(angle_radians)]
    ])

    # Rotate the points
    rotated_points = np.dot(points, rotation_matrix)

    #translate the points back to the original center
    rotated_points = rotated_points + np.array(center)

    return rotated_points


In [12]:
def center_2D(ase_structure):
    """returns the center of the structure in 2D
    ase_structure = an ase atoms object in the desired dimensions
    return [X, Y] cooridnates of the center of the structure given the first and second cell vectors
    """
    #Get the center of the structure in 2D
    X = (ase_structure.cell[0][0] + ase_structure.cell[1][0])/2
    Y = (ase_structure.cell[0][1] + ase_structure.cell[1][1])/2

    return [X, Y]

In [13]:

def  generate_plotdata(supercell_1, supercell_2, theta):
    """"Function for generation plotable data for Moire pattern recognition
    ase_structure = an ase atoms object in the desired dimensions
    theta = a list angles to have plot of
    return a pandas dataframe with the following columns:
    angle, x, y, species, radius, layer, atom_index
    """    

    center_layer1 = center_2D(supercell_1) #get the center of the structure in 2D
    center_layer2 = center_2D(supercell_2) #get the center of the structure in 2D 


    pos_layer1 = get_atom_2d_positions(supercell_1, bias = center_layer1) #get the 2D coordinates of the atoms object
    pos_layer2 = get_atom_2d_positions(supercell_2, bias = center_layer2) #get the 2D coordinates of the atoms object



    #Defining lists used to generate dataframe
    angles_long = []
    radius_long = []
    species_long = []
    x_long = []
    y_long = []
    layer_long = []
    atom_index_long = []

    #define list of constant terms for variation in angle

    species = np.append(np.array(supercell_1.symbols), np.array(supercell_2.symbols)) #get the symbols of the species
    
    radius = get_covalent_radii(supercell_1) + get_covalent_radii(supercell_2) #get the atomic radius of each species
    layer_1 = ['fixed layer']*len(pos_layer1) #list of layer labels
    layer_2 = ['twist layer']*len(pos_layer2)
    atom_index_1 = np.arange(1,len(pos_layer1)+1) #list of atom index (used for plotly identification of what variable should be animated)
    atom_index_2 = np.arange(len(pos_layer1)+1,len(pos_layer1)+len(pos_layer2)+1)
    


    for a in theta:
        a_s = [a]*(len(pos_layer1)+len(pos_layer2)) #making the angle array of the same length as the number of atoms
        rot_pos = rotate_points(pos_layer2, a) #rotating the atoms in layer 2 by the angle a
        #appending the data to the lists
        angles_long.append(a_s)
        x_long.append(pos_layer1[:,0])
        x_long.append(rot_pos[:,0])
        y_long.append(pos_layer1[:,1])
        y_long.append(rot_pos[:,1])
        species_long.append(species)
        radius_long.append(radius)
        layer_long.append(layer_1)
        layer_long.append(layer_2)
        atom_index_long.append(atom_index_1)
        atom_index_long.append(atom_index_2)
        

    angles_long = [item for sublist in angles_long for item in sublist]
    x_long = [item for sublist in x_long for item in sublist]
    y_long = [item for sublist in y_long for item in sublist]
    species_long = [item for sublist in species_long for item in sublist]
    radius_long = [item for sublist in radius_long for item in sublist]
    layer_long = [item for sublist in layer_long for item in sublist]
    atom_index_long = [item for sublist in atom_index_long for item in sublist]


    return pd.DataFrame({'angle': angles_long , 'x': x_long, 'y':y_long, 
                                'species':species_long, 'radius': radius_long, 
                                'layer': layer_long, 'atom_index': atom_index_long})   
     
    

In [14]:
def create_custom_scatter_plot(
    plotting_data,
    extent,
    size=0.25,
    color_column="species",
    hover_column="layer",
    color_sequence=['white'],
    picture_size=800,
    drop_animation_buttons=True,
    plot_bgcolor='lightgrey',
    paper_bgcolor='lightgrey',
    showgrid=False,
    zeroline=False,
    title=None,
):
    """
    Creates a custom animated scatter plot from plotting_data.

    Parameters:
        plotting_data (pd.DataFrame): Data for plotting (must contain x, y, radius, angle, atom_index columns).
        extent_min (float): Extent of the plot in x/y directions.
        size_reduction_factor (float): Factor to reduce marker size.
        color_column (str): Column name for coloring points.
        hover_column (str): Column name for hover information.
        color_sequence (list): List of colors to use.
        picture_size (int): Size of the figure in pixels.
        pixel_fraction (float): Fraction of pixels used for plot marker scaling.
        drop_animation_buttons (bool): Whether to remove animation control buttons.
        plot_bgcolor (str): Background color of the plot.
        paper_bgcolor (str): Background color of the paper/sorounding plot.
        showgrid (bool): Whether to show grid lines.
        zeroline (bool): Whether to show zero lines.

    Returns:
        fig (plotly.graph_objects.Figure): The resulting plotly figure.
    """
    #Standard plot

    min_x = -extent/2
    max_x = extent/2
    min_y = -extent/2
    max_y = extent/2
    radius_max = plotting_data.radius.max()

    #Calculating the size of the markers in the plot based off atom radius and the size of the plot
    Marker_size = radius_max/(extent) *(picture_size*size)

    fig = px.scatter(plotting_data, x="x", y="y", animation_frame="angle", animation_group="atom_index",
            size="radius", size_max = Marker_size, color=color_column, hover_name=hover_column,
            range_x=[min_x, max_x], range_y=[min_y, max_y],
            color_discrete_sequence = color_sequence,
            width=picture_size, height=picture_size,
            title=title,)
    fig.update_yaxes(
        scaleanchor="x",
        scaleratio=1,
        showgrid=showgrid, 
        zeroline=zeroline
    )
    fig.update_xaxes(
        showgrid=showgrid, 
        zeroline=zeroline
    )
    fig.update_layout(plot_bgcolor=plot_bgcolor, paper_bgcolor=paper_bgcolor)
    fig["layout"].pop("updatemenus") # optional, drop animation buttons
    return fig
    
        

# TEM simmulation functions

In [15]:
def potential_build(ase_structure, energy, 
                    slice_thickness = 1, parametrization="lobato", 
                    projection="finite", resolution=1, gpts=None):
    """Builds structure potential of either Atoms or Frozen Phonons obeject. Units given in Å and eV
    
    
    Parameters
    ----------
    ase_structure : se.Atoms
        ase.Atoms object/objects of material
    energy: float or int
        Energy of electron wave given in eV. Standard values between 80,000 to 300,000 eV
    resolution: float or int
        Desired resolution of TEM image given in Å.
    parametrization: str
    projection: str
        either fintite or infinte

    Returns
    -------
    potential : abtem.Potential object
    
    """
    
    #If GPTS is defined, then sampling parameters are not calculated
    if gpts is None:
        #retrives sampling parameters
        Param = sampling_parameters(ase_structure = ase_structure, energy = energy, resolution = resolution)
        gpts = Param['gpts']
    #retrives sampling parameters
    Param = sampling_parameters(ase_structure = ase_structure, energy = energy, resolution = resolution)
    
    return abtem.Potential(ase_structure, gpts=gpts, parametrization=parametrization, 
                           slice_thickness=slice_thickness, projection=projection)

In [16]:
def TEM_exit_wave(potential, input_wave, compute = True):
    """Returns an exit wave for an HRTEM simulation
    given input wave and potential object by using multislice approximation.
    
    Parameters
    ----------
    potential : abtem.Potential object
        potential object with slice information
    input_wave: abtem.Waves
        abtem plane wave with given parameters

    Returns
    -------
    exit_wave: abtem.Waves
        exit electron wave
    
    """
    exit_wave = input_wave.multislice(potential)
    if compute:
        exit_wave.compute()
    return exit_wave
    

In [17]:

def generate_ctf(Cs = -20e-6*1e10, energy=300e3, defocus ="scherzer" ):
   """generate Contrast Transfer function to be applied on exit wave
   
   Parameters
   ----------
   Cs : float
      the sperical abberation given in Å
   energy: float
      energy of the electron wave in eV
   defocus: str or float
      The defocus setting - either automatic as set to "scherzer" or a value in Å

   Returns
   -------
   ctf : abtem.CTF
      contrast transfer function object for the given parameters 
   """
   
   
   ctf = abtem.CTF(Cs=Cs, energy=energy, defocus=defocus)

   print(f"defocus = {ctf.defocus:.2f} Å")
   aberration_coefficients = {"C10": -ctf.defocus, "C30": Cs}

   return abtem.CTF(aberration_coefficients=aberration_coefficients, energy=ctf.energy) 
   

In [18]:
def gen_stack_from_unitcell(unit_cell_layer1,  Theta, extent, unit_cell_layer2=None, interlayer_dist=1.5, xy_padding = 2, z_padding = 1,
            X_translation = [0], Y_translation = [0], sliceThickness = 0.5,
            resolution = 0.5, save_potentials = False, save_structures = False,
            energy = 300e3):
    """Function for getting data from many layers
    
    parameters
    ---------
    unit_cell_layer1 : ase.Atoms object
        ase.Atoms object of the first layer
    Theta : list of floats
        list of angles in degrees to rotate the second layer
    extent : float or int
        Size N of desired N x N in XY plane of TEM image in Å
    unit_cell_layer2 : ase.Atoms object
        ase.Atoms object of the second layer. If not given, the first layer is used for both layers
    interlayer_dist : float or int
        distance between the two layers in Å
    xy_padding : float or int
        padding in the XY plane in Å
    z_padding : float or int
        padding in the Z direction in Å
    X_translation : list of floats
        list of translations in the X direction in Å
    Y_translation : list of floats
        list of translations in the Y direction in Å
    sliceThickness : float or int
        thickness of each slice of the multislice potential structure in Å
    resolution : float or int
        desired size of features to be destinguishable in the final TEM image in Å
        (this number is not the pixel size due to nyquist sampling limitations)
    save_potentials : bool
        if True, the potentials of each layer are saved in a list and returned
    save_structures : bool
        if True, the structures of each layer are saved in a list and returned
    energy : float or int
        energy of the electron wave in eV. Standard values between 80,000 to 300,000 eV
    
        
    returns
    --------
    exit_wave_stack : abtem.Stack object
        stack of the exit waves of the given angles and translations
    metadata : dict
        dictionary with the following keys:
            'Rotation' : list of angles in degrees
            'x_translation' : list of translations in the X direction in Å
            'y_translation' : list of translations in the Y direction in Å
            'ASE structure' : list of ase.Atoms objects of the structures
            'Potential' : list of abtem.Potential objects of the potentials
    """

    #if unitcell layer 2 is not given, use unitcell layer 1 both layers:
    if unit_cell_layer2 is None:
        unit_cell_layer2 = unit_cell_layer1.copy()
    
    #create supercell 1 and 2:
    supercell_1 = create_supercell(unit_cell_layer1, extent) #create supercell of the given unit cell with the specified extent.
    supercell_2 = create_supercell(unit_cell_layer2, extent)

    #cut layer 1 to to extent:
    layer_1 = cut_supercell(supercell_1, extent)


    # adding vaccume to xy plane before sampling parameters are calcualted
    sampling_layer  = layer_1.copy()
    sampling_layer.center(vacuum = xy_padding, axis=(0,1))

    #defining the sampling parameters:
    sampling_pam = sampling_parameters(sampling_layer, energy = energy, resolution = resolution)
    #Make the ingoing plane wave:
    plane_wave = abtem.PlaneWave(energy =sampling_pam['energy'],
                                gpts = sampling_pam['gpts'],
                                extent = sampling_pam['extent'])




    A  = len(Theta)    
    B = len(X_translation)
    C = len(Y_translation)
    D = int(sampling_pam['gpts'])



    #Generate labels for metadata:
    labels = ["rotation offset = ", "x-axis offset = ", "y-axis offset = "]

    Y_meta_list = list(map(str, Y_translation))
    for i , lab in enumerate(Y_meta_list):
        Y_meta_list[i] = labels[2] + lab + "Å"

    X_meta_list = list(map(str, X_translation))
    for i , lab in enumerate(X_meta_list):
        X_meta_list[i] = labels[1] + lab + "Å"

    Theta_meta_list = list(map(str, Theta))
    for i , lab in enumerate(Theta_meta_list):
        Theta_meta_list[i] = labels[0] + lab + "°"

    #Data saving lists
    ase_structures = []
    potentials = []    
    
    T_s = [] #list for abtem stack creation

    for k, phi in enumerate(tqdm(Theta, desc ='Angle number:', leave=False) ):
        
        X_s = [] #list for abtem stack creation
    
        for j, x in enumerate(tqdm(X_translation, desc ='X translation:', leave=False)):
            
            Y_s = [] #list for abtem stack creation

            for l, y in enumerate(tqdm(Y_translation, desc ='Y translation:', leave=False)):
                
                #creating layer 2 with given rotation and translation:    
                layer_2 = cut_supercell(supercell_2, extent, theta=phi, X_translation = x, Y_translation = y) #cut layer 2 to extent and translate it to the desired position
                #combine the two layers
                combined_layers = combine_layers(layer_1, layer_2, interlayer_distance=interlayer_dist, xy_padding = xy_padding, z_padding = z_padding) #combine the two layers

                #save the combined layers if desired
                if save_structures:
                    ase_structures.append(combined_layers)

                #build potential of new instance and add to list
                potential = potential_build(combined_layers, 
                                            energy=sampling_pam['energy'], 
                                            slice_thickness = sliceThickness,
                                            projection="finite", 
                                            resolution=sampling_pam['resolution'],
                                            gpts=sampling_pam['gpts'] 
                                            )
                                            
                if save_potentials:
                    potentials.append(potential)

                #build exit wave and add to list
                exit_wave = TEM_exit_wave(potential, plane_wave)


                Y_s.append(exit_wave)
                
            
            X_s.append(abtem.stack(Y_s,Y_meta_list))
            
        T_s.append(abtem.stack(X_s,X_meta_list))
    
    print("Creating the exit wave stack...")
    exit_wave_stack = abtem.stack(T_s, Theta_meta_list)
    print("Done!")
                       
    metadata = {'Rotation' : Theta, 'x_translation' : X_translation, 'y_translation' : Y_translation, 
                'ASE structure' : ase_structures, 'Potential' : potentials}
        


    return exit_wave_stack, metadata

In [19]:
def gen_stack_from_supercell(supercell_layer1,  Theta, extent, supercell_layer2=None, interlayer_dist=1.5, xy_padding = 2, z_padding = 1,
            X_translation = [0], Y_translation = [0], sliceThickness = 0.5,
            resolution = 0.5, save_potentials = False, save_structures = False,
            energy = 300e3):
    """Function for getting data from many layers
    
    parameters
    ---------
    supercell_layer1 : ase.Atoms object
        ase.Atoms object of the first layer
    Theta : list of floats
        list of angles in degrees to rotate the second layer
    extent : float or int
        Size N of desired N x N in XY plane of TEM image in Å
    supercell_layer2 : ase.Atoms object
        ase.Atoms object of the second layer. If not given, the first layer is used for both layers
    interlayer_dist : float or int
        distance between the two layers in Å
    xy_padding : float or int
        padding in the XY plane in Å
    z_padding : float or int
        padding in the Z direction in Å
    X_translation : list of floats
        list of translations in the X direction in Å
    Y_translation : list of floats
        list of translations in the Y direction in Å
    sliceThickness : float or int
        thickness of each slice of the multislice potential structure in Å
    resolution : float or int
        desired size of features to be destinguishable in the final TEM image in Å
        (this number is not the pixel size due to nyquist sampling limitations)
    save_potentials : bool
        if True, the potentials of each layer are saved in a list and returned
    save_structures : bool
        if True, the structures of each layer are saved in a list and returned
    energy : float or int
        energy of the electron wave in eV. Standard values between 80,000 to 300,000 eV
    
        
    returns
    --------
    exit_wave_stack : abtem.Stack object
        stack of the exit waves of the given angles and translations
    metadata : dict
        dictionary with the following keys:
            'Rotation' : list of angles in degrees
            'x_translation' : list of translations in the X direction in Å
            'y_translation' : list of translations in the Y direction in Å
            'ASE structure' : list of ase.Atoms objects of the structures
            'Potential' : list of abtem.Potential objects of the potentials
    """

    #if unitcell layer 2 is not given, use unitcell layer 1 both layers:
    if supercell_layer2 is None:
        supercell_layer2 = supercell_layer1.copy()
    
    
    #cut layer 1 to to extent:
    layer_1 = cut_supercell(supercell_layer1, extent)


    # adding vaccume to xy plane before sampling parameters are calcualted
    sampling_layer  = layer_1.copy()
    sampling_layer.center(vacuum = xy_padding, axis=(0,1))

    #defining the sampling parameters:
    sampling_pam = sampling_parameters(sampling_layer, energy = energy, resolution = resolution)
    #Make the ingoing plane wave:
    plane_wave = abtem.PlaneWave(energy =sampling_pam['energy'],
                                gpts = sampling_pam['gpts'],
                                extent = sampling_pam['extent'])




    A  = len(Theta)    
    B = len(X_translation)
    C = len(Y_translation)
    D = int(sampling_pam['gpts'])



    #Generate labels for metadata:
    labels = ["rotation offset = ", "x-axis offset = ", "y-axis offset = "]

    Y_meta_list = list(map(str, Y_translation))
    for i , lab in enumerate(Y_meta_list):
        Y_meta_list[i] = labels[2] + lab + "Å"

    X_meta_list = list(map(str, X_translation))
    for i , lab in enumerate(X_meta_list):
        X_meta_list[i] = labels[1] + lab + "Å"

    Theta_meta_list = list(map(str, Theta))
    for i , lab in enumerate(Theta_meta_list):
        Theta_meta_list[i] = labels[0] + lab + "°"

    #Data saving lists
    ase_structures = []
    potentials = []    
    
    T_s = [] #list for abtem stack creation

    for k, phi in enumerate(tqdm(Theta, desc ='Angle number:', leave=False) ):
        
        X_s = [] #list for abtem stack creation
    
        for j, x in enumerate(tqdm(X_translation, desc ='X translation:', leave=False)):
            
            Y_s = [] #list for abtem stack creation

            for l, y in enumerate(tqdm(Y_translation, desc ='Y translation:', leave=False)):
                
                #creating layer 2 with given rotation and translation:    
                layer_2 = cut_supercell(supercell_layer2, extent, theta=phi, X_translation = x, Y_translation = y) #cut layer 2 to extent and translate it to the desired position
                #combine the two layers
                combined_layers = combine_layers(layer_1, layer_2, interlayer_distance=interlayer_dist, xy_padding = xy_padding, z_padding = z_padding) #combine the two layers

                #save the combined layers if desired
                if save_structures:
                    ase_structures.append(combined_layers)

                #build potential of new instance and add to list
                potential = potential_build(combined_layers, 
                                            energy=sampling_pam['energy'], 
                                            slice_thickness = sliceThickness,
                                            projection="finite", 
                                            resolution=sampling_pam['resolution'],
                                            gpts=sampling_pam['gpts'] 
                                            )
                                            
                if save_potentials:
                    potentials.append(potential)

                #build exit wave and add to list
                exit_wave = TEM_exit_wave(potential, plane_wave)


                Y_s.append(exit_wave)
                
            
            X_s.append(abtem.stack(Y_s,Y_meta_list))
            
        T_s.append(abtem.stack(X_s,X_meta_list))
    
    print("Creating the exit wave stack...")
    exit_wave_stack = abtem.stack(T_s, Theta_meta_list)
    print("Done!")
                       
    metadata = {'Rotation' : Theta, 'x_translation' : X_translation, 'y_translation' : Y_translation, 
                'ASE structure' : ase_structures, 'Potential' : potentials}
        


    return exit_wave_stack, metadata

In [20]:
def gen_stack_from_trajectory(ase_traj, sliceThickness = 0.5, resolution = 0.5, save_potentials = False, save_structures = False, energy = 300e3):
    """Function for getting data from many layers
    
    
    parameters
    ---------
    ase_traj : ase.Atoms object or list of ase.Atoms objects
        ase.Atoms object of the trajectory or list of ase.Atoms objects of the trajectory
    sliceThickness : float or int
        thickness of each slice of the multislice potential structure in Å
    resolution : float or int
        desired size of features to be destinguishable in the final TEM image in Å
        (this number is not the pixel size due to nyquist sampling limitations)
    save_potentials : bool
        if True, the potentials of each layer are saved in a list and returned
    save_structures : bool
        if True, the structures of each layer are saved in a list and returned
    energy : float or int
        energy of the electron wave in eV. Standard values between 80,000 to 300,000 eV

    returns
    --------
    exit_wave_stack : abtem.Stack object
        stack of the exit waves of all the differnet trajectories
    metadata : dict
        dictionary with the following keys:
            'ASE structure' : list of ase.Atoms objects of the structures
            'Potential' : list of abtem.Potential objects of the potentials
    
    """

    sampling_pam = sampling_parameters(ase_traj[0], energy = energy, resolution = resolution)
    #Make the ingoing plane wave:
    plane_wave = abtem.PlaneWave(energy =sampling_pam['energy'],
                                gpts = sampling_pam['gpts'],
                                extent = sampling_pam['extent'])




    #Generate labels for metadata:
    labels = ["trajectory # "]*len(ase_traj)

    for i , lab in enumerate(labels):
        labels[i] = lab + f'{i+1}'

    #Data saving lists
    ase_structures = []
    potentials = []    
    exit_waves = []

    for l, traj in enumerate(tqdm(ase_traj, desc ='Trajectory file:', leave=False)):
                
        if save_structures:
            ase_structures.append(traj)

        #build potential of new instance and add to list
        potential = potential_build(traj, 
                                    energy=sampling_pam['energy'], 
                                    slice_thickness = sliceThickness,
                                    projection="finite", 
                                    resolution=sampling_pam['resolution'],
                                    )
        if save_potentials:
            potentials.append(potential)

        #build exit wave and add to list
        exit_wave = TEM_exit_wave(potential, plane_wave)
        exit_waves.append(exit_wave)

             

    exit_wave_stack = abtem.stack(exit_waves, labels)
                       
    metadata = {'ASE structure' : ase_structures, 'Potential' : potentials}
        


    return exit_wave_stack, metadata

In [21]:
def save_outcome (abtem_stack, filename:str):
    """Function that saves simulation outcomes to disk as a zarr file.
    If no folder in current directory called 'data' exists, it will be created.

    parameters
    ---------
    abtem_stack : abtem.waves or abtem.measurements object
        stack of all the exit waves/outcomes
        Can either be measurement/intensities or exit waves/abtem.waves object
    filename : str
        name of the file to save the data to. The file will be saved in the current working directory in a folder called 'data'.
        The file will be saved as a zarr file.
    """

    os.makedirs('data', exist_ok=True)
    path = os.getcwd()
    path_to_data = os.path.join(path, 'data')
    abtem_stack.to_zarr(os.path.join(path_to_data, filename))
    print(f'File has been saved to disk at directory:\n {os.path.join(path_to_data, filename)}')

In [22]:
def load_outcome_file(file_directory):
    """loads outcome file
    
    parameters
    ---------
    file_directory : str
        directory of the file to load. The file should be a zarr file.

    returns
    --------
    abtem_stack : abtem.Stack object
        stack of all the exit waves/outcomes
        Can either be measurement/intensities only, exit waves or any other abtem stack.
    """
    return abtem.array.from_zarr(file_directory)