In [None]:
import torch
import torch.nn.functional as F

from tqdm.notebook import tqdm

import zetta_utils

from zetta_utils.layer.volumetric.cloudvol import build_cv_layer
from zetta_utils.viz.widgets import visualize_list
from zetta_utils import tensor_ops
from zetta_utils.alignment.online_finetuner import align_with_online_finetuner

In [None]:
def norm(img):
    result = img.clone().float()
    result[result != 0] -= result[result != 0].mean()
    result[result != 0] /= result[result != 0].std()
    return result

In [None]:

unaligned_path = 'gs://neuroglancer/kluther/tomography/jun23/Section_1_rawtilts'

unaligned_raw = build_cv_layer(
    path=unaligned_path,
    default_desired_resolution=(1, 1, 1),
    index_resolution=(1, 1, 1),
    readonly=True,
    cv_kwargs={'cache': True},
)

sections_raw = {}
sections_norm = {}

In [None]:
xy_size = 256
downs_factor = 2

for z in range(54, 75):
    #if z not in sections_raw:
        print (z)
        sections_raw[z] = unaligned_raw[0:xy_size, 0:xy_size, z:z+1]
        
for k, v in sections_raw.items():
    sections_norm[k] = norm(
        tensor_ops.common.interpolate(
            v, 
            scale_factor=(1/downs_factor, 1/downs_factor, 1), 
            mode="img"
        )
    ).unsqueeze(0).squeeze(-1)

In [None]:
def compute_affine_alignment_loss(src, tgt, theta):
    grid = F.affine_grid(theta, src.size()).float()
    src_warped = F.grid_sample(src, grid)
    src_tissue_warped = F.grid_sample((src != 0).float(), grid) != 0
    
    mse_map = (src_warped - tgt).abs()
    mse_mask = src_tissue_warped * (tgt != 0)
    result = mse_map[mse_mask].sum()
    return result

def align_tilts(src, tgt, lr=1e-3, num_iter=100, theta=None):
    if theta is None:
        theta = torch.tensor(
            [[[1, 0, 0], [0, 1, 0]]],
            dtype=float,
            requires_grad=True
        )
    else:
        theta = theta.clone()
        theta.requires_grad = True
    
    optimizer = torch.optim.Adam(
        [theta],
        lr=lr,
    )


    for i in range(num_iter):
        loss = compute_affine_alignment_loss(
            src=src,
            tgt=tgt,
            theta=theta
        )
        if loss < 0.005:
            break
        if i % 20 == 0:
            print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return theta.detach()

In [None]:
inital_thetas = [
    torch.tensor(
        [[[1, 0, 0], [0, 1, 0]]],
        dtype=float,
    )
    for z in range(128)
]
factor = (downs_factor * list(sections_norm.values())[0].shape[-1])
inital_thetas[10] = torch.tensor(
        [[[1, 0, 0], [0, 1, -255 / factor]]],
        dtype=float,
    ) 

inital_thetas[117] =  torch.tensor(
        [[[1, 0, -30 / factor], [0, 1, 850 / factor]]],
        dtype=float,
    ) 

In [None]:
z = 119
src = sections_norm[z]
tgt = sections_norm[z-1]
grid = F.affine_grid(inital_thetas[z], src.size()).float()
x = F.grid_sample(src, grid).float()
visualize_list([x, src, tgt])

In [None]:
pairwise_thetas = {}

In [None]:
for z in range(73, 74):
    #if z not in pairwise_thetas:
        print (z)
        src = sections_norm[z]
        tgt = sections_norm[z-1]
        pairwise_thetas[z] = align_tilts(src, tgt, theta=inital_thetas[z], num_iter=400, lr=1e-3)


In [None]:
src.shape

In [None]:
field = align_with_online_finetuner(
    src.squeeze(0).unsqueeze(-1), 
    tgt.squeeze(0).unsqueeze(-1),
    sm=10
).unsqueeze(0).squeeze(-1)

In [None]:
src.shape, field.shape

In [None]:
z = 73
src = sections_norm[z]
tgt = sections_norm[z-1]
grid = F.affine_grid(pairwise_thetas[z], src.size()).float()
x = F.grid_sample(src, grid).float()
visualize_list([field.from_pixels()(src), x, src, tgt, field])

In [None]:
def add_last_row(x):
    return torch.cat(
        [
            x,
            torch.tensor(
                [[[0, 0, 1]]],
                dtype=float,
            )
        ], 
        1
    )

def compose_thetas(x, y):
    x_ = add_last_row(x)
    y_ = add_last_row(y)
    return torch.matmul(x_, y_)[:, :-1]

def invert_theta(x):
    x_ = add_last_row(x)
    return  torch.inverse(x_)[:, :-1]

In [None]:
mid_section = 64
final_thetas = {
    mid_section: torch.tensor(
        [[[1, 0, 0], [0, 1, 0]]],
        dtype=float,
    )
}
start_section = 54
end_section = 74
aligned_imgs_norm = {64: sections_norm[64]}

for z in range(mid_section + 1, end_section+1):
    final_thetas[z] = compose_thetas(
        pairwise_thetas[z],
        final_thetas[z - 1],
        
    )
    grid = F.affine_grid(final_thetas[z], src.size()).float()
    aligned_imgs_norm[z] = F.grid_sample(sections_norm[z], grid).float()
    
for z in range(mid_section - 1, start_section-1, -1):
    final_thetas[z] = compose_thetas(
        invert_theta(pairwise_thetas[z + 1]),
        final_thetas[z + 1],
        
    )
    grid = F.affine_grid(final_thetas[z], src.size()).float()
    aligned_imgs_norm[z] = F.grid_sample(sections_norm[z], grid).float()



In [None]:

visualize_list([aligned_imgs_norm[z] for z in range(54, 74)])

In [None]:
import pickle
with open("final_thetas_x1.pkl", 'wb') as f:
    pickle.dump(final_thetas, f)

In [None]:
aligned_path = 'gs://tmp_2w/tomography/jun23/Section_1_rawtilts_aligned_x4'

aligned = build_cv_layer(
    path=aligned_path,
    default_desired_resolution=(1, 1, 1),
    index_resolution=(1, 1, 1),
    info_chunk_size=(512, 512, 1),
    info_reference_path=unaligned_path,
    info_field_overrides={
        'data_type': 'float32'
    },
    on_info_exists='override'
)

In [None]:
for z in tqdm(range(start_section, end_section+1)):
    img_raw = sections_raw[z].unsqueeze(0).squeeze(-1).float()
    grid = F.affine_grid(final_thetas[z], img_raw.size()).float()
    
    aligned[0:xy_size, 0:xy_size, z:z+1] = F.grid_sample(img_raw, grid).float().squeeze(0).unsqueeze(-1)
    
        

In [None]:
"hi"