# Define functions

Starting point: neuron detections

1. Match features (general, not just neurons)
2. Loop over neurons and build local affine transformations
3. Combine local transformations into a global vector field

In [1]:
# Loosely inspired from caiman, but they only do translation, not rotation
import pandas as pd
import matplotlib.pyplot as plt
from skimage import transform
from skimage.transform import warp
import numpy as np
import zarr
import napari
from DLC_for_WBFM.utils.postprocessing.base_cropping_utils import get_crop_coords3d
from DLC_for_WBFM.utils.feature_detection.utils_rigid_alignment import calc_warp_ECC
from tqdm.auto import tqdm
from DLC_for_WBFM.utils.video_and_data_conversion.import_video_as_array import get_single_volume
from DLC_for_WBFM.utils.projects.utils_project import load_config
from pathlib import Path
from DLC_for_WBFM.utils.projects.utils_project import safe_cd
import cv2
from DLC_for_WBFM.utils.feature_detection.visualization_tracks import visualize_tracks
from DLC_for_WBFM.utils.preprocessing.utils_tif import PreprocessingSettings
from DLC_for_WBFM.utils.preprocessing.utils_tif import perform_preprocessing
from DLC_for_WBFM.utils.feature_detection.utils_features import build_features_and_match_2volumes, extract_map1to2_from_matches
from scipy.ndimage import distance_transform_edt
%load_ext autoreload
import itertools
import pickle
%autoreload 2

In [2]:

project_path = r"Y:\shared_projects\wbfm\dlc_stacks\Charlie-worm3-long\project_config.yaml"
cfg = load_config(project_path)

red_btf = cfg['red_bigtiff_fname']
num_z = cfg['dataset_params']['num_slices']

project_dir = Path(project_path).parent

with safe_cd(project_dir):
    train_fname = Path(cfg['subfolder_configs']['training_data'])
    train_cfg = dict(load_config(train_fname))
    
    p_fname = train_cfg['preprocessing_config']
    p = PreprocessingSettings.load_from_yaml(p_fname)

In [3]:
# Get data
i_ref = 300
ref_frame_raw = get_single_volume(red_btf, i_ref, num_z, 0.15)

i_test = 301
test_frame_raw = get_single_volume(red_btf, i_test, num_z, 0.15)

In [4]:
# Do z-alignment (preprocessing)

ref_frame = perform_preprocessing(ref_frame_raw, p)
test_frame = perform_preprocessing(test_frame_raw, p)

In [5]:
# Get segmentation
# fname = r"Y:\shared_projects\wbfm\dlc_stacks\Charlie-worm3-long\4-traces\reindexed_masks.zarr"
fname = r"Y:\shared_projects\wbfm\dlc_stacks\Charlie-worm3-long\1-segmentation\masks_1500.zarr"
z = zarr.open(fname)

ref_seg = np.array(z[i_ref,...])
test_seg = np.array(z[i_test,...])

# And metadata
fname = r"Y:\shared_projects\wbfm\dlc_stacks\Charlie-worm3-long\1-segmentation\metadata_1500.pickle"
with open(fname, 'rb') as f:
    seg_metadata = pickle.load(f)

# Step 0: Apply a pre-rotation (global)

In [17]:
h_2d, rotated_frame, zxy0, zxy1 = get_warp_via_features_from_imgs(ref_frame, test_frame, apply_to_slices=True)

  0%|          | 0/32 [00:00<?, ?it/s]

# Steps 1 and 2:  match features and build local flow fields

In [10]:
# Build a flow field from the rotation matrix
def flow_field_from_matrix(shape, A):
    if A is None or len(A)==0:
        return None
    x, y = shape
    xx, yy = np.arange(x), np.arange(y)
    A = cv2.invertAffineTransform(A)
    
    all_xy = np.array([[yy[i_y], xx[i_x], 1] for i_x, i_y in itertools.product(range(x), range(y))])
    flow_long = A@all_xy.T
    flow0 = np.reshape(flow_long[0,:], (x, y), order='A')
    flow1 = np.reshape(flow_long[1,:], (x, y), order='A')
    flow = np.stack([flow0, flow1], axis=2).astype('float32')
    return flow

# Redo feature matching on the crop
def get_warp_via_features_from_imgs(ref_frame, test_frame, apply_to_slices=False):
    # 3d
    all_locs0, all_locs1, all_kp0, all_kp1, all_matches = build_features_and_match_2volumes(
        ref_frame, test_frame, use_GMS=True, verbose=0, start_plane=5, matches_to_keep=0.5
    )
    if len(all_locs0) == 0:
#         print("No matches found")
        return all_kp0, all_kp1, None, None
    zxy0 = np.array(all_locs0[:,1:], dtype='float64')
    zxy1 = np.array(all_locs1[:,1:], dtype='float64')
    # Try skimage
    trans = transform.estimate_transform('euclidean', zxy1, zxy0)
    h_2d = trans.params[:2,:]
    
    rotated_frame = None
    if h_2d is not None:
        if apply_to_slices:
            rotated_frame = np.zeros_like(test_frame)
            for i, f in enumerate(test_frame):
                sz = (f.shape[1], f.shape[0])
                out = cv2.warpAffine(f, h_2d, dsize=sz)
                rotated_frame[i,...] = out
    else:
        print("No rotation found")
    return h_2d, rotated_frame, zxy0, zxy1

