In [31]:
from pathlib import Path
from brain_observatory_qc.pipeline_dev import paired_plane_registration as ppr
import h5py
import numpy as np
import brain_observatory_qc.data_access.from_lims as from_lims
from brain_observatory_qc.pipeline_dev import decrosstalk_fov_ica as dc_fov
from suite2p.registration import nonrigid
from importlib import reload

decrosstalk_dir = Path(r'\\allen\programs\mindscope\workgroups\learning\pipeline_validation\decrosstalk')
data_dir = decrosstalk_dir / 'paired_registration_data'


In [27]:
# Total 30 min for oeid 1167237079
## To get alpha and beta for the experiment:
oeid = 1167237079
alpha_list, beta_list, signal_mean_list, paired_mean_list, recon_signal_list, recon_paired_list = \
    dc_fov.decrosstalk_movie(oeid)
alpha = np.mean(alpha_list)
beta = np.mean(beta_list)

## To reduce RAM usage, you can get the decrosstalk_data in chunks:
chunk_size = 5000 # num of frames in each chunk

signal_fn = from_lims.get_motion_corrected_movie_filepath(oeid)
with h5py.File(signal_fn, 'r') as f:
    data_shape = f['data'].shape
data_length = data_shape[0]
num_chunks = int(np.ceil(data_length / chunk_size))
start_frames = np.arange(0, data_length, chunk_size)
end_frames = np.append(start_frames[1:], data_length)

dc_fov_dir = data_dir / 'decrosstalk_fov_ica'
if not dc_fov_dir.exists():
    dc_fov_dir.mkdir()
decrosstalk_fn = dc_fov_dir / f'{oeid}_decrosstalk_test.h5'

# Get the paired plane raw movie fn
paired_oeid = from_lims.get_paired_plane_id(oeid)
paired_oeid_path = from_lims.get_motion_xy_offset_filepath(paired_oeid).parent.parent
paired_raw_movie_h5 = paired_oeid_path / (str(paired_oeid) + '.h5')
if not paired_raw_movie_h5.exists():
    raise FileNotFoundError(f'Paired raw movie not found at {paired_raw_movie_h5}')

# Get the registration info from the signal plane
shifts_df = ppr.get_s2p_motion_transform(oeid)
if_nonrigid = True if 'nonrigid_x' in shifts_df.columns else False

i = 0
for start_frame, end_frame in zip(start_frames, end_frames):
    with h5py.File(signal_fn, 'r') as f:
        signal_data = f['data'][start_frame:end_frame]
    with h5py.File(paired_raw_movie_h5, 'r') as f:
        epoch_data = f['data'][start_frame : end_frame]

        # Apply the registration to raw movie of the paired plane
        y = shifts_df['y'].values[start_frame : end_frame]
        x = shifts_df['x'].values[start_frame : end_frame]
        if if_nonrigid:
            nonrigid_y = np.vstack(shifts_df['nonrigid_y'].values)
            nonrigid_x = np.vstack(shifts_df['nonrigid_x'].values)
            nonrigid_y = nonrigid_y[start_frame : end_frame, :]
            nonrigid_x = nonrigid_x[start_frame : end_frame, :]
            # from default parameters:
            # TODO: read from a file
            Ly1 = 512
            Lx1 = 512
            block_size = (128, 128)
            blocks = nonrigid.make_blocks(Ly=Ly1, Lx=Lx1, block_size=block_size)
        paired_data = epoch_data.copy()
        for frame, dy, dx in zip(paired_data, y, x):
            frame[:] = ppr.shift_frame(frame=frame, dy=dy, dx=dx)
        if if_nonrigid:
            paired_data = nonrigid.transform_data(paired_data, yblock=blocks[0], xblock=blocks[1], nblocks=blocks[2],
                                                ymax1=nonrigid_y, xmax1=nonrigid_x, bilinear=True)

    recon_signal_data = np.zeros_like(signal_data)
    for j in range(signal_data.shape[0]):
        recon_signal_data[j, :, :] = dc_fov.apply_mixing_matrix(alpha, beta, signal_data[j, :, :], paired_data[j, :, :])[0]

    if i == 0:
        with h5py.File(decrosstalk_fn, 'w') as f:
            f.create_dataset('data', data=recon_signal_data, maxshape=(None, data_shape[1], data_shape[2]))
            f.create_dataset('alpha_list', data=alpha_list)
            f.create_dataset('beta_list', data=beta_list)

    else:
        with h5py.File(decrosstalk_fn, 'a') as f:
            f['data'].resize((f['data'].shape[0] + recon_signal_data.shape[0]), axis=0)
            f['data'][start_frame:end_frame] = recon_signal_data
    i += 1

<class 'numpy.int16'>
<class 'numpy.int16'>
<class 'numpy.int16'>
<class 'numpy.int16'>
<class 'numpy.int16'>
<class 'numpy.int16'>
<class 'numpy.int16'>
<class 'numpy.int16'>


# Estimating optimal sigma for high-pass filtering

In [60]:
reload(dc_fov)

# this takes about 90 min
sigma_optimal = dc_fov.find_optimal_sigma_exp(oeid)
# This takes about 5 min
alpha_list, beta_list, signal_mean_list, paired_mean_list, recon_signal_list, recon_paired_list = \
    dc_fov.decrosstalk_movie(oeid, filter_sigma_um = sigma_optimal)
