# Setting Up

Cross-registration allows the user to register the outcomes of the MiniAn pipeline across multiple experimental sessions.
It is a useful add-on to deal with longitudinal experiments. 

## specify directories and dataset patterns

For cross-registration to work, we need to have existing datasets with proper metadata.
At the minimum, a dimension `session` should exists on all the datasets.
Each dataset can either be a directory of `zarr` arrays (the default output format of `save_minian`), or a single file saved by users.
Each dataset should reside in its own directory.

Details on the parameters:

* `minian_path` points to the path of minian package, which by default is the current folder.
* `dpath` is the path containing all the datasets.
    It will be traversed recursively to search for datasets.
* `f_pattern` is the directory/file name pattern of each dataset.
    The program will attempt to load all directories/files matching `f_pattern` under `dpath`.
    Note that here our demo data are `netcdf` files that are manually saved.
    For the default minian dataset format (directory of `zarr` arrays), `f_pattern = r"minian$"` should suffice.
* `id_dims` is the name of dimensions that can uniquely identify each dataset.
    It should at least contain a `"session"` dimension.

In [37]:
minian_path = "."
dpath = "D:\Desktop\ZJU\\bme\\bme_com\minian-master\minian-master_v0\demo_data"
f_pattern = r"minian.nc$"
id_dims = ["session"]

## specify parameters

`param_dist` defines the maximal distance between cell centroids (in pixel units) on different sessions to consider them as the same cell.
`output_size` controls the scale of visualizations.

In [38]:
param_dist = 5
output_size = 100

## load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings
import itertools as itt
import numpy as np
import xarray as xr
import holoviews as hv
import pandas as pd
from holoviews.operation.datashader import datashade, regrid
from dask.diagnostics import ProgressBar
sys.path.append(minian_path)
from minian.cross_registration import (calculate_centroids, calculate_centroid_distance, calculate_mapping,
                                       group_by_session, resolve_mapping, fill_mapping)
from minian.motion_correction import estimate_motion, apply_transform
from minian.utilities import open_minian, open_minian_mf
from minian.visualization import AlignViewer
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

## module initialization

In [None]:
hv.notebook_extension('bokeh', width=100)
pbar = ProgressBar(minimum=2)
pbar.register()

# Allign Videos

## open datasets

