# Setting Up

Cross-registration allows the user to register the outcomes of the MiniAn pipeline across multiple experimental sessions.
It is a useful add-on to deal with longitudinal experiments. 

## specify directories and dataset patterns

For cross-registration to work, we need to have existing datasets with proper metadata.
At the minimum, a dimension `session` should exists on all the datasets.
Each dataset can either be a directory of `zarr` arrays (the default output format of `save_minian`), or a single file saved by users.
Each dataset should reside in its own directory.

Details on the parameters:

* `minian_path` points to the path of minian package, which by default is the current folder.
* `dpath` is the path containing all the datasets.
    It will be traversed recursively to search for datasets.
* `f_pattern` is the directory/file name pattern of each dataset.
    The program will attempt to load all directories/files matching `f_pattern` under `dpath`.
    Note that here our demo data are `netcdf` files that are manually saved.
    For the default minian dataset format (directory of `zarr` arrays), `f_pattern = r"minian$"` should suffice.
* `id_dims` is the name of dimensions that can uniquely identify each dataset.
    It should at least contain a `"session"` dimension.

In [None]:
import os
import sys
from termcolor import colored
import re
minian_path = "."


session =  r"C:\Users\axelle.piguet\Documents\GitHub\klab_analysis\Axelle\nk50"

dpath = os.getenv("ANIMAL_PATH", "")
if not dpath :
    dpath = session

print(dpath, file = open("CON", "w") )


In [None]:
# sampling frequency parameters 
fs = 20
# Convert frame numbers to seconds
in_sec = 1 / fs
# Convert frame numbers to milliseconds
in_msec = 1000 / fs
# Convert frame numbers to minutes
in_min = 1 / (fs * 60)

In [None]:
mouse_name = re.search(r'nk[^\\]*', dpath).group(0)

f_pattern = fr"^{mouse_name}_s"  # replace with what we put in the minian_ds_path in the other notebook
f_pattern = r"result$"
id_dims = ["session"]
print(dpath)
print(f_pattern)
print(id_dims)



## specify parameters

`param_dist` defines the maximal distance between cell centroids (in pixel units) on different sessions to consider them as the same cell.
`output_size` controls the scale of visualizations.

In [None]:
param_dist = 5
output_size = 100

## load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings
import itertools as itt
import numpy as np
import xarray as xr
import holoviews as hv
import pandas as pd
from holoviews.operation.datashader import datashade, regrid
from dask.diagnostics import ProgressBar
sys.path.append(minian_path)
from minian.cross_registration import (calculate_centroids, calculate_centroid_distance, calculate_mapping,
                                       group_by_session, resolve_mapping, fill_mapping)
from minian.motion_correction import estimate_motion, apply_transform
from minian.utilities import open_minian, open_minian_mf
from minian.visualization import AlignViewer

## module initialization

In [None]:
hv.notebook_extension('bokeh', width=100)
pbar = ProgressBar(minimum=2)
pbar.register()

# Allign Videos

## open datasets

