# Setup

In [38]:
# widen jupyter notebook window
from IPython.display import display, HTML
display(HTML("<style>.container {width:95% !important; }</style>"))

# check environment
import os
print(f'Conda Environment: ' + os.environ['CONDA_DEFAULT_ENV'])

from platform import python_version
print(f'python version: {python_version()}')

Conda Environment: testroicat
python version: 3.11.3


In [2]:
import os
from pathlib import Path
import copy

import matplotlib.pyplot as plt
import numpy as np
import natsort

import torch

import gc
import time
import functools
import multiprocessing as mp


In [3]:
tic = time.time()
toc = {}
toc['start'] = time.time() - tic

In [4]:
# %load_ext autoreload
# %autoreload 2
import roicat

# Import paths

In [None]:
dir_allOuterFolders = str(Path(r"/media/rich/bigSSD/downloads_tmp/tmp_data/mouse_0322R/statFiles/").resolve())

pathSuffixToStat = 'stat.npy'
pathSuffixToOps = 'ops.npy'

paths_allStat = test = roicat.helpers.find_paths(
    dir_outer=dir_allOuterFolders,
    reMatch=pathSuffixToStat,
    depth=4,
)[:]
paths_allOps  = np.array([Path(path).resolve().parent / pathSuffixToOps for path in paths_allStat])[:]
# paths_allOps = test = roicat.helpers.find_paths(
#     dir_outer=dir_allOuterFolders,
#     reMatch=pathSuffixToOps,
#     depth=4,
# )[:]

print(f'paths to all stat files:');
[print(path) for path in paths_allStat];
print('');
print(f'paths to all ops files:');
[print(path) for path in paths_allOps];


In [6]:
toc['import_paths'] = time.time() - tic

# Import data

In [7]:
toc['import_data'] = time.time() - tic

In [None]:
data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_allStat[:],
    paths_opsFiles=paths_allOps[:],
    um_per_pixel=2.5,
    new_or_old_suite2p='new',

    out_height_width=[36,36],
    
    type_meanImg='meanImgE',
#     FOV_images=FOVs_mixed,

    verbose=True,
)

assert data.check_completeness(verbose=False)['tracking'], f"Data object is missing attributes necessary for tracking."

In [None]:
roicat.visualization.display_toggle_image_stack(data.FOV_images)
roicat.visualization.display_toggle_image_stack([sf.max(0).reshape(data.FOV_height, data.FOV_width).toarray() for sf in data.spatialFootprints], clim=[0,0.05])
roicat.visualization.display_toggle_image_stack(np.concatenate(data.ROI_images, axis=0)[:5000], image_size=(200,200))

# Alignment

In [223]:
aligner = roicat.tracking.alignment.Aligner(verbose=True)

In [249]:
FOV_images = aligner.augment_FOV_images(
    ims=data.FOV_images,
    spatialFootprints=data.spatialFootprints,
    roi_FOV_mixing_factor=0.5,
    use_CLAHE=True,
    CLAHE_grid_size=1,
    CLAHE_clipLimit=1,
    CLAHE_normalize=True,
)

In [None]:
roicat.visualization.display_toggle_image_stack(FOV_images)

In [None]:
aligner.fit_geometric(
#     template=FOV_images[4],
    template=0.5,
    ims_moving=FOV_images,
    template_method='sequential',
    mode_transform='homography',
    mask_borders=(50,50,50,50),
    n_iter=50,
    termination_eps=1e-09,
    gaussFiltSize=31,
    auto_fix_gaussFilt_step=10,
)

aligner.transform_images_geometric(FOV_images);

In [None]:
aligner.fit_nonrigid(
#     template=FOV_images[1],
    template=0.5,
    ims_moving=aligner.ims_registered_geo,
    remappingIdx_init=aligner.remappingIdx_geo,
    template_method='image',
    mode_transform='createOptFlow_DeepFlow',
    kwargs_mode_transform=None,
)

aligner.transform_images_nonrigid(FOV_images);