All the metadata defined in `id_dims` will be printed out for each dataset.
It is important to make sure all the metadata are correct, otherwise you may get unexpected results.
If metadata was not saved correctly, consider putting the datasets into correct hierarchical directory structures and use the `post_process` argument of `open_minian_mf` to correct for metadata.
See the main `pipeline.ipynb` and [API reference](https://minian.readthedocs.io/page/api/minian.utilities.html#minian-utilities-open_minian_mf) for more detail.

In [None]:
dpath_1 = os.path.join(dpath, "OFT")
dpath_2 = os.path.join(dpath, "TST")
folder_path_1 = os.path.join(dpath, "sum\OFT")
folder_path_2 = os.path.join(dpath, "sum\TST")

In [None]:
# minian_ds_1 = open_minian(dpath_1+'\minian')
# minian_ds_2 = open_minian(dpath_2+'\minian')
# print(minian_ds_1)

In [None]:
# # 使用 np.char.add 进行字符串拼接
# minian_ds_1.coords['session'] = minian_ds_1.coords['session'].values.astype(str)
# # minian_ds_1.coords['session'] = minian_ds_1.coords['session'].values.astype(str)
# # 使用 np.char.add 进行字符串拼接
# minian_ds_2.coords['session'] = minian_ds_2.coords['session'].values.astype(str)
# # minian_ds_2.coords['session'] = minian_ds_2.coords['session'].values.astype(str)
# # 创建文件夹

# os.makedirs(folder_path_1,exist_ok=True)  # 如果文件夹已存在，会抛出异常
# os.makedirs(folder_path_2,exist_ok=True)  # 如果文件夹已存在，会抛出异常


# minian_ds_1.to_netcdf(folder_path_1+"\minian.nc")
# minian_ds_2.to_netcdf(folder_path_2+"\minian.nc")

In [None]:
minian_ds = open_minian_mf(
    dpath, id_dims, pattern=f_pattern)

In [None]:
minian_ds

## estimate shifts

Here we estimate a translational shift along the `session` dimension using the max projection for each dataset.
We combine the `shifts`, original templates `temps`, and shifted templates `temps_sh` into a single dataset `shiftds` to use later.

In [None]:
%%time
temps = minian_ds['max_proj'].rename('temps')
shifts = estimate_motion(temps, dim='session').compute().rename('shifts')
temps_sh = apply_transform(temps, shifts).compute().rename('temps_shifted')
shiftds = xr.merge([temps, shifts, temps_sh])

## visualize alignment

We visualize alignment of sessions by plotting the templates before and after the shift for each session.

In [None]:
hv.output(size=int(output_size * 0.6))
opts_im = {
    'aspect': shiftds.sizes['width'] / shiftds.sizes['height'],
    'frame_width': 500, 'cmap': 'viridis'}
hv_temps = (hv.Dataset(temps).to(hv.Image, kdims=['width', 'height'])
            .opts(**opts_im).layout('session').cols(1))
hv_temps_sh = (hv.Dataset(temps_sh).to(hv.Image, kdims=['width', 'height'])
            .opts(**opts_im).layout('session').cols(1))
display(hv_temps + hv_temps_sh)

## visualize overlap of field of view across all sessions

Since only pixels that are common across all sessions are considered, it is important to sanity-check that this overlap window capture most of our cells.

In [None]:
hv.output(size=int(output_size * 0.6))
opts_im = {
    'aspect': shiftds.sizes['width'] / shiftds.sizes['height'],
    'frame_width': 500, 'cmap': 'viridis'}
window = shiftds['temps_shifted'].isnull().sum('session')
window, _ = xr.broadcast(window, shiftds['temps_shifted'])
hv_wnd = hv.Dataset(window).to(hv.Image, ['width', 'height'])
hv_temps = hv.Dataset(temps_sh).to(hv.Image, ['width', 'height'])
hv_wnd.opts(**opts_im).relabel("Window") + hv_temps.opts(**opts_im).relabel("Shifted Templates")

## apply shifts and set window

If the shifts and overlaps all look good, we commit by applying them to the spatial footprints of each session.

In [None]:
A_shifted = apply_transform(minian_ds['A'].chunk(dict(height=-1, width=-1)), shiftds['shifts'])

In [None]:
def set_window(wnd):
    return wnd == wnd.min()
window = xr.apply_ufunc(
    set_window,
    window,
    input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']],
    vectorize=True)

# Cross-session registration

## calculate centroids

We start by calculating a centroid of spatial footprint for each cell.
The centroid location is the only source of information used to register cells across sessions.

In [None]:
%%time
cents = calculate_centroids(A_shifted, window)

## calculate centroid distance

We then calculate pairwise distance between cells in all pairs of sessions.
Note that at this stage, since we are computing something along the `session` dimension, it is no longer considered as a metadata dimension, so we remove it.

In [None]:
%%time
id_dims.remove("session")
dist = calculate_centroid_distance(cents, index_dim=id_dims)

## threshold centroid distances

We threshold the centroid distances and keep only cell pairs with distance less than `param_dist`.

In [None]:
dist_ft = dist[dist['variable', 'distance'] < param_dist].copy()
dist_ft = group_by_session(dist_ft)

## generate mappings

Finally we generate mapping of cells across session in three steps:

1. We filter the pairwise distances into pairwise mappings by applying a mutual nearest-neighbour criteria, using `calculate_mapping`.
1. We extend/merge pairwise mappings into multi-session mappings and drop any conficting mappings, using `resolve_mapping`.
1. We fill in "mappings" that represent cells only appeared in single sessions, using `fill_mapping`.

Please see the [API reference](https://minian.readthedocs.io/page/api/minian.cross_registration.html) for more detail on the output dataframe format.

In [None]:
%%time
mappings = calculate_mapping(dist_ft)
mappings_meta = resolve_mapping(mappings)
mappings_meta_fill = fill_mapping(mappings_meta, cents)
mappings_meta_fill.head()

## visualize mappings

We visualize the matching of cells by color-mapping cells 3 arbitrary sessions into RGB channels and plot the overlay image.
Please see [API reference](https://minian.readthedocs.io/page/api/minian.visualization.html#minian-visualization-AlignViewer) for more details on the tools available in this visualization.

In [None]:
hv.output(size=int(output_size * 0.7))
alnviewer = AlignViewer(minian_ds, cents, mappings_meta_fill, shiftds)
alnviewer.show()

## save results

If everything looks good, we commit by saving the mappings into `pickle` file.
Optionally we also save centroids `cents` and `shiftds` in case they come in handy in down-stream analysis.

In [None]:
mappings_meta_fill.to_pickle(os.path.join(dpath, "mappings.pkl"))
cents.to_pickle(os.path.join(dpath, "cents.pkl"))
shiftds.to_netcdf(os.path.join(dpath, "shiftds.nc"))

### save them as csv file

In [None]:
if isinstance(mappings_meta_fill, pd.DataFrame):
    mappings_meta_fill.to_csv(os.path.join(dpath, "mappings.csv"), index=False)
    print(f"mappings_meta_fill 已保存为 {os.path.join(dpath, 'mappings.csv')}")
else:
    print("mappings_meta_fill 不是 Pandas DataFrame，无法直接保存为 CSV。请检查其类型。")

# 2. 保存 cents 为 CSV
# 确保 cents 是一个 DataFrame 或 Series
if isinstance(cents, pd.DataFrame) or isinstance(cents, pd.Series):
    cents.to_csv(os.path.join(dpath, "cents.csv"), index=False)
    print(f"cents 已保存为 {os.path.join(dpath, 'cents.csv')}")
else:
    print("cents 不是 Pandas DataFrame 或 Series，无法直接保存为 CSV。请检查其类型。")

# 3. 保存 shiftds 为 CSV
# shiftds 是一个 xarray Dataset。直接保存为 CSV 不合适。
# 如果你确实需要保存其数据到 CSV，你需要将其转换为 Pandas DataFrame。
# 转换方法取决于 shiftds 的结构和你想在 CSV 中保留哪些数据。
# 例如，如果你想保存所有变量为表格：
if isinstance(shiftds, xr.Dataset):
    try:
        shiftds_df = shiftds.to_dataframe()
        shiftds_df.to_csv(os.path.join(dpath, "shiftds.csv"), index=True) # index=True 如果你的坐标是索引
        print(f"shiftds 已转换为 DataFrame 并保存为 {os.path.join(dpath, 'shiftds.csv')}")
    except Exception as e:
        print(f"无法将 shiftds (xarray Dataset) 转换为 DataFrame 并保存为 CSV: {e}")
        print("请检查 shiftds 的结构，以确定如何将其扁平化为表格格式。")
else:
    print("shiftds 不是 xarray Dataset，无法使用此方法处理。")

In [None]:
mappings.to_csv(os.path.join(dpath, 'mappings.csv'))

# furtther code for analysis

## load data

In [None]:
# # Open mappings.csv
# mappings_df = pd.read_csv(os.path.join(dpath,'Pre_sum\\mappings.csv'))
# print("First 5 rows of mappings.csv:")
# print(mappings_df.head())

# print("-" * 30)

# Open cents.csv
cents_df = pd.read_csv(os.path.join(dpath,'cents.csv'))
print("First 5 rows of cents.csv:")
print(cents_df.head())

# print("-" * 30)

# # Open shiftds.csv
# shiftds_df = pd.read_csv(os.path.join(dpath,'shiftds.csv'))
# print("First 5 rows of shiftds.csv:")
# print(shiftds_df.head())

In [None]:
# print("当前 mappings_df 的列名：")
# print(mappings_df.columns)

# print("/n当前 mappings_df 的前5行数据：")
# print(mappings_df.head())

In [None]:
from minian_plot_method.yra_c import calculate_yrc_std_and_save_separately

yra_csv_path1 = os.path.join(dpath, 'OFT', 'save_result','YrA.csv')
yrc_csv_path1 = os.path.join(dpath, 'OFT', 'save_result','C.csv')
output_csv_path1 = os.path.join(dpath, 'OFT', 'save_result','YrA-C_Std.csv')

yra_csv_path2 = os.path.join(dpath, 'TST', 'save_result','YrA.csv')
yrc_csv_path2 = os.path.join(dpath, 'TST', 'save_result','C.csv')
output_csv_path2 = os.path.join(dpath, 'TST', 'save_result','YrA-C_Std.csv')

calculate_yrc_std_and_save_separately(
    yra_csv_path1=yra_csv_path1,
    c_csv_path1=yrc_csv_path1,
    output_std_csv_path1=output_csv_path1,

    yra_csv_path2=yra_csv_path2,
    c_csv_path2=yrc_csv_path2,
    output_std_csv_path2=output_csv_path2,

    plot_unit_id=3
)

In [None]:
# std1 = pd.read_csv(output_csv_path1)
# std1


In [None]:
# 1. 确保您的 'minian_plotting_utils.py' 文件在当前工作目录或 Python 路径中
#    如果没有，请确保将其放置在正确的位置，或者添加到 sys.path

from minian_plot_method.plot_for_cross_regis import interactive_minian_analyzer
import pandas as pd
# 2. 定义您的数据文件路径
#    请根据您的实际文件位置修改这些路径
mappings_csv_path = os.path.join(dpath,'mappings.csv')
cents_csv_path = os.path.join(dpath,'cents.csv')
session1_nc_path = os.path.join(dpath,'sum\session1\minian.nc')
session2_nc_path = os.path.join(dpath,'sum\session2\minian.nc')
std_csv1_path = os.path.join(dpath,'OFT\save_result\YrA-C_Std.csv')
std_csv2_path = os.path.join(dpath,'TST\save_result\YrA-C_Std.csv')
peak_csv_path = os.path.join(dpath,'peak.csv')

picture_path = os.path.join(dpath, 'pictures')
os.makedirs(picture_path,exist_ok=True)
param={
        'peak_csv_path': peak_csv_path,
        'mapping_class_csv_path': dpath,
        'save_classification_prefix': 'c_spike_rate_classification',
        'prominence_C': 4.0,
        'distance_C': 5,
        'histogram_bins': 50,
        'max_classification_heatmaps': 100,
        'enable_normalization': False, # Default to enable normalization
        'plot_peak_stats_comparison': True, # Default to plot peak statistics comparison
        'plot_classification_heatmaps_flag': True, # Default to plot classification heatmaps
        'plot_overall_peak_histograms': True, # Default to plot overall peak histograms
        'save_peak_stats_csv': True, # Default to save peak statistics CSV
        'save_classification_csv': True, # Default to save classification CSV
        'show_selected_c_traces': True, # Default to show C traces for selected unit
        'save_individual_mapping_plots': True, # 是否批量保存每个映射对的图片
        'individual_plots_output_base_dir': os.path.join(picture_path, 'individual_mapping_plots') # 批量保存图片的根目录
    }


In [None]:

interactive_minian_analyzer(
    mappings_csv_path=mappings_csv_path,
    cents_csv_path=cents_csv_path,
    session1_nc_path=session1_nc_path,
    session2_nc_path=session2_nc_path,
    std_csv_path1=std_csv1_path, # 如果有 STD 文件，请传入
    std_csv_path2=std_csv2_path, # 如果有 STD 文件，请传入
    c_analysis_params=param,

)