In [None]:
import sys
sys.path.insert(0, R"C:\Users\mb265392\Workspace\TemGymBasic\stem_overfocus.ipynb")

In [None]:
from typing import TypedDict, Tuple

import numpy as np
from temgymbasic import components as comp
from temgymbasic.model import Model
from temgymbasic.functions import get_pixel_coords
from tem_gym_proto import STEMModel, DoubleDeflector
import numba
%matplotlib widget

In [None]:
class OverfocusParams(TypedDict):
    overfocus: float  # m
    scan_pixel_size: float  # m
    camera_length: float  # m
    detector_pixel_size: float  # m
    semiconv: float  # rad
    cy: float
    cx: float
    scan_rotation: float
    scan_shape: Tuple[int, int]
    flip_y: bool



In [None]:
def make_model(params: OverfocusParams, dataset_shape):
    # We have to make it square
    sample = np.ones((dataset_shape[0], dataset_shape[0]))
    
    # Create a list of components to model a simplified 4DSTEM experiment
    components = [
        comp.DoubleDeflector(name='Scan Coils', z_up=0.3, z_low=0.25),
        comp.Lens(name='Lens', z=0.20),
        comp.Sample(
            name='Sample',
            sample=sample,
            z=params['camera_length'],
            width=sample.shape[0] * params['scan_pixel_size']
        ),
        comp.DoubleDeflector(
            name='Descan Coils',
            z_up=0.1,
            z_low=0.05,
            scan_rotation=0.
        )
    ]

    # Create the model Electron microscope. Initially we create a parallel
    # circular beam leaving the "gun"
    model = Model(
        components,
        beam_z=0.4,
        beam_type='paralell',
        num_rays=7,  # somehow the minimum
        experiment='4DSTEM',
        detector_pixels=dataset_shape[1],
        detector_size=dataset_shape[1] * params['detector_pixel_size'],
    )
    model.set_obj_lens_f_from_overfocus(params['overfocus'])
    model.scan_pixels = dataset_shape[0]
    return model

def make_model_proto(params: OverfocusParams):
    model = STEMModel()
    model.move_component(model.scan_coils.first, 0.1)
    model.move_component(model.scan_coils.second, 0.15)
    model.move_component(model.objective, 0.2)
    model.move_component(model.sample, 0.225)
    model.move_component(model.descan_coils.first, 0.25)
    model.move_component(model.descan_coils.second, 0.3)
    return model.set_stem_params(
        camera_length=params['camera_length'],
        semiconv_angle=params['semiconv'],
        scan_step_yx=(
            params['scan_pixel_size'],
            params['scan_pixel_size'],
        ),
        scan_shape=params['scan_shape'],
        overfocus=params['overfocus'],
    )

In [None]:
def get_translation_matrix(params: OverfocusParams, model):
    a = []
    b = []
    model.scan_pixel_x = 0
    model.scan_pixel_y = 0
    for scan_y in (0, model.scan_pixels - 1):
        for scan_x in (0, model.scan_pixels - 1):
            model.scan_pixel_y = scan_y
            model.scan_pixel_x = scan_x
            model.update_scan_coil_ratio()
            model.step()
            sample_rays_x = model.r[model.sample_r_idx, 0, :]
            sample_rays_y = model.r[model.sample_r_idx, 2, :]
            detector_rays_x = model.r[-1, 0, :]
            detector_rays_y = model.r[-1, 2, :]
            sample_coords_x, sample_coords_y = get_pixel_coords(
                rays_x=sample_rays_x,
                rays_y=sample_rays_y,
                size=model.components[model.sample_idx].sample_size,
                pixels=model.components[model.sample_idx].sample_pixels,
            )
            detector_coords_x, detector_coords_y = get_pixel_coords(
                rays_x=detector_rays_x,
                rays_y=detector_rays_y,
                size=model.detector_size,
                pixels=model.detector_pixels,
                flip_y=params['flip_y'],
                scan_rotation=params['scan_rotation'],
            )
            for i in range(len(sample_coords_x)):
                a.append((
                    sample_coords_y[i],
                    sample_coords_x[i],
                    model.scan_pixel_y,
                    model.scan_pixel_x,
                    1
                ))
                b.append((detector_coords_y[i], detector_coords_x[i]))
    res = np.linalg.lstsq(a, b, rcond=None)
    return res[0], a, b

