# Purpose:
- Fix registration to z-stack issue.
- From 240517_online_motion_correction_analysis.ipynb

In [5]:
import numpy as np
from pystackreg import StackReg
from pathlib import Path
from glob import glob
import tifffile
import h5py
import matplotlib.pyplot as plt
from dask.distributed import Client
from dask import delayed, compute
import napari
import cv2
from tifffile import TiffFile, imread, imsave, read_scanimage_metadata
from pprint import pprint

In [3]:
def rolling_average_zstack(zstack, rolling_window_flank=2):
    new_zstack = np.zeros(zstack.shape)
    for i in range(zstack.shape[0]):
        new_zstack[i] = np.mean(zstack[max(0, i-rolling_window_flank) : min(zstack.shape[0], i+rolling_window_flank), :, :],
                                axis=0)
    return new_zstack


def get_matched_zstack(emf_fn, ops_fn, zstack_dir, num_planes_around=40):
    ''' 
    
    
    Notes
    - Rolling average of z-stacks was not enough.
    '''
    ops = np.load(ops_fn, allow_pickle=True).item()
    y_roll_bottom = np.min(ops['reg_result'][4][0])
    y_roll_top = np.max(ops['reg_result'][4][0])
    x_roll_left = np.max(ops['reg_result'][4][1])
    x_roll_right = np.min(ops['reg_result'][4][1])
    if y_roll_bottom >= 0:
        y_roll_bottom = -1
    if x_roll_right >= 0:
        x_roll_right = -1

    zstack_fn_list = glob(str(zstack_dir /'ophys_experiment_*_local_z_stack.tiff'))
    center_zstacks = []
    for zstack_fn in zstack_fn_list:
        zstack = tifffile.imread(zstack_fn)
        zstack = med_filt_z_stack(zstack)
        new_zstack = rolling_average_zstack(zstack)
        center_ind = int(np.floor(new_zstack.shape[0]/2))
        center_zstack = new_zstack[center_ind - num_planes_around//2 : center_ind + num_planes_around//2+1]
        center_zstack = center_zstack[:, y_roll_top:y_roll_bottom, x_roll_left:x_roll_right]

        center_zstacks.append(center_zstack)
    first_emf = tifffile.imread(emf_fn)[0, y_roll_top:y_roll_bottom, x_roll_left:x_roll_right]
    # first_emf = tifffile.imread(emf_fn)[0]
    # first_emf = tifffile.imread(emf_fn)[:, y_roll_top:y_roll_bottom, x_roll_left:x_roll_right].mean(axis=0)

    first_emf_clahe = image_normalization_uint16(first_emf)
    assert first_emf.min() > 0
    valid_pix_threshold = first_emf.min()/10
    num_pix_threshold = first_emf.shape[0] * first_emf.shape[1] / 2

    sr = StackReg(StackReg.AFFINE)
    corrcoef = np.zeros((len(center_zstacks), center_zstacks[0].shape[0]))
    
    best_tmat_array = []
    emf_reg_array = []
    for i, zstack in enumerate(center_zstacks):
        temp_cc = []
        tmat_list = []
        for j, zstack_plane in enumerate(zstack):
            zstack_plane_clahe = image_normalization_uint16(zstack_plane)            
            tmat = sr.register(zstack_plane_clahe, first_emf_clahe)
            emf_reg = sr.transform(first_emf, tmat=tmat)            
            valid_y, valid_x = np.where(emf_reg > valid_pix_threshold)
            if len(valid_y) > num_pix_threshold:
                temp_cc.append(np.corrcoef(zstack_plane.flatten(), emf_reg.flatten())[0,1])
                tmat_list.append(tmat)
            else:
                temp_cc.append(0)
                tmat_list.append(np.eye(3))
        temp_ind = np.argmax(temp_cc)
        best_tmat = tmat_list[temp_ind]
        emf_reg = sr.transform(first_emf, tmat=best_tmat)
        for j, zstack_plane in enumerate(zstack):
            corrcoef[i,j] = np.corrcoef(zstack_plane.flatten(), emf_reg.flatten())[0,1]
        best_tmat_array.append(best_tmat)
        emf_reg_array.append(emf_reg)

    matched_ind = np.argmax(np.mean(corrcoef, axis=1))
    best_tmat_array = np.array(best_tmat_array)
    emf_reg_array = np.array(emf_reg_array)
    return matched_ind, zstack_fn_list, corrcoef, best_tmat_array, emf_reg_array


def med_filt_z_stack(zstack, kernel_size=5):
    """Get z-stack with each plane median-filtered

    Parameters
    ----------
    zstack : np.ndarray
        z-stack to apply median filtering
    kernel_size : int, optional
        kernel size for median filtering, by default 5
        It seems only certain odd numbers work, e.g., 3, 5, 11, ...

    Returns
    -------
    np.ndarray
        median-filtered z-stack
    """
    filtered_z_stack = []
    for image in zstack:
        filtered_z_stack.append(cv2.medianBlur(
            image.astype(np.uint16), kernel_size))
    return np.array(filtered_z_stack)


def image_normalization_uint16(image, im_thresh=0):
    """Normalize 2D image and convert to uint16
    Prevent saturation.

    Args:
        image (np.ndarray): input image (2D)
                            Just works with 3D data as well.
        im_thresh (float, optional): threshold when calculating pixel intensity percentile.
                            0 by default
    Return:
        norm_image (np.ndarray)
    """
    clip_image = np.clip(image, np.percentile(
        image[image > im_thresh], 0.2), np.percentile(image[image > im_thresh], 99.8))
    norm_image = (clip_image - np.amin(clip_image)) / \
        (np.amax(clip_image) - np.amin(clip_image)) * 0.9
    uint16_image = ((norm_image + 0.05) *
                    np.iinfo(np.uint16).max * 0.9).astype(np.uint16)
    return uint16_image


def metadata_from_scanimage_tif(stack_path):
    """Extract metadata from ScanImage tiff stack

    Dev notes:
    Seems awkward to parse this way
    Depends on ScanImageTiffReader

    Parameters
    ----------
    stack_path : str
        Path to tiff stack

    Returns
    -------
    dict
        stack_metadata: important metadata extracted from scanimage tiff header
    dict
        si_metadata: all scanimge metadata. Each value still a string, so convert if needed.
    dict
        roi_groups_dict: 
    """
    with open(stack_path, 'rb') as fh:
        metadata = read_scanimage_metadata(fh)

    stack_metadata = {}
    stack_metadata['num_slices'] = int(metadata[0]['SI.hStackManager.actualNumSlices'])
    stack_metadata['num_volumes'] = int(metadata[0]['SI.hStackManager.actualNumVolumes'])
    stack_metadata['frames_per_slice'] = int(metadata[0]['SI.hStackManager.framesPerSlice'])
    stack_metadata['z_steps'] = metadata[0]['SI.hStackManager.zs']
    stack_metadata['actuator'] = metadata[0]['SI.hStackManager.stackActuator']
    stack_metadata['num_channels'] = sum((metadata[0]['SI.hPmts.powersOn']))
    stack_metadata['z_step_size'] = int(metadata[0]['SI.hStackManager.actualStackZStepSize'])  

    roi_groups_dict = metadata[1]

    si_metadata = metadata[0]

    return stack_metadata, si_metadata, roi_groups_dict

In [119]:
# Results from the previous registration using StackReg
data_dir = Path(r'\\allen\programs\mindscope\workgroups\learning\pilots\online_motion_correction\mouse_726433\test_240531')
emf_fn_list = glob(str(data_dir / '*_timeseries_omc_test_00002_*_emf.tif'))

omc_parent_opid = []
for fn in emf_fn_list:
    h5fn = fn.split('.')[0] + '_zdrift.h5'
    with h5py.File(h5fn, 'r') as h:
        mc_matched_zstack_fn = h['matched_zstack_fn'][()]
        parent_opid = Path(mc_matched_zstack_fn.decode('utf-8', errors='replace')).name.split('_')[2]
    omc_parent_opid.append(parent_opid)

pprint(omc_parent_opid)

['1369640314',
 '1369640314',
 '1369640317',
 '1369640317',
 '1369640321',
 '1369640321',
 '1369640324',
 '1369640324']


In [120]:
max_cc_zstack_finding = []
for fn in emf_fn_list:
    h5fn = fn.split('.')[0] + '_zdrift.h5'
    with h5py.File(h5fn, 'r') as h:
        # print(h.keys())
        # corrcoef = h['corrcoef'][:]
        corrcoef_zstack_finding = h['corrcoef_zstack_finding'][:]
    max_cc_zstack_finding.append(corrcoef_zstack_finding.max())
pprint(max_cc_zstack_finding)

[0.9155909943170593,
 0.8679927856831672,
 0.9146133190171375,
 0.8434830594148753,
 0.8953753884488609,
 0.8104135336372471,
 0.8195972298432874,
 0.6863223911348383]


In [72]:
test_fn = emf_fn_list[-1]
test_h5fn = test_fn.split('.')[0] + '_zdrift.h5'
fn_base = Path(test_fn).name.split('.')[0]
ops_fn = Path(test_fn).parent / f'{fn_base[:-3]}ops.npy'
ops = np.load(ops_fn, allow_pickle=True).item()
y_roll_top = np.max(ops['reg_result'][4][0])
y_roll_bottom = np.min(ops['reg_result'][4][0])    
x_roll_left = np.max(ops['reg_result'][4][1])
x_roll_right = np.min(ops['reg_result'][4][1])
if y_roll_bottom >= 0:
    y_roll_bottom = -1
if x_roll_right >= 0:
    x_roll_right = -1

with h5py.File(test_h5fn, 'r') as h:
    print(h.keys())
    tmat = h['tmat'][:]
    corrcoef_zstack_finding = h['corrcoef_zstack_finding'][:]
    emf_registered = h['emf_registered'][:]
    matched_inds = h['matched_inds'][:]
    matched_zstack_fn = Path('/' + h['matched_zstack_fn'][()].decode('utf-8', errors='replace'))
zstack_ref = tifffile.imread(matched_zstack_fn)[:, y_roll_top:y_roll_bottom, x_roll_left:x_roll_right]


<KeysViewHDF5 ['corrcoef', 'corrcoef_zstack_finding', 'emf_registered', 'matched_inds', 'matched_zstack_fn', 'tmat', 'tmat_zstack_finding', 'zstack_fn_list']>


In [73]:
viewer = napari.Viewer()
for emf in emf_registered:
    viewer.add_image(emf)
viewer.add_image(zstack_ref)

<Image layer 'zstack_ref' at 0x1c39dddab20>

In [74]:
np.array([(tm == np.eye(3)).all() for tm in tmat]).all()

False

# StackReg does not work well
- For some cases, (ch2 with the same fastZ values; they had low power)
-  Fixed slurm sbatch code.
- There is x-y difference between channels.
    - How about fixing this first using phase correlation?

In [79]:
viewer = napari.Viewer()
for fn in emf_fn_list:
    emf = tifffile.imread(fn)[0]
    viewer.add_image(emf)

In [88]:
viewer = napari.Viewer()
for fn in emf_fn_list[2:4]:
    emf = tifffile.imread(fn)
    viewer.add_image(emf)

In [77]:
emf_registered_all = []
matched_inds_all = []
matched_zstack_fn_all = []
corrcoef_zstack_finding_all = []
for test_fn in emf_fn_list:
    test_h5fn = test_fn.split('.')[0] + '_zdrift.h5'
    with h5py.File(test_h5fn, 'r') as h:
        corrcoef_zstack_finding = h['corrcoef_zstack_finding'][:]
        emf_registered = h['emf_registered'][:]
        matched_inds = h['matched_inds'][:]
        matched_zstack_fn = Path('/' + h['matched_zstack_fn'][()].decode('utf-8', errors='replace'))
    emf_registered_all.append(emf_registered)
    matched_inds_all.append(matched_inds)
    matched_zstack_fn_all.append(matched_zstack_fn)
    corrcoef_zstack_finding_all.append(corrcoef_zstack_finding)


In [78]:
viewer = napari.Viewer()
for emf_registered in emf_registered_all:
    viewer.add_image(emf_registered[0])
# tmats are different between similar planes

In [108]:
emf_fn = emf_fn_list[-1]
emf = tifffile.imread(emf_fn)[0]
matched_zstack_fn = Path('//allen/programs/mindscope/workgroups/learning/pilots/online_motion_correction/mouse_726433/test_240531/ophys_session_1369518919/ophys_experiment_1369640314_local_z_stack.tiff')
zstack = tifffile.imread(matched_zstack_fn)

viewer = napari.Viewer()
viewer.add_image(emf)
viewer.add_image(zstack)

<Image layer 'zstack' at 0x1c4d2414400>

In [110]:
# manual matching
new_zstack = med_filt_z_stack(zstack)
new_zstack = rolling_average_zstack(new_zstack)
matched_zplane = new_zstack[46]
zplane_clahe = image_normalization_uint16(matched_zplane)
emf_clahe = image_normalization_uint16(emf)

In [111]:
sr = StackReg(StackReg.AFFINE)
tmat = sr.register(zplane_clahe, emf_clahe)
emf_clahe_reg = sr.transform(emf_clahe, tmat=tmat)
viewer = napari.Viewer()
viewer.add_image(emf_clahe_reg)
viewer.add_image(zplane_clahe)

<Image layer 'zplane_clahe' at 0x1c4dbcb4880>

In [100]:
emf_clahe.shape

(512, 512)

In [None]:
emf_reg = sr.transform(emf, tmat=sr._tmats)

In [102]:
def get_correlation_after_reg(fov, zstack, use_clahe=True, sr_method='affine', tmat=None):
    if use_clahe:
        fov_for_reg = image_normalization_uint16(fov)
        zstack_for_reg = np.zeros(zstack.shape)
        for zi in range(zstack.shape[0]):
            temp_zplane = image_normalization_uint16(zstack[zi])
            zstack_for_reg[zi] = temp_zplane
    else:
        fov_for_reg = fov.copy()
        zstack_for_reg = zstack.copy()
    
    if sr_method == 'affine':
        sr = StackReg(StackReg.AFFINE)
    elif sr_method == 'rigid_body':
        sr = StackReg(StackReg.RIGID_BODY)
    else:
        raise ValueError('"sr_method" should be either "affine" or "rigid_body"')
    
    assert fov.min() > 0
    valid_pix_threshold = fov.min()/10
    num_pix_threshold = fov.shape[0] * fov.shape[1] / 2
    
    corrcoef = np.zeros(zstack.shape[0])
    
    if tmat is None:
        temp_cc = []
        tmat_list = []
        for zi in range(zstack_for_reg.shape[0]):
            zstack_plane_clahe = zstack_for_reg[zi]
            zstack_plane = zstack[zi]
            tmat = sr.register(zstack_plane_clahe, fov_for_reg)
            fov_reg = sr.transform(fov, tmat=tmat)            
            valid_y, valid_x = np.where(fov_reg > valid_pix_threshold)
            if len(valid_y) > num_pix_threshold:
                temp_cc.append(np.corrcoef(zstack_plane[valid_y, valid_x].flatten(),
                                           fov_reg[valid_y, valid_x].flatten())[0,1])
                tmat_list.append(tmat)
            else:
                temp_cc.append(0)
                tmat_list.append(np.eye(3))
        temp_ind = np.argmax(temp_cc)
        best_tmat = tmat_list[temp_ind]
    else:
        best_tmat = tmat
    fov_reg = sr.transform(fov, tmat=best_tmat)
    for zi, zstack_plane in enumerate(zstack):
        corrcoef[zi] = np.corrcoef(zstack_plane[valid_y, valid_x].flatten(),
                                   fov_reg[valid_y, valid_x].flatten())[0,1]
    matched_ind = np.argmax(corrcoef)

    return matched_ind, corrcoef, fov_reg, best_tmat, tmat_list

In [112]:
matched_ind, corrcoef, fov_reg, best_tmat, tmat_list = get_correlation_after_reg(emf, new_zstack)

In [114]:
viewer = napari.Viewer()
viewer.add_image(fov_reg)

<Image layer 'fov_reg' at 0x1c4dbcb41c0>

In [115]:
corrcoef

array([0.15799288, 0.15658885, 0.15544704, 0.15278767, 0.14936238,
       0.14548563, 0.14116875, 0.13594433, 0.13198478, 0.12948122,
       0.12750878, 0.12702488, 0.12759204, 0.12879196, 0.13159325,
       0.13551517, 0.14073303, 0.14710477, 0.15349493, 0.15957989,
       0.16604109, 0.1727442 , 0.18013914, 0.18756444, 0.19488358,
       0.20102375, 0.20716125, 0.2119622 , 0.21487481, 0.21665308,
       0.21657319, 0.21518891, 0.21013952, 0.20329907, 0.19665226,
       0.18921109, 0.18115776, 0.17303211, 0.16623403, 0.15940293,
       0.15390292, 0.14910624, 0.14466624, 0.14061606, 0.13833304,
       0.13653966, 0.13590848, 0.13662614, 0.13779613, 0.13967729,
       0.14230996, 0.14467322, 0.14621411, 0.14723888, 0.14794339,
       0.1484995 , 0.1502469 , 0.15131686, 0.15230892, 0.15347177,
       0.15512494, 0.15830219, 0.1604132 , 0.16348256, 0.16588682,
       0.16683656, 0.16917295, 0.17174088, 0.1740844 , 0.17716151,
       0.17903563, 0.18058055, 0.1813198 , 0.1821405 , 0.18400

In [106]:
corrcoef

array([0.1875511 , 0.15010333, 0.14921082, 0.14948493, 0.14991655,
       0.15114204, 0.15246722, 0.15559095, 0.15398175, 0.15777677,
       0.1583372 , 0.15482468, 0.16218258, 0.16274028, 0.16773384,
       0.17163666, 0.17208537, 0.17324004, 0.17766018, 0.17908206,
       0.18104921, 0.19198186, 0.1928775 , 0.19204059, 0.20287014,
       0.20869923, 0.21037484, 0.21732722, 0.22060106, 0.22894865,
       0.23459593, 0.24572491, 0.24041272, 0.25796574, 0.26398675,
       0.26606035, 0.27398805, 0.28119802, 0.29167247, 0.29265458,
       0.27932194, 0.29322207, 0.28901348, 0.29597984, 0.28951687,
       0.29554355, 0.29165697, 0.28271473, 0.28776744, 0.26372394,
       0.27134012, 0.26488738, 0.25330275, 0.24628301, 0.24397952,
       0.23902716, 0.22604561, 0.21860348, 0.22486392, 0.21794637,
       0.21057846, 0.20508086, 0.19157065, 0.18910964, 0.18224201,
       0.17820477, 0.17072989, 0.16626039, 0.16861135, 0.15717158,
       0.15629845, 0.15294491, 0.15037264, 0.14418139, 0.14293

## StackReg DOES work well
- Fixed after calculating correlation in the valid pixels only.
- Valid pixel calculated as those with positive value after StackReg.
- In the future, if something looks bad, then try initializing with phase correlation first (rough translational registration)