In [1]:
%matplotlib widget

In [2]:
%load_ext autoreload

In [3]:
import numba
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import optax

In [51]:
%autoreload
from microscope_calibration.common.model import (
    Parameters4DSTEM, Model4DSTEM, Result4DSTEM, PixelYX, CoordXY, identity, rotate, scale, flip_y,
    DescanError
)
from microscope_calibration.util.stem_overfocus_sim import smiley, project
from microscope_calibration.common.stem_overfocus import (
    get_backward_transformation_matrix, get_detector_correction_matrix, correct_frame, project_frame_backwards
)
from microscope_calibration.util.optimize import _solve

In [236]:
scan_pixel_pitch = 0.1
detector_pixel_pitch = 0.2
overfocus = 0.01
camera_length = 1.234
propagation_distance = overfocus + camera_length
obj_half_size = 16
angle = np.arctan2(obj_half_size*detector_pixel_pitch/2 + 0.00314157, propagation_distance)
params = Parameters4DSTEM(
    overfocus=overfocus,
    scan_pixel_pitch=scan_pixel_pitch,
    camera_length=camera_length,
    detector_pixel_pitch=detector_pixel_pitch,
    semiconv=angle,
    scan_center=PixelYX(x=1.1*obj_half_size, y=0.9*obj_half_size),
    scan_rotation=np.pi/23,
    flip_y=True,
    detector_center=PixelYX(x=2.3*obj_half_size, y=1.9*obj_half_size),
    # descan_error=DescanError(
    #     offpxi=detector_pixel_pitch,
    #     offpyi=detector_pixel_pitch * 2,
    #     offsxi=-1 * detector_pixel_pitch/camera_length,
    #     offsyi=-2 * detector_pixel_pitch/camera_length,
    #     pxo_pxi=2 * detector_pixel_pitch/scan_pixel_pitch,
    #     pyo_pyi=3 * detector_pixel_pitch/scan_pixel_pitch,
    #     sxo_pxi=-3 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
    #     syo_pyi=-4 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
    # ),
    descan_error=DescanError(*np.random.normal(scale=1,size=len(DescanError()))),
    detector_rotation=np.pi/42,
)

target_params = params.derive(
    descan_error=DescanError(),
    #scan_rotation=np.pi/7,
    #scan_pixel_pitch=scan_pixel_pitch*1.321,
    #detector_center=PixelYX(x=2.12*obj_half_size, y=1.87*obj_half_size),
    #flip_y=True,
)

In [208]:
%autoreload
def check_descan_equivalence(params, target_params):
    distances = []
    for scan_y in (0, 1):
        for scan_x in (0, 1):
            for cl in (0, 1):
                ref_params = params.derive(
                    camera_length=cl
                )
                ref_model = Model4DSTEM.build(params=ref_params, scan_pos=PixelYX(y=scan_y, x=scan_x))
                ref_ray = ref_model.make_source_ray(source_dy=0., source_dx=0.).ray
                ref = ref_model.trace(ref_ray)
                opt_params = target_params.derive(
                    camera_length=cl,
                )
                opt_model = Model4DSTEM.build(params=opt_params, scan_pos=PixelYX(y=scan_y, x=scan_x))
                opt_ray = opt_model.make_source_ray(source_dy=0., source_dx=0.).ray
                opt = opt_model.trace(opt_ray)
                distances.append((
                    opt['detector'].sampling['detector_px'].y - ref['detector'].sampling['detector_px'].y,
                    opt['detector'].sampling['detector_px'].x - ref['detector'].sampling['detector_px'].x,
                ))
    return jnp.linalg.norm(jnp.array(distances))



for angle in (0., np.pi/2, np.pi, -np.pi/2):
    print(check_descan_equivalence(
        params,
        Model4DSTEM.rotate_scan(params, angle)
    ))



0.0
1.0658141036401503e-14
1.0658141036401503e-14
1.0658141036401503e-14


In [209]:
%autoreload
for angle in (0., np.pi/2, np.pi, -np.pi/2):
    print(check_descan_equivalence(
        params,
        Model4DSTEM.rotate_detector(params, angle)
    ))

0.0
3.552713678800501e-15
1.0658141036401503e-14
1.2809491335957507e-14


In [210]:
%autoreload
print(check_descan_equivalence(
    Model4DSTEM.flip_detector_y(params),
    Model4DSTEM.flip_detector_y(Model4DSTEM.flip_detector_y(params))
))

1.1234667099445444e-14


