In [1]:
%matplotlib widget

In [2]:
%load_ext autoreload

In [3]:
import numba
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import optax

In [4]:
%autoreload
from microscope_calibration.common.model import (
    Parameters4DSTEM, Model4DSTEM, Result4DSTEM, PixelYX, CoordXY, identity, rotate, scale, flip_y,
    DescanError
)
from microscope_calibration.util.stem_overfocus_sim import smiley, project
from microscope_calibration.common.stem_overfocus import (
    get_backward_transformation_matrix, get_detector_correction_matrix, correct_frame, project_frame_backwards
)

In [5]:
%autoreload
scan_pixel_pitch = 0.1
detector_pixel_pitch = 0.2
overfocus = 1.
camera_length = 1.
propagation_distance = overfocus + camera_length
obj_half_size = 16
angle = np.arctan2(obj_half_size*detector_pixel_pitch/2 + 0.00314157, propagation_distance)

params = Parameters4DSTEM(
    overfocus=overfocus,
    scan_pixel_pitch=scan_pixel_pitch,
    camera_length=camera_length,
    detector_pixel_pitch=detector_pixel_pitch,
    semiconv=angle,
    scan_center=PixelYX(x=0., y=0.),
    scan_rotation=0.,
    flip_y=False,
    detector_center=PixelYX(x=2* obj_half_size, y=2* obj_half_size),
    detector_rotation=0.,
    descan_error=DescanError(
        sxo_pyi=1 * detector_pixel_pitch/scan_pixel_pitch,
        syo_pxi=1 * detector_pixel_pitch/scan_pixel_pitch,
        sxo_pxi=-2 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
        syo_pyi=-1 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
    )
)

In [29]:
%autoreload
test_positions = jnp.array((
    (0, 0),
    (100, 0),
    (0, 100)
))

target_px = []
for scan_y, scan_x in test_positions:
    target_model = Model4DSTEM.build(params=params, scan_pos=PixelYX(y=scan_y, x=scan_x))
    target_ray = target_model.make_source_ray(source_dx=0, source_dy=0).ray
    target_res = target_model.trace(target_ray)
    target_px.append((
        target_res['detector'].sampling['detector_px'].x, 
        target_res['detector'].sampling['detector_px'].y, 
    ))

target_px = jnp.array(target_px)

@jax.jit
def loss(args):
    sxo_pyi, syo_pxi, sxo_pxi, syo_pyi = args
    opt_params = params.derive(descan_error=DescanError(
        sxo_pyi=sxo_pyi,
        syo_pxi=syo_pxi,
        sxo_pxi=sxo_pxi,
        syo_pyi=syo_pyi,
    ))
    res = []
    for scan_y, scan_x in test_positions:
        opt_model = Model4DSTEM.build(params=opt_params, scan_pos=PixelYX(y=scan_y, x=scan_x))
        opt_ray = opt_model.make_source_ray(source_dx=0, source_dy=0).ray
        opt_res = opt_model.trace(opt_ray)
        res.append((
            opt_res['detector'].sampling['detector_px'].x, 
            opt_res['detector'].sampling['detector_px'].y, 
        ))
    return jnp.linalg.norm(jnp.array(res) - target_px)

In [30]:
%autoreload
start = jnp.zeros(4)
correct = jnp.array((
    1 * detector_pixel_pitch/scan_pixel_pitch,
    1 * detector_pixel_pitch/scan_pixel_pitch,
    -2 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
    -1 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
))
loss(start), loss(correct)

(Array(264.57513111, dtype=float64), Array(0., dtype=float64))

In [31]:
%%time
%autoreload
solver = optax.lbfgs()
optargs = start.copy()
opt_state = solver.init(optargs)
value_and_grad = optax.value_and_grad_from_state(loss)

@jax.jit
def optstep(optargs, opt_state):

    value, grad = value_and_grad(optargs, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, optargs, value=value, grad=grad, value_fn=loss
    )
    optargs = optax.apply_updates(optargs, updates)
    return optargs, opt_state

