For this demonstration we will be simulating a SrTiO3 CeO2 interface downloaded from https://www.materialscloud.org/explore/stoceriaitf/grid/calculations. We first import the multislice library, read in and plot the crystal and set up some basic parameters for our simulation.

In [None]:
import pyms
import numpy as np

%matplotlib inline

#Get crystal
# crystal = pyms.crystal('Demos/Structures/SrTiO3_CeO2_interface.xyz')
crystal = pyms.crystal('Structures/Si100.xyz')

#Quick plot of crystal
crystal.quickplot(atomscale=1e-2)

#Subslicing of crystal for multislice
subslices = [1.35/5.43,2.7/5.43,1.0]
nsubslices = len(subslices)

#Grid size in pixels
gridshape = [512,512]
#gridshape = [1024,1024]
tiling = [4,4]

#Size of real space grid
rsize = np.zeros((3,))
rsize[:3]  = crystal.unitcell[:3]
rsize[:2] *= np.asarray(tiling)

#Number of transmission functions
nT = 4

#Probe accelerating voltage in eV
eV = 3e5

#Objective aperture in mrad
app = 15

#Number of frozen phonon passes
nfph = 5

Set up and look at transmission functions

In [None]:
import torch
from pyms.utils import cx_to_numpy

#Choose GPU if available and CPU if not
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

#Subslice every 2 Angstrom
nsubslices = int(np.ceil(crystal.unitcell[2]/2))
subslices = np.linspace(1.0/nsubslices,1.0,num = nsubslices)

#Initialize array to store transmission functions in
T = torch.zeros(nT,nsubslices,*gridshape,2,device=device)

#Make transmission functions
for i in range(nT):
    T[i,:,:,:] = crystal.make_transmission_functions(gridshape,eV,subslices,tiling,fftout=True,device=device)

%matplotlib notebook
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# from PIL import Image
# Image.fromarray(np.angle(cx_to_numpy(T[0,0,...]))).save('Potential.tif')

#Now plot transmission function with slider widget to inspect individual slices

fig,ax = plt.subplots(figsize=(4,4))
img = np.angle(cx_to_numpy(T[0,0,...]))
p = ax.imshow(img,vmax=0.6*np.amax(img),vmin=np.amin(img))
ax.set_axis_off()

def plot_img(i,img,p):
    p.set_data(np.angle(cx_to_numpy(img[0,i,...])))
    fig.canvas.draw_idle()
#     p.set_title()

widg = widgets.IntSlider(
    value=0,
    min=0,
    max=T.shape[1]-1,
    step=1,
    description='Slice:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
interact(plot_img,i=widg,img=fixed(T), p=fixed(p))

# Check that the grid is of sufficient resolution for our planned 150 mrad
# outer angle HAADF detector
print('Max resolution of grid is {0:.1f} mrad, '.format(pyms.max_grid_resolution(gridshape,rsize,eV=eV)))

Make some detectors

In [None]:
#Define detector inner...
betamin = [0  ,7.5,70]
#... and outer angles
betamax = [7.5,15 ,150]

ndet = len(betamax)

detectors = np.zeros((ndet,*gridshape),dtype = np.float)

fig,ax = plt.subplots(ncols=3,figsize=(12,4))

for i in range(ndet):
    detectors[i,...] = pyms.make_detector(gridshape,rsize,eV,betamax[i],betamin[i])
    
    ax[i].imshow(np.fft.fftshift(detectors[i,...]))
    ax[i].set_axis_off()

plt.show()

Now calculate TEM images for a range of defocii

In [None]:
# Make Fresnel free-space propagators for multislice algorithm
propagators = pyms.make_propagators(gridshape, rsize, eV, subslices)

# Set up thickness series
tstep = 100
thicknesses = np.asarray([50,100, 200])
nt = thicknesses.shape[0]

nfph = 2

images = []
# Iteration over frozen phonon configurations
for ifph in range(nfph):
    # Make probe
    probe = pyms.focused_probe(gridshape, rsize, eV, app)
    # Run multislice
    print("Frozen phonon iteration: {0:2d}/{1:2d}".format(ifph+1, nfph))
    images.append(
        pyms.STEM(
            rsize[:2],
            probe,
            propagators,
            T,
            np.ceil(thicknesses / crystal.unitcell[2]).astype(np.int),
            eV,
            app,
            batch_size=50,
            detectors=detectors,
            tiling=tiling,
        )
    )
images = np.average(images, axis=0)
    

Plot resulting images

In [None]:
nrows, ncols = [ndet, nt]
fig, ax = plt.subplots(nrows=ndet, ncols=nt, figsize=(2 * ncols, 2 * nrows))

for col in range(ncols):
    ax[0,col].set_title('t = {0} '.format(thicknesses[col]))
    for row in range(nrows):
        ax[row, col].imshow(images[row, col, ...])
        ax[row,col].set_xticks([])
        ax[row,col].set_yticks([])
        if(col==0): ax[row,0].set_ylabel('Detector {0}'.format(row+1))