In [None]:
aligner.transform_ROIs(
    ROIs=data.spatialFootprints, 
    remappingIdx=aligner.remappingIdx_nonrigid,
    normalize=True,
);

In [None]:
roicat.visualization.display_toggle_image_stack(data.FOV_images)
roicat.visualization.display_toggle_image_stack(aligner.ims_registered_geo)
roicat.visualization.display_toggle_image_stack(aligner.ims_registered_nonrigid)
roicat.visualization.display_toggle_image_stack(aligner.get_ROIsAligned_maxIntensityProjection(), clim=(0, 0.05))

In [16]:
toc['alignment'] = time.time() - tic

## Blur ROIs (optional)

In [None]:
blurrer = roicat.tracking.blurring.ROI_Blurrer(
    frame_shape=(data.FOV_height, data.FOV_width),
    kernel_halfWidth=2,
    plot_kernel=False,
)

blurrer.blur_ROIs(
    spatialFootprints=aligner.ROIs_aligned[:],
)

In [None]:
roicat.visualization.display_toggle_image_stack(blurrer.get_ROIsBlurred_maxIntensityProjection())

In [19]:
toc['blur'] = time.time() - tic

## Neural network embedding distances

In [None]:
roinet = roicat.ROInet.ROInet_embedder(
    device='cuda:0',
    dir_networkFiles=r'/home/rich/Desktop/tmp_data/',
    download_method='check_local_first',
    download_url='https://osf.io/x3fd2/download',
    download_hash='7a5fb8ad94b110037785a46b9463ea94',
    forward_pass_version='latent',
    verbose=True,
)

In [None]:
roinet.generate_dataloader(
    ROI_images=data.ROI_images,
    um_per_pixel=data.um_per_pixel,
    pref_plot=False,
    
    jit_script_transforms=False,
    
    batchSize_dataloader=8,
    pinMemory_dataloader=True,
    numWorkers_dataloader=mp.cpu_count(),
    persistentWorkers_dataloader=True,
    prefetchFactor_dataloader=2,    

#     batchSize_dataloader=1,
#     pinMemory_dataloader=False,
#     numWorkers_dataloader=0,
#     persistentWorkers_dataloader=False,
#     prefetchFactor_dataloader=2,    
);

In [None]:
roicat.visualization.display_toggle_image_stack(roinet.ROI_images_rs[:5000], image_size=(200,200))

In [None]:
roinet.generate_latents();

In [25]:
toc['NN'] = time.time() - tic

## Scattering wavelet embedding distances

In [None]:
swt = roicat.tracking.scatteringWaveletTransformer.SWT(
    kwargs_Scattering2D={'J': 2, 'L': 12}, 
    image_shape=data.ROI_images[0].shape[1:3], 
    device='cuda:0'
)

swt.transform(
    ROI_images=np.concatenate(data.ROI_images, axis=0),
    batch_size=100,
);

In [28]:
toc['SWT'] = time.time() - tic

## Compute similarities

In [None]:
sim = roicat.tracking.similarity_graph.ROI_graph(
    n_workers=-1,
    frame_height=data.FOV_height,
    frame_width=data.FOV_width,
#     block_height=data.FOV_height,
#     block_width=data.FOV_width,
    block_height=128,
    block_width=128,
    algorithm_nearestNeigbors_spatialFootprints='brute',
    verbose=True,
)

sim.visualize_blocks()

s_sf, s_NN, s_SWT, s_sesh = sim.compute_similarity_blockwise(
    spatialFootprints=blurrer.ROIs_blurred,
#     spatialFootprints=aligner.ROIs_aligned,
    features_NN=roinet.latents,
    features_SWT=swt.latents,
    ROI_session_bool=data.session_bool,
    spatialFootprint_maskPower=1.0,
);

In [30]:
sim.make_normalized_similarities(
    centers_of_mass=data.centroids,
    features_NN=roinet.latents,
    features_SWT=swt.latents,
#     features_SWT=None,
    k_max=4000,
    k_min=100,
    algo_NN='kd_tree',
    device='cuda:0'
)