In [217]:
%autoreload
print(check_descan_equivalence(
    params,
    Model4DSTEM.shift_detector(params, PixelYX(x=3, y=5))
))

1.2809491335957507e-14


In [240]:
%autoreload
distances = []
cl_factor = 2.3

for scan_y in (0, 1):
    for scan_x in (0, 1):
        ref_params = params.derive()
        ref_model = Model4DSTEM.build(params=ref_params, scan_pos=PixelYX(y=scan_y, x=scan_x))
        ref_ray = ref_model.make_source_ray(source_dy=0., source_dx=0.).ray
        ref = ref_model.trace(ref_ray)
        
        opt_params = Model4DSTEM.set_camera_length(
            params,
            params.camera_length * cl_factor,
        )
        opt_model = Model4DSTEM.build(params=opt_params, scan_pos=PixelYX(y=scan_y, x=scan_x))
        opt_ray = opt_model.make_source_ray(source_dy=0., source_dx=0.).ray
        opt = opt_model.trace(opt_ray)
        distances.append((
            opt['detector'].sampling['detector_px'].y - ref['detector'].sampling['detector_px'].y,
            opt['detector'].sampling['detector_px'].x - ref['detector'].sampling['detector_px'].x,
        ))
print(jnp.linalg.norm(jnp.array(distances)))

7.105427357601002e-15


In [171]:
def shift_detector(de: DescanError, shift: PixelYX):
    return DescanError(
        pxo_pyi=de.pxo_pyi,
        pyo_pyi=de.pyo_pyi,
        pxo_pxi=de.pxo_pxi,
        pyo_pxi=de.pyo_pxi,
        sxo_pyi=de.sxo_pyi,
        syo_pyi=de.syo_pyi,
        sxo_pxi=de.sxo_pxi,
        syo_pxi=de.syo_pxi,
        offpxi=de.offpxi - shift.x*detector_pixel_pitch,
        offpyi=de.offpyi - shift.y*detector_pixel_pitch,
        offsxi=de.offsxi,
        offsyi=de.offsyi,
    )

for y_shift in (0., 1):
    for x_shift in (0., 1):
        print(check_descan_equivalence(
            params,
            params.derive(
                detector_center=PixelYX(
                    x=params.detector_center.x + x_shift,
                    y=params.detector_center.y + y_shift,
                ),
                descan_error=shift_detector(params.descan_error, shift=PixelYX(x=x_shift, y=y_shift), detector_pixel_pitch=params.detector_pixel_pitch)
            )
        ))

0.0
0.21151653623562722
0.2115165362356262
0.29912955421061055


In [128]:
_solve(start=np.random.normal(scale=100,size=len(DescanError())), loss=loss, debug=True)

Args: [ -22.28901175  -78.78353979  -21.09382325 -122.65020354   -1.1763234
  158.68672912  134.98010966  -41.51362779   66.33315814 -163.99003483
  -75.78764447  -30.64502824], objective function: 2788.783856155217
Change: 63.99999999999999
Args: [  -1.61551785  -51.54320918  -19.67399147 -120.62436421  -26.32383438
  125.48368564  121.68095419  -58.9324628    50.73644899 -165.08887858
  -56.79537687  -20.63716561], objective function: 1870.3984710226177
Change: 106.57148867200867
Args: [  29.80377626  -10.15113372  -12.54870404 -110.94266489  -71.15713155
   66.28568521  102.06368474  -84.58080711   27.03449727 -170.51698847
  -22.93462716   -5.88584896], objective function: 713.8927318153708
Change: 36.29987320705729
Args: [  37.15151947   -0.48266482   -6.07957814 -102.26371478  -88.04890005
   43.97088534   98.04299839  -89.76828122   21.49420214 -175.42541298
  -10.17324009   -2.87919552], objective function: 449.3437427839336