for i in range(10):
    print(f'Objective function: {loss(optargs)}, distance {optargs - correct}')
    optargs, opt_state = optstep(optargs, opt_state)
print(f'Objective function: {loss(optargs)}, distance {optargs - correct}')

Objective function: 264.5751311064591, distance [-2. -2.  4.  2.]
Objective function: 1.8962321570080587e-05, distance [-1.43341678e-07 -1.43341678e-07  2.86683356e-07  1.43341678e-07]
Objective function: 1.0412134715599073e-06, distance [-7.87083376e-09 -7.87083376e-09  1.57416675e-08  7.87083376e-09]
Objective function: 0.05797493496415593, distance [ 0.00043825  0.00043825 -0.0008765  -0.00043825]
Objective function: 4.710848622298637e-09, distance [ 3.41993101e-11  3.41993101e-11 -7.23199278e-11 -3.61599639e-11]
Objective function: 3.763268221359316e-09, distance [-2.95872216e-11 -2.95872216e-11  5.59583491e-11  2.79791745e-11]
Objective function: 3.762100024202355e-09, distance [-2.84379187e-11 -2.84379187e-11  5.68780578e-11  2.84390289e-11]
Objective function: 5.7820468276725964e-11, distance [ 8.17124146e-13  8.17124146e-13 -2.75335310e-14 -1.37667655e-14]
Objective function: 2.440981848036492e-11, distance [ 1.72750703e-13  1.72750703e-13 -3.78364007e-13 -1.89182003e-13]
Objec

In [9]:
%autoreload
obj = smiley(obj_half_size * 2) # np.ones((obj_half_size * 2, obj_half_size * 2))

projected = project(
    image=obj,
    detector_shape=(obj_half_size * 4, obj_half_size * 4),
    scan_shape=(obj_half_size * 2, obj_half_size * 2),
    sim_params=params,
)


In [10]:
%autoreload

out = np.zeros_like(projected)
for scan_y in range(out.shape[0]):
    for scan_x in range(out.shape[1]):
        correct_frame(
            frame=projected[scan_y, scan_x],
            mat=mat,
            scan_y=scan_y,
            scan_x=scan_x,
            detector_out=out[scan_y, scan_x],
        )

projected_ref = project(
    image=obj,
    detector_shape=(obj_half_size * 4, obj_half_size * 4),
    scan_shape=(obj_half_size * 2, obj_half_size * 2),
    sim_params=params_ref,
    specimen_to_image=map_coord,
)


scan_y = 16
scan_x = 16

                    
fig, axes = plt.subplots(2, 4, squeeze=False)

axes[0, 0].imshow(projected[scan_y, scan_x])
axes[0, 1].imshow(projected[:, :, obj_half_size * 2, obj_half_size * 2])
axes[0, 2].imshow(out[scan_y, scan_x])
axes[0, 3].imshow(out[:, :, obj_half_size* 2, obj_half_size * 2])

axes[1, 0].imshow(projected_ref[scan_y, scan_x])
axes[1, 1].imshow(projected_ref[:, :, obj_half_size * 2, obj_half_size * 2])
axes[1, 2].imshow(out[scan_y, scan_x])
axes[1, 3].imshow(out[:, :, obj_half_size * 2, obj_half_size * 2])


NameError: name 'mat' is not defined

In [None]:
clip = 1
clip2 = 1
np.allclose(projected_ref, out)

In [None]:

np.allclose(projected_ref[scan_y, scan_x, clip2:-clip2, clip2:-clip2], out[scan_y, scan_x, clip2:-clip2, clip2:-clip2])

In [None]:
projected_ref[scan_y, scan_x] - out[scan_y, scan_x]

In [None]:
detector_rotation = 0.
scan_rotation = 0.