toc['sim'] = time.time() - tic

Finding k-range of center of mass distance neighbors for each ROI...
Normalizing Neural Network similarity scores...


100%|██████████████████████████████████| 44419/44419 [00:03<00:00, 11777.29it/s]


Normalizing SWT similarity scores...


 46%|████████████████▎                  | 20628/44419 [00:19<00:22, 1068.97it/s]

# Clustering

In [None]:
clusterer = roicat.tracking.clustering.Clusterer(
    s_sf=sim.s_sf,
    s_NN_z=sim.s_NN_z,
    s_SWT_z=sim.s_SWT_z,
    s_sesh=sim.s_sesh,
)

kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning(
    n_bins=None,
    smoothing_window_bins=None,
    find_parameters_automatically=True,
    kwargs_findParameters={
        'n_patience': 300,
        'tol_frac': 0.001,
        'max_trials': 1200,
        'max_duration': 60*10,
        'verbose': False,
    },
    bounds_findParameters={
        'power_NN': (0., 5.),
        'power_SWT': (0., 5.),
        'p_norm': (-5, 0),
        'sig_NN_kwargs_mu': (0., 1.0),
        'sig_NN_kwargs_b': (0.00, 1.5),
        'sig_SWT_kwargs_mu': (0., 1.0),
        'sig_SWT_kwargs_b': (0.00, 1.5),
    },
    n_jobs_findParameters=-1,    
)

toc['separate_diffSame'] = time.time() - tic

In [None]:
kwargs_mcdm_tmp = kwargs_makeConjunctiveDistanceMatrix_best

# kwargs_mcdm_tmp = {
#     'power_SF': 1.0,
#     'power_NN': 1.0,
#     'power_SWT': 0.0,
#     'p_norm': -4.0,
# #         'sig_SF_kwargs': {'mu':0.5, 'b':1.0},
#     'sig_SF_kwargs': None,
#         'sig_NN_kwargs': {'mu':1.0, 'b':0.5},
# #     'sig_NN_kwargs': None,
# #         'sig_SWT_kwargs': {'mu':0.5, 'b':1.0},
#     'sig_SWT_kwargs': None,
# }

clusterer.plot_distSame(kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp)

clusterer.plot_similarity_relationships(
    plots_to_show=[1,2,3], 
    max_samples=100000, 
    kwargs_scatter={'s':1, 'alpha':0.2},
    kwargs_makeConjunctiveDistanceMatrix=kwargs_mcdm_tmp
);

In [None]:
clusterer.make_pruned_similarity_graphs(
    d_cutoff=None,
    kwargs_makeConjunctiveDistanceMatrix=kwargs_makeConjunctiveDistanceMatrix_best,
    stringency=1.0,
    convert_to_probability=False,    
)

In [None]:
if data.n_sessions >= 8:
    labels = clusterer.fit(
        d_conj=clusterer.dConj_pruned,
        session_bool=data.session_bool,
        min_cluster_size=2,
        n_iter_violationCorrection=6,
        cluster_selection_method='leaf',
        d_clusterMerge=None,
        alpha=0.999,
        split_intraSession_clusters=True,
        discard_failed_pruning=True,
        n_steps_clusterSplit=100,
    )

else:
    labels = clusterer.fit_sequentialHungarian(
        session_bool=data.session_bool,
        thresh_cost=0.6,
        d_conj=clusterer.dConj_pruned,
    )

In [None]:
clusterer.compute_cluster_quality_metrics()

labels_bySession = [labels[idx] for idx in data.session_bool.T]

In [None]:
## results_clustering
print(f'Number of clusters: {len(np.unique(labels))}')
print(f'Number of discarded ROIs: {(labels==-1).sum()}')

In [None]:
toc['clustering'] = time.time() - tic

# Visualize results