Change: 12.021881824075802
Args: [  32.54578055   -6.

Array([-1.86244646, -0.82308442,  0.29601854, -2.59974192, -0.89806828,
       -0.44858726, -1.78711243, -0.18349141,  1.20953499,  0.70046954,
        0.35914747, -0.21368643], dtype=float64)

In [137]:
%autoreload
opt_res = []
for i in range(10):
    res, residual = _solve(start=np.random.normal(scale=100,size=len(DescanError())), loss=loss) 
    print(residual)
    opt_res.append(res)
np.allclose(opt_res, opt_res[0])

1.0575790737097468e-11
4.5739764057665255e-12
8.247118051745801e-13
4.469240380813033e-12
5.468294608828751e-13
6.862901256958579e-13
4.1798880531533696e-12
6.782241143442136e-13
8.108447709576468e-12
2.7610734406217615e-12


True

In [135]:
%autoreload
for i in range(10):
    opt_res = _solve(start=np.random.normal(scale=100,size=len(DescanError())), loss=loss, debug=False)
    print(np.allclose(params.descan_error, DescanError(*opt_res)))

True
True
True
True
True
True
True
True
True
True


In [114]:
np.allclose(params.descan_error, DescanError(*opt_res))

True

In [46]:
start = jnp.array((1., ))
correct = jnp.array((scan_pixel_pitch, ))
loss(start), loss(correct)

(Array(8.29759001, dtype=float64), Array(0., dtype=float64))

In [50]:
solver = optax.lbfgs()
optargs = start.copy()
opt_state = solver.init(optargs)
value_and_grad = optax.value_and_grad_from_state(loss)

def optstep(optargs, opt_state):
    value, grad = value_and_grad(optargs, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, optargs, value=value, grad=grad, value_fn=loss
    )
    print(jnp.linalg.norm(updates))
    optargs = optax.apply_updates(optargs, updates)
    return optargs, opt_state, jnp.linalg.norm(updates)

while True:
    print(f'Optargs: {optargs}, Objective function: {loss(optargs)}, distance {optargs - correct}')
    optargs, opt_state, change = optstep(optargs, opt_state)
    if change < 1e-12:
        break

Optargs: [1.], Objective function: 8.297590011563598, distance [0.9]
0.89999974585797
Optargs: [0.10000025], Objective function: 2.3430737435825577e-06, distance [2.5414203e-07]
1420.4332635862895
Optargs: [-1420.33326333], Objective function: 13093.903711016888, distance [-1420.43326333]
1420.4332635862897
Optargs: [0.10000025], Objective function: 2.3430761162401836e-06, distance [2.54142287e-07]
2.677778115931931e-07
Optargs: [0.09999999], Objective function: 1.257133224186191e-07, distance [-1.36355243e-08]
1.3635527488239522e-08
Optargs: [0.1], Objective function: 2.942091015256665e-14, distance [3.1918912e-15]
3.2502398488620793e-15


In [12]:
%autoreload
scan_pixel_pitch = 0.1
detector_pixel_pitch = 0.2
overfocus = 1.
camera_length = 1.
propagation_distance = overfocus + camera_length
obj_half_size = 16
angle = np.arctan2(obj_half_size*detector_pixel_pitch/2 + 0.00314157, propagation_distance)

params = Parameters4DSTEM(
    overfocus=overfocus,
    scan_pixel_pitch=scan_pixel_pitch,
    camera_length=camera_length,
    detector_pixel_pitch=detector_pixel_pitch,
    semiconv=angle,
    scan_center=PixelYX(x=obj_half_size, y=obj_half_size),
    scan_rotation=0.,
    flip_y=False,
    detector_center=PixelYX(x=2*obj_half_size, y=2*obj_half_size),
    detector_rotation=0.,
    descan_error=DescanError(
        sxo_pyi=1 * detector_pixel_pitch/scan_pixel_pitch,
        syo_pxi=1 * detector_pixel_pitch/scan_pixel_pitch,
        sxo_pxi=-2 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
        syo_pyi=-1 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
    )
)

In [13]:
obj = smiley(2* obj_half_size)
sim = project(
    image=obj,
    scan_shape=PixelYX(x=2*obj_half_size, y=2*obj_half_size),
    detector_shape=PixelYX(x=4*obj_half_size, y=4*obj_half_size),
    sim_params=params
)

In [14]:
#np.save('with_descan.npy', sim)

In [15]:
%autoreload
test_positions = jnp.array((
    (0, 0),
    (100, 0),
    (0, 100)
))

target_px = []
for scan_y, scan_x in test_positions:
    target_model = Model4DSTEM.build(params=params, scan_pos=PixelYX(y=scan_y, x=scan_x))
    target_ray = target_model.make_source_ray(source_dx=0, source_dy=0).ray
    target_res = target_model.trace(target_ray)
    target_px.append((
        target_res['detector'].sampling['detector_px'].x, 
        target_res['detector'].sampling['detector_px'].y, 
    ))

target_px = jnp.array(target_px)

@jax.jit
def loss(args):
    sxo_pyi, syo_pxi, sxo_pxi, syo_pyi = args
    opt_params = params.derive(descan_error=DescanError(
        sxo_pyi=sxo_pyi,
        syo_pxi=syo_pxi,
        sxo_pxi=sxo_pxi,
        syo_pyi=syo_pyi,
    ))
    res = []
    for scan_y, scan_x in test_positions:
        opt_model = Model4DSTEM.build(params=opt_params, scan_pos=PixelYX(y=scan_y, x=scan_x))
        opt_ray = opt_model.make_source_ray(source_dx=0, source_dy=0).ray
        opt_res = opt_model.trace(opt_ray)
        res.append((
            opt_res['detector'].sampling['detector_px'].x, 
            opt_res['detector'].sampling['detector_px'].y, 
        ))
    return jnp.linalg.norm(jnp.array(res) - target_px)

In [16]:
%autoreload
start = jnp.zeros(4)
correct = jnp.array((
    1 * detector_pixel_pitch/scan_pixel_pitch,
    1 * detector_pixel_pitch/scan_pixel_pitch,
    -2 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
    -1 * detector_pixel_pitch/scan_pixel_pitch/camera_length,
))
loss(start), loss(correct)

(Array(259.93845425, dtype=float64), Array(0., dtype=float64))

In [17]:
%%time
%autoreload
solver = optax.lbfgs()
optargs = start.copy()
opt_state = solver.init(optargs)
value_and_grad = optax.value_and_grad_from_state(loss)

@jax.jit
def optstep(optargs, opt_state):

    value, grad = value_and_grad(optargs, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, optargs, value=value, grad=grad, value_fn=loss
    )
    optargs = optax.apply_updates(optargs, updates)
    return optargs, opt_state

for i in range(10):
    print(f'Objective function: {loss(optargs)}, distance {optargs - correct}')
    optargs, opt_state = optstep(optargs, opt_state)
print(f'Objective function: {loss(optargs)}, distance {optargs - correct}')

Objective function: 259.93845425407915, distance [-2. -2.  4.  2.]
Objective function: 26.66086140682235, distance [ 0.61221921  0.10120593  0.30860143 -0.10120593]
Objective function: 6.92307995973211, distance [-0.00929074 -0.04819074  0.13528148  0.04819074]
Objective function: 5.000034294667415, distance [-0.10706057 -0.0309365  -0.01425106  0.0309365 ]
Objective function: 2.333560982376479, distance [-0.02197011  0.01050028 -0.05347095 -0.01050028]
Objective function: 1.0353947452645216, distance [ 0.0161431   0.00794695 -0.00769774 -0.00794695]
Objective function: 0.844680495950539, distance [ 0.00983851 -0.00318521  0.01939414  0.00318521]
Objective function: 0.4793930619284352, distance [-0.01083868 -0.00132091 -0.00687594  0.00132091]
Objective function: 0.12648295699271592, distance [-0.0006318   0.00072134 -0.00279582 -0.00072134]
Objective function: 0.07459031810058668, distance [ 0.00136696  0.00054215 -0.00025949 -0.00054215]
Objective function: 0.04170785608595267, dista

In [18]:
%autoreload
obj = smiley(obj_half_size * 2) # np.ones((obj_half_size * 2, obj_half_size * 2))

projected = project(
    image=obj,
    detector_shape=(obj_half_size * 4, obj_half_size * 4),
    scan_shape=(obj_half_size * 2, obj_half_size * 2),
    sim_params=params,
)


In [19]:
%autoreload

out = np.zeros_like(projected)
for scan_y in range(out.shape[0]):
    for scan_x in range(out.shape[1]):
        correct_frame(
            frame=projected[scan_y, scan_x],
            mat=mat,
            scan_y=scan_y,
            scan_x=scan_x,
            detector_out=out[scan_y, scan_x],
        )

projected_ref = project(
    image=obj,
    detector_shape=(obj_half_size * 4, obj_half_size * 4),
    scan_shape=(obj_half_size * 2, obj_half_size * 2),
    sim_params=params_ref,
    specimen_to_image=map_coord,
)


scan_y = 16
scan_x = 16

                    
fig, axes = plt.subplots(2, 4, squeeze=False)

axes[0, 0].imshow(projected[scan_y, scan_x])
axes[0, 1].imshow(projected[:, :, obj_half_size * 2, obj_half_size * 2])
axes[0, 2].imshow(out[scan_y, scan_x])
axes[0, 3].imshow(out[:, :, obj_half_size* 2, obj_half_size * 2])

axes[1, 0].imshow(projected_ref[scan_y, scan_x])
axes[1, 1].imshow(projected_ref[:, :, obj_half_size * 2, obj_half_size * 2])
axes[1, 2].imshow(out[scan_y, scan_x])
axes[1, 3].imshow(out[:, :, obj_half_size * 2, obj_half_size * 2])


NameError: name 'mat' is not defined

In [None]:
clip = 1
clip2 = 1
np.allclose(projected_ref, out)

In [None]:

np.allclose(projected_ref[scan_y, scan_x, clip2:-clip2, clip2:-clip2], out[scan_y, scan_x, clip2:-clip2, clip2:-clip2])

In [None]:
projected_ref[scan_y, scan_x] - out[scan_y, scan_x]

In [None]:
detector_rotation = 0.
scan_rotation = 0.

params = Parameters4DSTEM(
    overfocus=1,
    scan_pixel_pitch=2,
    camera_length=1,
    detector_pixel_pitch=2,
    semiconv=np.pi/2,
    scan_center=PixelYX(x=16., y=16.),
    scan_rotation=0.,
    flip_y=False,
    detector_center=PixelYX(x=32, y=32.),
    detector_rotation=0.,
    descan_error=DescanError()
)
obj = smiley(64)
res = np.zeros((32, 32))

# def map_coord(inp):
#     cy = obj.shape[0] / 2
#     cx = obj.shape[1] / 2
#     inp_vec = jnp.array((inp.y, inp.x))
#     y, x = scale(1) @ inp_vec
#     return PixelYX(y=y+cy, x=x+cx)

map_coord = None

projected = project(
    image=obj,
    scan_shape=((32, 32)),
    detector_shape=((64, 64)),
    sim_params=params,
    specimen_to_image=map_coord,
)

print(projected.shape)

fig, axes = plt.subplots(1, 2, squeeze=False)
axes[0, 0].imshow(projected[:, :, 32, 32])
axes[0, 1].imshow(projected[16, 16])

mat = get_backward_transformation_matrix(
    rec_params=params,
    specimen_to_image=map_coord
)
project_frame_backwards(
    frame=projected[16, 16],
    source_semiconv=np.pi/2,
    mat=mat,
    scan_y=16,
    scan_x=16,
    image_out=res,
)




fig, axes = plt.subplots(1, 2, squeeze=False)

axes[0, 0].imshow(obj)
axes[0, 1].imshow(res)


In [None]:
params = Parameters4DSTEM(
    overfocus=1,
    scan_pixel_pitch=1,
    camera_length=1,
    detector_pixel_pitch=2,
    semiconv=np.pi/2,
    scan_center=PixelYX(x=16, y=16.),
    scan_rotation=0.,
    flip_y=False,
    detector_center=PixelYX(x=16, y=16.),
    detector_rotation=np.pi/2,
    descan_error=DescanError()
)
obj = smiley(32)
res = project(
    image=obj,
    detector_shape=(32, 32),
    scan_shape=(32, 32),
    sim_params=params,
)
fig, axes = plt.subplots(1, 2)
axes[0].imshow(obj)
axes[1].imshow(np.rot90(res[16, 15], k=-1)) 
#assert_allclose(obj, np.rot90(res[15, 16], k=-1))

In [None]:
trans = rotate(np.pi/1.23) @ scale(0.23) @ flip_y()
cis = jnp.linalg.inv(trans)

In [None]:
trans, cis

In [None]:
cis2 = flip_y() @ scale(1/0.23) @ rotate(-np.pi/1.23)
cis2

In [None]:
I = np.eye(2, dtype=trans.dtype)
cis3 = jnp.linalg.solve(trans, I)
cis3

In [None]:
cis3

In [None]:
trans.dtype

In [None]:
p = OverfocusParams(
    overfocus=1,
    scan_pixel_size=1,
    camera_length=1,
    detector_pixel_size=1,
    cy=0,
    cx=0,
)
p	

In [None]:
detector_px_to_specimen_px(
    y_px=1.,
    x_px=0.,
    fov_size_y=0,
    fov_size_x=0,
    transformation_matrix=np.array(((0., 1.), (1., 0.))),
    **p
)

In [None]:
%autoreload
size = 32
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)
obj = smiley(size)
projected = project(
    image=obj,
    scan_shape=(size, size),
    detector_shape=(size, size),
    sim_params=params,
)