All the metadata defined in `id_dims` will be printed out for each dataset.
It is important to make sure all the metadata are correct, otherwise you may get unexpected results.
If metadata was not saved correctly, consider putting the datasets into correct hierarchical directory structures and use the `post_process` argument of `open_minian_mf` to correct for metadata.
See the main `pipeline.ipynb` and [API reference](https://minian.readthedocs.io/page/api/minian.utilities.html#minian-utilities-open_minian_mf) for more detail.

In [None]:
print(dpath, file = open("CON", "w") )
print(f_pattern, file = open("CON", "w") )
print(id_dims, file = open("CON", "w") )


In [None]:
minian_ds = open_minian_mf(
    dpath, id_dims, pattern=f_pattern)

## estimate shifts

Here we estimate a translational shift along the `session` dimension using the max projection for each dataset.
We combine the `shifts`, original templates `temps`, and shifted templates `temps_sh` into a single dataset `shiftds` to use later.

In [None]:


temps = minian_ds['max_proj'].rename('temps')
shifts = estimate_motion(temps, dim='session').compute().rename('shifts')
temps_sh = apply_transform(temps, shifts).compute().rename('temps_shifted')
shiftds = xr.merge([temps, shifts, temps_sh])

## visualize alignment

We visualize alignment of sessions by plotting the templates before and after the shift for each session.

In [None]:
hv.output(size=int(output_size * 0.6))
opts_im = {
    'aspect': shiftds.sizes['width'] / shiftds.sizes['height'],
    'frame_width': 500, 'cmap': 'viridis'}
hv_temps = (hv.Dataset(temps).to(hv.Image, kdims=['width', 'height'])
            .opts(**opts_im).layout('session').cols(1))
hv_temps_sh = (hv.Dataset(temps_sh).to(hv.Image, kdims=['width', 'height'])
            .opts(**opts_im).layout('session').cols(1))
display(hv_temps + hv_temps_sh)

## visualize overlap of field of view across all sessions

Since only pixels that are common across all sessions are considered, it is important to sanity-check that this overlap window capture most of our cells.

In [None]:
hv.output(size=int(output_size * 0.6))
opts_im = {
    'aspect': shiftds.sizes['width'] / shiftds.sizes['height'],
    'frame_width': 500, 'cmap': 'viridis'}
window = shiftds['temps_shifted'].isnull().sum('session')
window, _ = xr.broadcast(window, shiftds['temps_shifted'])
hv_wnd = hv.Dataset(window).to(hv.Image, ['width', 'height'])
hv_temps = hv.Dataset(temps_sh).to(hv.Image, ['width', 'height'])
hv_wnd.opts(**opts_im).relabel("Window") + hv_temps.opts(**opts_im).relabel("Shifted Templates")

## apply shifts and set window

If the shifts and overlaps all look good, we commit by applying them to the spatial footprints of each session.

In [None]:
minian_ds

In [None]:
A_shifted = apply_transform(minian_ds['A'].chunk(dict(height=-1, width=-1)), shiftds['shifts'])

In [None]:
def set_window(wnd):
    return wnd == wnd.min()
window = xr.apply_ufunc(
    set_window,
    window,
    input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']],
    vectorize=True)

# Cross-session registration

## calculate centroids

We start by calculating a centroid of spatial footprint for each cell.
The centroid location is the only source of information used to register cells across sessions.

In [None]:
%%time
cents = calculate_centroids(A_shifted, window)

## calculate centroid distance

We then calculate pairwise distance between cells in all pairs of sessions.
Note that at this stage, since we are computing something along the `session` dimension, it is no longer considered as a metadata dimension, so we remove it.

In [None]:
%%time
id_dims.remove("session")
dist = calculate_centroid_distance(cents, index_dim=id_dims)

## threshold centroid distances

We threshold the centroid distances and keep only cell pairs with distance less than `param_dist`.

In [None]:
dist_ft = dist[dist['variable', 'distance'] < param_dist].copy()
dist_ft = group_by_session(dist_ft)

## generate mappings

Finally we generate mapping of cells across session in three steps:

1. We filter the pairwise distances into pairwise mappings by applying a mutual nearest-neighbour criteria, using `calculate_mapping`.
1. We extend/merge pairwise mappings into multi-session mappings and drop any conficting mappings, using `resolve_mapping`.
1. We fill in "mappings" that represent cells only appeared in single sessions, using `fill_mapping`.

Please see the [API reference](https://minian.readthedocs.io/page/api/minian.cross_registration.html) for more detail on the output dataframe format.

In [None]:
%%time
mappings = calculate_mapping(dist_ft)
mappings_meta = resolve_mapping(mappings)
mappings_meta_fill = fill_mapping(mappings_meta, cents)
mappings_meta_fill.head()

In [None]:
mappings_meta_fill_save = mappings_meta_fill


In [None]:
mappings_meta_fill =mappings_meta_fill_save 

In [None]:
mappings_meta_fill["group"] =mappings_meta_fill["group"].astype(str)
mappings_meta_fill = mappings_meta_fill.sort_values(by=("group","group"), ascending = True).reset_index(drop=True)

In [None]:
session_names = mappings_meta_fill.session.columns.values
num_unit = len(mappings_meta_fill['session', session_names[0]].values)
num_session = len(minian_ds['S'].session.values)
num_frames = len(minian_ds['S'].frame.values) 
data = np.empty((num_session,num_unit,num_frames ))
data[:] = np.nan

ends = []

for sess in range(num_session ) :
    
    sess_name = minian_ds['S'].session.values[sess]
    non_nan_indices = np.where(~np.isnan(mappings_meta_fill['session', sess_name].values))
    old_units = mappings_meta_fill['session', sess_name].values[non_nan_indices]

    data[sess,non_nan_indices ,:] = minian_ds['S'].sel(unit_id = old_units, session = sess_name).values
    
    end_frame = np.where(~np.isnan(data[sess,non_nan_indices[0][0] ,:]))[0][-1]
    ends.append(end_frame)
    
starts = [0] * num_session



data_arrayS = xr.DataArray(
    data,
    dims=[ "session",  "new_unitId", "frame"],  # Define dimension names
    coords={  # Define coordinates
        "session": minian_ds['S'].session.values,
        "new_unitId": np.linspace(1,num_unit, num =num_unit, dtype = int), # rows
        "frame": minian_ds['S'].frame.values
    },
    name= minian_ds['S'].animal.values,  # Optional: Name the data array
)

data_arrayS.attrs["start_frame"] =starts
data_arrayS.attrs["end_frame"] =ends

In [None]:
data = np.empty((num_session,num_unit,num_frames ))
data[:] = np.nan


for sess in range(num_session ) :
    
    sess_name = minian_ds['C'].session.values[sess];
    non_nan_indices = np.where(~np.isnan(mappings_meta_fill['session', sess_name].values))

    old_units = mappings_meta_fill['session', sess_name].values[non_nan_indices];
    
    data[sess,non_nan_indices ,:] = minian_ds['C'].sel(unit_id = old_units, session = sess_name).values



data_arrayC = xr.DataArray(
    data,
    dims=[ "session",  "new_unitId", "frame"],  # Define dimension names
    coords={  # Define coordinates
        "session": minian_ds['C'].session.values,
        "new_unitId": np.linspace(1,num_unit, num =num_unit, dtype = int), # rows
        "frame": minian_ds['C'].frame.values
    },
    name= minian_ds['C'].animal.values,  # Optional: Name the data array
)
data_arrayC.attrs["start_frame"] =starts
data_arrayC.attrs["end_frame"] =ends

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(7, 15)) 
colors = ['palevioletred', 'purple']

begining = data_arrayC.start_frame[0]
for sess in range(num_session ) :  
    for idx in data_arrayC["new_unitId"].values:
        plt.plot(range(begining, begining + data_arrayC.end_frame[sess] ),data_arrayC.sel( new_unitId = idx, session = minian_ds['C'].session.values[sess])[data_arrayC.start_frame[sess]:data_arrayC.end_frame[sess]] + 15*idx,colors[sess%2])
    begining = begining + data_arrayC.end_frame[sess] +1
            
plt.xlabel('Time (min)')
plt.ylabel('Units')
plt.gca().spines['top'].set_visible(False)  # Hides the top spine
plt.gca().spines['right'].set_visible(False)  # Hides the right spine
#plt.gca().xaxis.set_ticks([])
#plt.gca().set_xticklabels([]) 
yticks = plt.gca().get_yticks()  # Get current y-axis tick values
plt.gca().set_yticklabels([f"{int(tick /15)}" for tick in yticks]) 
xticks = plt.gca().get_xticks()  # Get current y-axis tick values
plt.gca().set_xticklabels([f"{int(tick *in_min)}" for tick in xticks]) 
plt.savefig(os.path.join(dpath, "C.png"), format="png", dpi=300)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors
from matplotlib.colors import LogNorm

import random

colors = ['darkslateblue', 'darkgreen']
intercolors = ['slateblue', 'green']



fig, ax = plt.subplots()
fig.figsize=(15, 25)

begining = data_arrayS.start_frame[0]
for sess in range(num_session ) :
    
    #color =  LinearSegmentedColormap.from_list("white_to_red", ["white", [random.uniform(0.3, 0.7) for _ in range(3)]])
    color =  LinearSegmentedColormap.from_list("white_to_red", ["grey", "black"])#, intercolors[sess%2], colors[sess%2]])
    color =  LinearSegmentedColormap.from_list("white_to_black", ["white", [.3,.3,.3] ,[.2,.2,.2],[.1,.1,.1], "black"])
    ax.imshow(data_arrayS.sel( session = minian_ds['S'].session.values[sess])[:,data_arrayS.start_frame[sess]:data_arrayS.end_frame[sess]], cmap=color, extent=[begining, (begining + data_arrayS.end_frame[sess]) , 0, num_unit], origin='lower', aspect = 'auto')#,norm=LogNorm(vmin=np.nanmin(data_array.values)+0.0001, vmax=np.nanmax(data_array.values)))
    
    begining = begining + data_arrayS.end_frame[sess] +1
    
ax.set_xlim(0, begining)
plt.xlabel('Time (min)')
plt.ylabel('Units')
plt.gca().spines['top'].set_visible(False)  # Hides the top spine
plt.gca().spines['right'].set_visible(False)  # Hides the right spine

xticks = plt.gca().get_xticks()  # Get current y-axis tick values
plt.gca().set_xticklabels([f"{int(tick *in_min)}" for tick in xticks])
plt.savefig(os.path.join(dpath, "S.png"), format="png", dpi=300)

## visualize mappings

We visualize the matching of cells by color-mapping cells 3 arbitrary sessions into RGB channels and plot the overlay image.
Please see [API reference](https://minian.readthedocs.io/page/api/minian.visualization.html#minian-visualization-AlignViewer) for more details on the tools available in this visualization.

In [None]:
hv.output(size=int(output_size * 0.7))
alnviewer = AlignViewer(minian_ds, cents, mappings_meta_fill, shiftds)
alnviewer.show()

## save results

If everything looks good, we commit by saving the mappings into `pickle` file.
Optionally we also save centroids `cents` and `shiftds` in case they come in handy in down-stream analysis.

In [None]:
data_arrayS = data_arrayS.assign_coords(session=[str(s) for s in data_arrayS.coords['session'].values])
data_arrayC = data_arrayC.assign_coords(session=[str(s) for s in data_arrayC.coords['session'].values])

In [None]:
mappings_meta_fill.to_pickle(os.path.join(dpath, "mappings.pkl"))
cents.to_pickle(os.path.join(dpath, "cents.pkl"))
shiftds.to_netcdf(os.path.join(dpath, "shiftds.nc"))

In [None]:
data_arrayS.name = "S_all"
data_arrayC.name = "C_all"
data_arrayS.to_netcdf(os.path.join(dpath, "S_allSessions.nc"))
data_arrayC.to_netcdf(os.path.join(dpath, "C_allSessions.nc"))