# Background:
- Reading tiff files takes too long from the allen drive.
- Even just for reading the header.
# Purpose:
- Compare reading time between from the local HDD and allen drive.
- Copying codes from 240517_online_motion_correction_analysis.ipynb

In [1]:
import numpy as np
from pystackreg import StackReg
from pathlib import Path
from glob import glob
import tifffile
import h5py
import matplotlib.pyplot as plt
from dask.distributed import Client
from dask import delayed, compute
import napari
from ScanImageTiffReader import ScanImageTiffReader
import json

def rolling_average_zstack(zstack, rolling_window_flank=2):
    new_zstack = np.zeros(zstack.shape)
    for i in range(zstack.shape[0]):
        new_zstack[i] = np.mean(zstack[max(0, i-rolling_window_flank) : min(zstack.shape[0], i+rolling_window_flank), :, :],
                                axis=0)
    return new_zstack


def get_matched_zstack(emf_fn, ops_fn, zstack_dir, num_planes_around=40):
    ''' 
    
    
    Notes
    - Rolling average of z-stacks was not enough.
    '''
    ops = np.load(ops_fn, allow_pickle=True).item()
    y_roll_bottom = np.min(ops['reg_result'][4][0])
    y_roll_top = np.max(ops['reg_result'][4][0])
    x_roll_left = np.max(ops['reg_result'][4][1])
    x_roll_right = np.min(ops['reg_result'][4][1])
    if y_roll_bottom >= 0:
        y_roll_bottom = -1
    if x_roll_right >= 0:
        x_roll_right = -1

    zstack_fn_list = glob(str(zstack_dir /'ophys_experiment_*_local_z_stack.tiff'))
    center_zstacks = []
    for zstack_fn in zstack_fn_list:
        zstack = tifffile.imread(zstack_fn)
        new_zstack = rolling_average_zstack(zstack)
        center_ind = int(np.floor(new_zstack.shape[0]/2))
        center_zstack = new_zstack[center_ind - num_planes_around//2 : center_ind + num_planes_around//2+1]
        center_zstack = center_zstack[:, y_roll_top:y_roll_bottom, x_roll_left:x_roll_right]

        center_zstacks.append(center_zstack)
    first_emf = tifffile.imread(emf_fn)[0, y_roll_top:y_roll_bottom, x_roll_left:x_roll_right]
    # first_emf = tifffile.imread(emf_fn)[0]
    
    assert first_emf.min() > 0
    valid_pix_threshold = first_emf.min()/10
    num_pix_threshold = first_emf.shape[0] * first_emf.shape[1] / 3

    sr = StackReg(StackReg.AFFINE)
    corrcoef = np.zeros((len(center_zstacks), center_zstacks[0].shape[0]))
    
    for i, zstack in enumerate(center_zstacks):
        temp_cc = []
        tmat_list = []
        for j, zstack_plane in enumerate(zstack):
            tmat = sr.register(zstack_plane, first_emf)
            emf_reg = sr.transform(first_emf, tmat=tmat)            
            valid_y, valid_x = np.where(emf_reg > valid_pix_threshold)
            if len(valid_y) > num_pix_threshold:
                temp_cc.append(np.corrcoef(zstack_plane.flatten(), emf_reg.flatten())[0,1])
                tmat_list.append(tmat)
            else:
                temp_cc.append(0)
                tmat_list.append(np.eye(3))
        temp_ind = np.argmax(temp_cc)
        best_tmat = tmat_list[temp_ind]
        emf_reg = sr.transform(first_emf, tmat=best_tmat)       
        for j, zstack_plane in enumerate(zstack):
            corrcoef[i,j] = np.corrcoef(zstack_plane.flatten(), emf_reg.flatten())[0,1]
    matched_ind = np.argmax(np.mean(corrcoef, axis=1))
    return matched_ind, zstack_fn_list, corrcoef


def _extract_dict_from_si_string(string):
    """Parse the 'SI' variables from a scanimage metadata string"""

    lines = string.split('\n')
    data_dict = {}
    for line in lines:
        if line.strip():  # Check if the line is not empty
            key, value = line.split(' = ')
            key = key.strip()
            if value.strip() == 'true':
                value = True
            elif value.strip() == 'false':
                value = False
            else:
                value = value.strip().strip("'")  # Remove leading/trailing whitespace and single quotes
            data_dict[key] = value

    json_data = json.dumps(data_dict, indent=2)
    loaded_data_dict = json.loads(json_data)
    return loaded_data_dict


def _str_to_int_list(string):
    return [int(s) for s in string.strip('[]').split()]


def _str_to_bool_list(string):
    return [bool(s) for s in string.strip('[]').split()]

def metadata_from_scanimage_tif(stack_path):
    """Extract metadata from ScanImage tiff stack

    Dev notes:
    Seems awkward to parse this way
    Depends on ScanImageTiffReader

    Parameters
    ----------
    stack_path : str
        Path to tiff stack

    Returns
    -------
    dict
        stack_metadata: important metadata extracted from scanimage tiff header
    dict
        si_metadata: all scanimge metadata. Each value still a string, so convert if needed.
    dict
        roi_groups_dict: 
    """
    with ScanImageTiffReader(str(stack_path)) as reader:
        md_string = reader.metadata()

    # split si & roi groups, prep for seprate parse
    s = md_string.split("\n{")
    rg_str = "{" + s[1]
    si_str = s[0]

    # parse 1: extract keys and values, dump, then load again
    si_metadata = _extract_dict_from_si_string(si_str)
    # parse 2: json loads works hurray
    roi_groups_dict = json.loads(rg_str)

    stack_metadata = {}
    stack_metadata['num_slices'] = int(si_metadata['SI.hStackManager.actualNumSlices'])
    stack_metadata['num_volumes'] = int(si_metadata['SI.hStackManager.actualNumVolumes'])
    stack_metadata['frames_per_slice'] = int(si_metadata['SI.hStackManager.framesPerSlice'])
    stack_metadata['z_steps'] = _str_to_int_list(si_metadata['SI.hStackManager.zs'])
    stack_metadata['actuator'] = si_metadata['SI.hStackManager.stackActuator']
    stack_metadata['num_channels'] = sum(_str_to_bool_list(si_metadata['SI.hPmts.powersOn']))
    stack_metadata['z_step_size'] = int(si_metadata['SI.hStackManager.actualStackZStepSize'])

    return stack_metadata, si_metadata, roi_groups_dict

In [None]:
# This took 17.5 min
# data_dir = Path(r'\\allen\programs\mindscope\workgroups\learning\pilots\online_motion_correction\mouse_721291\test_240515_721291')

# si_fn = '240515_721291_global_30min_1366658085_timeseries_00006.tif'
# si_fn = data_dir / si_fn
# stack_metadata, si_metadata, roi_groups_dict = metadata_from_scanimage_tif(si_fn)

In [2]:
# This takes 19 s
data_dir = Path(r'D:\online motion correction')

si_fn = '240515_721291_global_30min_1366658085_timeseries_00006.tif'
si_fn = data_dir / si_fn
stack_metadata, si_metadata, roi_groups_dict = metadata_from_scanimage_tif(si_fn)