params = Parameters4DSTEM(
    overfocus=1,
    scan_pixel_pitch=2,
    camera_length=1,
    detector_pixel_pitch=2,
    semiconv=np.pi/2,
    scan_center=PixelYX(x=16., y=16.),
    scan_rotation=0.,
    flip_y=False,
    detector_center=PixelYX(x=32, y=32.),
    detector_rotation=0.,
    descan_error=DescanError()
)
obj = smiley(64)
res = np.zeros((32, 32))

# def map_coord(inp):
#     cy = obj.shape[0] / 2
#     cx = obj.shape[1] / 2
#     inp_vec = jnp.array((inp.y, inp.x))
#     y, x = scale(1) @ inp_vec
#     return PixelYX(y=y+cy, x=x+cx)

map_coord = None

projected = project(
    image=obj,
    scan_shape=((32, 32)),
    detector_shape=((64, 64)),
    sim_params=params,
    specimen_to_image=map_coord,
)

print(projected.shape)

fig, axes = plt.subplots(1, 2, squeeze=False)
axes[0, 0].imshow(projected[:, :, 32, 32])
axes[0, 1].imshow(projected[16, 16])

mat = get_backward_transformation_matrix(
    rec_params=params,
    specimen_to_image=map_coord
)
project_frame_backwards(
    frame=projected[16, 16],
    source_semiconv=np.pi/2,
    mat=mat,
    scan_y=16,
    scan_x=16,
    image_out=res,
)




fig, axes = plt.subplots(1, 2, squeeze=False)

axes[0, 0].imshow(obj)
axes[0, 1].imshow(res)


In [None]:
params = Parameters4DSTEM(
    overfocus=1,
    scan_pixel_pitch=1,
    camera_length=1,
    detector_pixel_pitch=2,
    semiconv=np.pi/2,
    scan_center=PixelYX(x=16, y=16.),
    scan_rotation=0.,
    flip_y=False,
    detector_center=PixelYX(x=16, y=16.),
    detector_rotation=np.pi/2,
    descan_error=DescanError()
)
obj = smiley(32)
res = project(
    image=obj,
    detector_shape=(32, 32),
    scan_shape=(32, 32),
    sim_params=params,
)
fig, axes = plt.subplots(1, 2)
axes[0].imshow(obj)
axes[1].imshow(np.rot90(res[16, 15], k=-1)) 
#assert_allclose(obj, np.rot90(res[15, 16], k=-1))

In [None]:
trans = rotate(np.pi/1.23) @ scale(0.23) @ flip_y()
cis = jnp.linalg.inv(trans)

In [None]:
trans, cis

In [None]:
cis2 = flip_y() @ scale(1/0.23) @ rotate(-np.pi/1.23)
cis2

In [None]:
I = np.eye(2, dtype=trans.dtype)
cis3 = jnp.linalg.solve(trans, I)
cis3

In [None]:
cis3

In [None]:
trans.dtype

In [None]:
p = OverfocusParams(
    overfocus=1,
    scan_pixel_size=1,
    camera_length=1,
    detector_pixel_size=1,
    cy=0,
    cx=0,
)
p	

In [None]:
detector_px_to_specimen_px(
    y_px=1.,
    x_px=0.,
    fov_size_y=0,
    fov_size_x=0,
    transformation_matrix=np.array(((0., 1.), (1., 0.))),
    **p
)

In [None]:
%autoreload
size = 32
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)
obj = smiley(size)
projected = project(
    image=obj,
    scan_shape=(size, size),
    detector_shape=(size, size),
    sim_params=params,
)