In [None]:
fig, axes = plt.subplots(1, 3)
axes[0].imshow(obj)
axes[1].imshow(projected[:, :, size//2, size//2])
axes[2].imshow(projected[:, :, size//2, size//2] - obj)

fig, axes = plt.subplots(1, 3)
axes[0].imshow(obj)
axes[1].imshow(projected[size//2, size//2, :])
axes[2].imshow(projected[size//2, size//2, :] - obj)

In [None]:
%autoreload
size = 32
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)
obj = np.zeros((size, size))
obj[size//2, size//2] = 1
sim = project(obj, scan_shape=(size, size), detector_shape=(size, size), sim_params=params)
assert sim[size//2, size//2, size//2, size//2] == 1

udf = OverfocusUDF(params)
ctx = Context.make_with('inline')
ds = ctx.load('memory', data=sim)

res = ctx.run_udf(dataset=ds, udf=udf, plots=True)

res['point']

In [None]:
fig, axes = plt.subplots()
axes.imshow(sim[size//2, size//2])

In [None]:
fig, axes = plt.subplots()
axes.imshow(sim[:, :, size//2 + 1, size//2 + 1])

In [None]:
def get_translation_matrix(params: OverfocusParams, nav_shape):
    a = []
    b = []

    for det_y in (-10, 10):
        for det_x in (-10, 10):
            spec_y, spec_x = detector_px_to_specimen_px(
                y_px=float(det_y),
                x_px=float(det_x),
                fov_size_y=float(nav_shape[0]),
                fov_size_x=float(nav_shape[1]),
                transformation_matrix=get_transformation_matrix(params),
                cy=params['cy'],
                cx=params['cx'],
                detector_pixel_size=float(params['detector_pixel_size']),
                scan_pixel_size=float(params['scan_pixel_size']),
                camera_length=float(params['camera_length']),
                overfocus=float(params['overfocus']),
            )
            for scan_y in (-10, 10):
                for scan_x in (-10, 10):                    
                    offset_y = scan_y - nav_shape[0] / 2
                    offset_x = scan_x - nav_shape[1] / 2
                    image_px_y = spec_y + offset_y
                    image_px_x = spec_x + offset_x
                    a.append((
                        image_px_y,
                        image_px_x,
                        scan_y,
                        scan_x,
                        1
                    ))
                    b.append((det_y, det_x))
    #print(a)
    #print(b)
    res = np.linalg.lstsq(a, b, rcond=None)
    return res[0]

In [None]:
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)

get_translation_matrix(params, nav_shape=(32, 32))

In [None]:
class RefOverfocusUDF(OverfocusUDF):
    def get_task_data(self):
        overfocus_params = self.params.overfocus_params
        translation_matrix = get_translation_matrix(
            params=overfocus_params,
            nav_shape=self._get_fov()
        )
        select_roi = np.zeros(self.meta.dataset_shape.nav, dtype=bool)
        nav_y, nav_x = self.meta.dataset_shape.nav
        select_roi[nav_y//2, nav_x//2] = True
        return {
            'translation_matrix': translation_matrix,
            'select_roi': select_roi
        }

In [None]:
%autoreload
size = 16
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=size/2,
    cx=size/2,
    scan_rotation=0,
    flip_y=False
)
obj = np.zeros((size, size))
obj[size//2, size//2] = 1
sim = project(obj, scan_shape=(size, size), detector_shape=(size, size), sim_params=params)
assert sim[size//2, size//2, size//2, size//2] == 1

ref_udf = RefOverfocusUDF(params)
res_udf = OverfocusUDF(params)
ctx = Context.make_with('inline')
ds = ctx.load('memory', data=sim)

res = ctx.run_udf(dataset=ds, udf=(ref_udf, res_udf), plots=True)

In [None]:
params = OverfocusParams(
    overfocus=0.0001,
    scan_pixel_size=0.00000001,
    camera_length=1,
    detector_pixel_size=0.0001,
    semiconv=np.pi,
    cy=4.,
    cx=4.,
    scan_rotation=0,
    flip_y=False
)
obj = np.zeros((8, 8))
obj[4, 4] = 1
sim = project(obj, scan_shape=(8, 8), detector_shape=(8, 8), sim_params=params)
assert sim[4, 4, 4, 4] == 1

ctx = Context.make_with('inline')
ds = ctx.load('memory', data=sim)

ref_udf = RefOverfocusUDF(params)
res_udf = OverfocusUDF(params)

res = ctx.run_udf(dataset=ds, udf=(ref_udf, res_udf), plots=True)

fig, axes = plt.subplots(1, 2)
axes[0].imshow(res[0]['shifted_sum'].data.astype(bool))
axes[1].imshow(obj.astype(bool))
