In [None]:
import ase
import numpy as np
import abtem
import matplotlib.pyplot as plt
from typing import Tuple, NamedTuple

%config InlineBackend.rc = {"figure.dpi": 72, "figure.figsize": (6.0, 4.0)}
%matplotlib ipympl

# atoms = ase.Atoms(
#     "Si2", positions=[(1.0, 2.0, 1.0), (3.0, 2.0, 1.0)], cell=[4, 4, 2]
# )

atoms = ase.Atoms(
    "Si1", positions=[(2.0, 2.0, 1.0),], cell=[4, 4, 2]
)


phi_0 = 100e3

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
abtem.show_atoms(atoms, ax=ax1, title="Beam view", numbering=True, merge=False)
abtem.show_atoms(atoms, ax=ax2, plane="xz", title="Side view", legend=True)

potential = abtem.Potential(atoms, sampling=0.04, projection="infinite")

potential_array = potential.build().project().compute()
pot_array = np.flip(potential_array.array, axis=0)


In [None]:
class PlotParams(NamedTuple):
    num_rays: int = 10
    ray_color: str = 'dimgray'
    fill_color: str = 'aquamarine'
    fill_color_pair: Tuple[str, str] = ('khaki', 'deepskyblue')
    fill_alpha: float = 0.0
    ray_alpha: float = 1.
    component_lw: float = 1.
    edge_lw: float = 1.
    ray_lw: float = 1.
    label_fontsize: int = 12
    figsize: Tuple[int, int] = (6, 6)
    extent_scale: float = 1.1


In [None]:
visualization = potential.show(
    project=True,
    explode=True,
    figsize=(16, 5),
    common_color_scale=True,
    cbar=True,
)

In [None]:
x1 = 1.0
atom_spacing = 3.0
x2 = x1 + atom_spacing
x_centre = 10.0
x0 = x_centre - (x1) - atom_spacing/2
z = 30
# atoms = ase.Atoms(
#     "Si2", positions=[(x0+x1, x_centre, z), (x0+x2, x_centre, z)], cell=[x_centre*2, x_centre*2, 60]
# )

atoms = ase.Atoms(
    "Si1", positions=[(x_centre, x_centre, z)], cell=[x_centre*2, x_centre*2, 60]
)

phi_0 = 100e3

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
abtem.show_atoms(atoms, ax=ax1, title="Beam view", numbering=True, merge=False)
abtem.show_atoms(atoms, ax=ax2, plane="xz", title="Side view", legend=True)

potential = abtem.Potential(atoms, sampling=0.01, projection="infinite")

potential_array = potential.build().project().compute()


In [None]:
visualization = potential.show(
    project=True,
    explode=True,
    figsize=(16, 5),
    common_color_scale=True,
    cbar=True,
)

In [None]:
plt.figure()
plt.imshow(potential_array.array.T)

In [None]:
from scipy.interpolate import RegularGridInterpolator as RGI

from temgymbasic.model import (
    Model,
)
import temgymbasic.components as comp
from temgymbasic.rays import Rays

num_rays = 100

x0 = -10 * 1e-10
y0 = -10 * 1e-10

extent = np.array(potential.extent) * 1e-10
gpts = potential.gpts
x = np.linspace(x0, x0 + extent[0], gpts[0], endpoint=True)
y = np.linspace(y0, y0 + extent[1], gpts[1], endpoint=True)
xx, yy = np.meshgrid(x, y, indexing='ij')

phi_r = potential_array.array.T  # Be very careful with this - transpose only works because we have symmetry!

pot_interp = RGI([x, y], phi_r, method='linear', bounds_error=False, fill_value=0.0)
Ey, Ex = np.gradient(phi_r, x, y)
Ex_interp = RGI([x, y], Ex, method='linear', bounds_error=False, fill_value=0.0)
Ey_interp = RGI([x, y], Ey, method='linear', bounds_error=False, fill_value=0.0)

plt.figure()
plt.imshow(Ex)

plt.figure()
plt.imshow(Ey)
detector_pixels = 2000
components = (
    comp.ParallelBeam(
        z=-60e-10,
        radius=10e-10,
        phi_0=phi_0
    ),
    comp.PotentialSample(
        z=-30e-10,
        potential=pot_interp,
        Ex=Ex_interp,
        Ey=Ey_interp,
    ),
    comp.Detector(
        z=0,
        pixel_size=1e-11,
        shape=(200, 200),
    ),
)


model = Model(components)
rays = tuple(model.run_iter(num_rays=num_rays))
x = np.stack(tuple(r.x for r in rays), axis=0)
y = np.stack(tuple(r.y for r in rays), axis=0)
z = np.asarray(tuple(r.z for r in rays))

fig, ax = plt.subplots()
ax.plot(x, z, '-r')

max_x = 12e-10
min_x = -12e-10
min_z = np.min([np.min(z)] + [c.z for c in model.components])
max_z = np.max([np.max(z)] + [c.z for c in model.components])

ax.set_xlim([-max_x, max_x])
ax.set_ylim([max_z, min_z])

In [None]:
phi_0 = 100e3
plane_wave = abtem.PlaneWave(energy=100e3)
exit_wave = plane_wave.multislice(potential)
exit_wave.compute()

In [None]:
exit_wave.intensity().show(cbar=True)
intensity = exit_wave.intensity().compute()

In [None]:
lineprofile = intensity.interpolate_line_at_position(
    center=(10, 10), angle=0, extent=20
)
lineprofile.show()

In [None]:
potential_series = abtem.Potential(
    atoms, slice_thickness=0.05, sampling=0.04, exit_planes=2
)

exit_wave_series = plane_wave.multislice(potential_series).compute()

In [None]:
print(exit_wave_series.shape)

plt.figure()
plt.imshow(np.abs(exit_wave_series.array[:, :, 250]**2), extent=[0, 20, 0, 60], cmap='viridis', aspect='auto')
plt.xlabel('x [Å]')
plt.ylabel('z [Å]')
plt.plot(x*1e10 + 10, -1*z*1e10, '-r')
plt.xlim([0, 20.0])


In [None]:
model = Model(components)
rays = tuple(model.run_iter(num_rays=2**20))
wavefront = model.detector.get_image(rays[-1])

plt.figure()
plt.imshow(np.abs(wavefront)**2, extent=[0, 20, 0, 20])