def vector_field_to_flow(vf):
    h, w = vf.shape[:2]
    flow = vf.copy()
    flow = -flow
    flow[:,:,0] += np.arange(w)
    flow[:,:,1] += np.arange(h)[:,np.newaxis]
    return flow

def flow_to_vector_field(flow):
    h, w = flow.shape[:2]
    vf = flow.copy()
    vf[:,:,0] -= np.arange(w)
    vf[:,:,1] -= np.arange(h)[:,np.newaxis]
    vf = -vf
    return vf

def warp_flow(img, flow, apply_directly=False):
    # From:
    # https://github.com/opencv/opencv/blob/master/samples/python/opt_flow.py#L50-L56
    # apply_directly should be False if the output is from 
    if not apply_directly:
        flow = vector_field_to_flow(flow)
    res = cv2.remap(img, flow, None, cv2.INTER_LINEAR)
    return res

def get_cropped_image(centroid, test_frame):
    # Get a cropped cube in both volumes
    z, x, y = get_crop_coords3d(centroid, crop_sz=(9, 128, 128), clip_sz=None)

    test_crop_full_size = np.zeros_like(test_frame)
#     test_crop_full_size[z[0]:z[-1], x[0]:x[-1], y[0]:y[-1]] = test_frame[z[0]:z[-1], x[0]:x[-1], y[0]:y[-1]]
    test_crop_full_size[:, x[0]:x[-1], y[0]:y[-1]] = test_frame[:, x[0]:x[-1], y[0]:y[-1]]
    return test_crop_full_size


def get_flow_field_from_centroid(centroid, ref_frame, test_frame):
    cropped_test_frame = get_cropped_image(centroid, test_frame)
    h_2d, _, _, _ = get_warp_via_features_from_imgs(ref_frame, cropped_test_frame)
    if h_2d is None:
        return None, None
    flow = flow_field_from_matrix(cropped_test_frame.shape[1:], h_2d)
    weights = cropped_test_frame > 0
    return flow, weights

In [18]:
all_flows, all_weights = [], []
num_neurons = len(seg_metadata[i_ref])
for i_seg in tqdm(range(num_neurons), total=num_neurons):
    # Get a centroid
    centroid = seg_metadata[i_ref].iloc[i_seg]['centroids']
    flow, weights = get_flow_field_from_centroid(centroid, ref_frame, rotated_frame)
    all_flows.append(flow)
    all_weights.append(weights)
    
    if i_seg > 10:
        break

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

# Step 3: Combine the transforms into a global field

In [12]:
def dist_func(x, thresh):
    if x == 0 or x > thresh:
        return 0
    else:
        return 1.0/x


def combine_flow_fields(all_flows, all_weights, outliers, err_thresh=30):
    # Weight each flow field by how accurate the overlap is
    flow = np.zeros_like(all_flows[0])
    overlapping_weights = np.ones(flow.shape[:-1])
    for i, (w, f) in tqdm(enumerate(zip(all_weights, all_flows))):
        if i in outliers or f is None or len(f)==0:
            continue
        mask = np.max(w,axis=0)
        # Get both raw patches
        centroid = seg_metadata[i_ref].iloc[i]['centroids']
        i_slice = int(centroid[0])
        this_ref = np.where(mask, ref_frame[i_slice,...], 0)
        this_test = np.where(mask, test_frame[i_slice,...], 0)
        

        w_dist = mask
#         w_dist = distance_transform_edt(mask)
        overlapping_weights += w_dist
#         flow[...,0] += np.where(w_binary, f[...,0], 0)
#         flow[...,1] += np.where(w_binary, f[...,1], 0)
        flow[...,0] += np.multiply(f[...,0], w_dist)
        flow[...,1] += np.multiply(f[...,1], w_dist)
    tmp = np.stack([overlapping_weights, overlapping_weights], axis=-1)
    flow = np.divide(flow, tmp)
    return flow, overlapping_weights

In [19]:
outliers = []
# outliers = [0, 88]
# outliers = [13, 17, 32, 74, 88, 89, 93, 94, 96, 106, 108, 111, 121, 123, 129, 149, 153, 155, 157, 159, 162]
flow, overlapping_weights = combine_flow_fields(all_flows, all_weights, outliers)

0it [00:00, ?it/s]

# Step 4: Check!

In [22]:
# CHECK
i_slice = 18

from ipywidgets import interact

def f(i, z):
    plt.figure(figsize=(45,25))
    centroid = seg_metadata[i_ref].iloc[i]['centroids']
    
    i_slice = int(centroid[0])
    
    plt.imshow(warp_flow(rotated_frame[z,...], all_flows[i], apply_directly=True), alpha=0.5, cmap="Reds")
    
    tmp = get_cropped_image(centroid, ref_frame)
    plt.imshow(tmp[z,...], alpha=0.5, cmap="Greens")
    plt.title(f"Overlay after global and local affine transformations (slice {i_slice})")

interact(f, i=(0, len(all_flows)), z=(0,rotated_frame.shape[0]-1))

interactive(children=(IntSlider(value=6, description='i', max=12), IntSlider(value=15, description='z', max=31…

<function __main__.f(i, z)>

In [16]:
# CHECK
i_slice = 17

plt.figure(figsize=(25,5))
plt.imshow(rotated_frame[i_slice,...], alpha=0.5, cmap="Reds")
plt.imshow(ref_frame[i_slice,...], alpha=0.5, cmap="Greens")
plt.title("Overlay after affine transformation")


NameError: name 'rotated_frame' is not defined

<Figure size 1800x360 with 0 Axes>