In [None]:
# confidence = (clusterer.cluster_quality_metrics['cs_sil'] * clusterer.cluster_quality_metrics['cs_intra_means'] * clusterer.cluster_quality_metrics['cs_intra_mins'])
# confidence = (clusterer.cluster_quality_metrics['cs_sil'] * clusterer.cluster_quality_metrics['cs_intra_means'])
confidence = (((clusterer.cluster_quality_metrics['cs_sil'] + 1) / 2) * clusterer.cluster_quality_metrics['cs_intra_means'])
# confidence = (clusterer.cluster_quality_metrics['cs_intra_means'])
# confidence = (clusterer.cluster_quality_metrics['cs_intra_mins'])
# confidence = clusterer.cluster_quality_metrics['cs_sil']

plt.figure()
plt.hist(confidence, 50);
plt.xlabel('confidence');
plt.ylabel('cluster counts')

In [218]:
FOV_clusters = roicat.visualization.compute_colored_FOV(
    spatialFootprints=[r.power(0.7) for r in aligner.ROIs_aligned],
    FOV_height=data.FOV_height,
    FOV_width=data.FOV_width,
    session_bool=data.session_bool,
    labels=labels,
#     alphas_labels=np.clip(confidence / np.percentile(confidence, 90), a_min=0.0, a_max=1.0),
#     alphas_labels=(confidence < 0.1) * (confidence > 0.0),
    alphas_labels=confidence < 0.2,
#     alphas_labels=confidence < 0.5,
#     alphas_labels=(confidence + 1)/2,
#     alphas_labels=confidence,
#     alphas_labels = 1 - confidence,
#     alphas_sf=clusterer.hdbs.probabilities_[:-1] < 0.5,
#     alphas_sf=clusterer.hdbs.outlier_scores_[:-1],
#     alphas_sf=test2,
)

In [None]:
roicat.visualization.display_toggle_image_stack(
    FOV_clusters, 
    image_size=(np.array(FOV_clusters[0].shape)*1.5).astype(int)[:2],
    clim=[0,1],
)

In [None]:
_, counts = np.unique(labels, return_counts=True)

plt.figure()
plt.hist(counts, data.n_sessions*2 + 1, range=(0, data.n_sessions+1));


In [46]:
toc['visualize'] = time.time() - tic

# Save results

In [None]:
dir_save = Path('/media/rich/bigSSD/analysis_data/ROICaT/ROI_tracking/mouse_0322R/analysis_DAC_20230516/').resolve()
name_save = Path(dir_allOuterFolders).resolve().name
# name_save = 'RE257'
# path_save = dir_save / (name_save + '.ROICaT.tracking.results' + '.pkl')
path_save = dir_save / (name_save + '.ROICaT.tracking.results' + '.pkl')
print(f'path_save: {path_save}')

In [59]:
ROIs = {
    "ROIs_aligned": aligner.ROIs_aligned,
    "ROIs_raw": data.spatialFootprints,
    "frame_height": data.FOV_height,
    "frame_width": data.FOV_width,
    "idx_roi_session": np.where(data.session_bool)[1]
}

results = {
    "UCIDs": labels,
    "UCIDs_bySession": labels_bySession,
    "ROIs": ROIs,
    "input_data": {
        "paths_stat": data.paths_stat,
        "paths_ops": data.paths_ops,
    },
    "cluster_quality_metrics": clusterer.cluster_quality_metrics,
}

roicat.helpers.pickle_save(
    obj=results,
    path_save=path_save,
    mkdir=True,
)

run_data = copy.deepcopy({
    'data': data.serializable_dict,
    'aligner': aligner.serializable_dict,
    'blurrer': blurrer.serializable_dict,
    'roinet': roinet.serializable_dict,
    'swt': swt.serializable_dict,
    'sim': sim.serializable_dict,
    'clusterer': clusterer.serializable_dict,
})
roicat.helpers.pickle_save(
    obj=run_data,
    path_save=str(dir_save / (name_save + '.ROICaT.tracking.rundata' + '.pkl')),
    mkdir=True,
)


In [60]:
toc['saving'] = time.time() - tic

In [None]:
toc