def get_translation_matrix_proto(params: OverfocusParams, model: STEMModel):
    yxs = (
        (0, 0),
        (model.sample.scan_shape[0], model.sample.scan_shape[1]),
        (0, model.sample.scan_shape[1]),
        (model.sample.scan_shape[0], 0),
    )
    num_rays = 7
    
    a=[]
    b=[]

    for yx in yxs:
        for rays in model.scan_point_iter(num_rays=num_rays, yx=yx):
            if rays.location is model.sample:
                coordinates = np.tile(
                    np.asarray((*yx, 1)).reshape(-1, 3),
                    (rays.num, 1),
                )
                yyxx = np.stack(
                    rays.on_grid(
                        shape=model.sample.scan_shape,
                        pixel_size=model.sample.scan_step_yx[0],
                        as_int=False,
                    ),
                    axis=-1,
                )
                a.append(np.concatenate((yyxx, coordinates), axis=-1))
            elif rays.location is model.detector:
                yy, xx = rays.on_grid(
                    shape=model.detector.shape,
                    pixel_size=model.detector.pixel_size,
                    rotation=params['scan_rotation'],
                    flip_y=params['flip_y'],
                    as_int=False,
                )
                b.append(np.stack((yy, xx), axis=-1))

    res, *_ = np.linalg.lstsq(
        np.concatenate(a, axis=0),
        np.concatenate(b, axis=0),
        rcond=None,
    )
    return res, a, b


In [None]:
def plot_rays(model):
        # Iterate over components and their ray positions
    num_rays = 3
    yx = (0, 128)
    all_rays = tuple(model.scan_point_iter(num_rays=num_rays, yx=yx))

    import matplotlib.pyplot as plt
    fig, ax = plt.subplots()
    xvals = np.stack(tuple(r.x for r in all_rays), axis=0)
    zvals = np.asarray(tuple(r.z for r in all_rays))
    ax.plot(xvals, zvals)

    # Optional: Mark the component positions
    extent = 1.5 * np.abs(xvals).max()
    for component in model.components:
        if isinstance(component, DoubleDeflector):
            ax.hlines(
                component.first.z, -extent, extent, linestyle='--'
            )
            ax.text(-extent, component.first.z, repr(component.first), va='bottom')
            ax.hlines(
                component.second.z, -extent, extent, linestyle='--'
            )
            ax.text(-extent, component.second.z, repr(component.second), va='bottom')
        else:
            ax.hlines(component.z, -extent, extent, label=repr(component))
            ax.text(-extent, component.z, repr(component), va='bottom')

    ax.hlines(
        model.objective.ffp, -extent, extent, linestyle=':'
    )

    ax.axvline(color='black', linestyle=":", alpha=0.3)
    _, scan_pos_x = model.sample.scan_position(yx)
    ax.plot([scan_pos_x], [model.sample.z], 'ko')

    ax.set_xlabel('x position')
    ax.set_ylabel('z position')
    ax.invert_yaxis()
    ax.set_title(f'Ray paths for {num_rays} rays at position {yx}')
    plt.show()


In [None]:
dataset_shape = [128, 128]
overfocus_params = OverfocusParams(
    overfocus=0.01,  # m
    scan_pixel_size=0.01,  # m
    camera_length=0.15,  # m
    detector_pixel_size=0.050,  # m
    semiconv=5,  # rad
    scan_rotation=0,
    flip_y=False,
    scan_shape=tuple(dataset_shape),
    # Offset to avoid subchip gap
    cy=128,
    cx=128,
)

model = make_model(overfocus_params, dataset_shape)
res, a, b = get_translation_matrix(overfocus_params, model)

model_proto = make_model_proto(overfocus_params)
res_proto, a_p, b_p = get_translation_matrix_proto(overfocus_params, model_proto)
plot_rays(model_proto)

In [None]:
with np.printoptions(precision=2, suppress=True, threshold=1000):
    print(np.concatenate(a_p, axis=0)[:, :-1])

In [None]:
with np.printoptions(precision=2, suppress=True, threshold=1000):
    print(np.asarray(a))

In [None]:
with np.printoptions(precision=2, suppress=True, threshold=1000):
    print(np.concatenate(b_p, axis=0))

In [None]:
with np.printoptions(precision=2, suppress=True, threshold=1000):
    print(np.asarray(b))