In [None]:
fig, axes = plt.subplots(1, 3)
axes[0].imshow(obj)
axes[1].imshow(projected[:, :, size//2, size//2])
axes[2].imshow(projected[:, :, size//2, size//2] - obj)

fig, axes = plt.subplots(1, 3)
axes[0].imshow(obj)
axes[1].imshow(projected[size//2, size//2, :])
axes[2].imshow(projected[size//2, size//2, :] - obj)

In [None]:
%autoreload
size = 32
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)
obj = np.zeros((size, size))
obj[size//2, size//2] = 1
sim = project(obj, scan_shape=(size, size), detector_shape=(size, size), sim_params=params)
assert sim[size//2, size//2, size//2, size//2] == 1

udf = OverfocusUDF(params)
ctx = Context.make_with('inline')
ds = ctx.load('memory', data=sim)

res = ctx.run_udf(dataset=ds, udf=udf, plots=True)

res['point']

In [None]:
fig, axes = plt.subplots()
axes.imshow(sim[size//2, size//2])

In [None]:
fig, axes = plt.subplots()
axes.imshow(sim[:, :, size//2 + 1, size//2 + 1])

In [None]:
def get_translation_matrix(params: OverfocusParams, nav_shape):
    a = []
    b = []

    for det_y in (-10, 10):
        for det_x in (-10, 10):
            spec_y, spec_x = detector_px_to_specimen_px(
                y_px=float(det_y),
                x_px=float(det_x),
                fov_size_y=float(nav_shape[0]),
                fov_size_x=float(nav_shape[1]),
                transformation_matrix=get_transformation_matrix(params),
                cy=params['cy'],
                cx=params['cx'],
                detector_pixel_size=float(params['detector_pixel_size']),
                scan_pixel_size=float(params['scan_pixel_size']),
                camera_length=float(params['camera_length']),
                overfocus=float(params['overfocus']),
            )
            for scan_y in (-10, 10):
                for scan_x in (-10, 10):                    
                    offset_y = scan_y - nav_shape[0] / 2
                    offset_x = scan_x - nav_shape[1] / 2
                    image_px_y = spec_y + offset_y
                    image_px_x = spec_x + offset_x
                    a.append((
                        image_px_y,
                        image_px_x,
                        scan_y,
                        scan_x,
                        1
                    ))
                    b.append((det_y, det_x))
    #print(a)
    #print(b)
    res = np.linalg.lstsq(a, b, rcond=None)
    return res[0]

In [None]:
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)

get_translation_matrix(params, nav_shape=(32, 32))

In [None]:
class RefOverfocusUDF(OverfocusUDF):
    def get_task_data(self):
        overfocus_params = self.params.overfocus_params
        translation_matrix = get_translation_matrix(
            params=overfocus_params,
            nav_shape=self._get_fov()
        )
        select_roi = np.zeros(self.meta.dataset_shape.nav, dtype=bool)
        nav_y, nav_x = self.meta.dataset_shape.nav
        select_roi[nav_y//2, nav_x//2] = True
        return {
            'translation_matrix': translation_matrix,
            'select_roi': select_roi
        }

In [None]:
%autoreload
size = 16
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)
obj = np.zeros((size, size))
obj[size//2, size//2] = 1
sim = project(obj, scan_shape=(size, size), detector_shape=(size, size), sim_params=params)
assert sim[size//2, size//2, size//2, size//2] == 1

ref_udf = RefOverfocusUDF(params)
res_udf = OverfocusUDF(params)
ctx = Context.make_with('inline')
ds = ctx.load('memory', data=sim)

res = ctx.run_udf(dataset=ds, udf=(ref_udf, res_udf), plots=True)

In [None]:
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=4.,
    cx=4.,
    scan_rotation=0,
    flip_y=False
)
obj = np.zeros((8, 8))
obj[4, 4] = 1
sim = project(obj, scan_shape=(8, 8), detector_shape=(8, 8), sim_params=params)
assert sim[4, 4, 4, 4] == 1

ctx = Context.make_with('inline')
ds = ctx.load('memory', data=sim)

ref_udf = RefOverfocusUDF(params)
res_udf = OverfocusUDF(params)

res = ctx.run_udf(dataset=ds, udf=(ref_udf, res_udf), plots=True)

fig, axes = plt.subplots(1, 2)
axes[0].imshow(res[0]['shifted_sum'].data.astype(bool))
axes[1].imshow(obj.astype(bool))
