# ROICaT Tracking Result Visualization Colab Notebook

Welcome to the visualization notebook! This Colab notebook allows you to easily visualize tracking results **without installing ROICaT**. \



**More information**

FAQ on how to play with ROICaT is [here](https://roicat.readthedocs.io/en/dev/).

If you have any questions not covered in the FAQ, please don't hesitate to open a new issue tap on [ROICaT github page](https://github.com/RichieHakim/ROICaT/issues).

**The notebook proceeds as follows:**
1. **Import** libraries
2. Define **paths** to ROICaT results
3. Load **ROICaT tracking** data.
4. (Optional) Load **ROICaT classification** data to **discard bad ROIs**
5. (Optional) **Discard** ROIs that are **not tracked** very well across sessions
6. **Visualize** results

# SETUP

In [None]:
#@title Load modules

# Load basic modules
import os
from pathlib import Path
import zipfile
import pickle
import copy

import numpy as np
import scipy.sparse
import natsort
import re
import hashlib

import torch

# Load plotting modules
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import colorsys
from ipywidgets import widgets, Button, Output, interact
from IPython.display import display

In [None]:
#@title Helper Functions
def simple_hash(array):
  array_bytes = array.tobytes()
  hash_object = hashlib.sha256(array_bytes)
  hash_output = hash_object.hexdigest()
  return hash_output


def find_ref_session(ref_array, target_array):
  """
  Hardcoded for ROICaT results. Find index of matching array in target array list.
  """
  ref_hash = simple_hash(ref_array.__dict__['data'])
  for index, target in enumerate(target_array): 
    hash_match = (ref_hash == simple_hash(target['data']))
    if hash_match:
      return index


def broadcast_label(selected_ROIs_indices, ucids, ref_session=None):
    """
    Discards UCIDs that are not classified as good ROIs.
    Provided with classified indices for all sessions, this function simply keeps selected indices.
    Provided with classified indices for a single reference session, this function broadcasts indices from a reference session to all session.
    """
    if ref_session is None:
      assert len(selected_ROIs_indices) == len(ucids), f'!!!Missing classification results!!!'
      labels = [session_ROIs_indices * session_ucids  - (np.logical_not(session_ROIs_indices)) for session_ROIs_indices, session_ucids in zip(selected_ROIs_indices, ucids)]
    else:
      assert isinstance(ref_session, int), f'!!!ref_session should be int!!!'
      if isinstance(selected_ROIs_indices, list):
        assert len(selected_ROIs_indices) == 1
        selected_ROIs_indices = selected_ROIs_indices[0]
      ref_label = selected_ROIs_indices * ucids[ref_session]  - (np.logical_not(selected_ROIs_indices))
      set1 = set(np.unique(ref_label))
      labels = copy.deepcopy(ucids)
      for session in labels:
        session[:] = [label if label in set1 else -1 for label in session]
    return labels


def clear_n_sesh_thresh(n_sesh_thresh, max_sesh):
  n_sesh_thresh = int(n_sesh_thresh)
  if n_sesh_thresh <= 1:
    n_sesh_thresh = 2
  if n_sesh_thresh > max_sesh:
    n_sesh_thresh = max_sesh
  return n_sesh_thresh


def discard_UCIDs_with_fewer_matches(
    ucids, 
    n_sesh_thresh='all',
    ref_session=0,
    verbose=True,
):
    """
    Discards UCIDs that do not appear in at least n_sesh_thresh sessions.
    If n_sesh_thresh='all', then only UCIDs that appear in all sessions are kept.

    RH 2023
    """
    ucids_out = copy.deepcopy(ucids)
    n_sesh = len(ucids)
    n_sesh_thresh = n_sesh if n_sesh_thresh == 'all' else n_sesh_thresh
    assert isinstance(n_sesh_thresh, int)
    ucids_unique = np.unique(np.concatenate(ucids_out))
    
    ucids_inAllSesh = [u for u in ucids_unique if np.array([np.isin(u, u_sesh) for u_sesh in ucids_out]).sum() >= n_sesh_thresh]
    if verbose:
        fraction = (np.unique(ucids_inAllSesh) >= 0).sum() / (np.unique(ucids_out[ref_session]) >= 0).sum()
        print(f'Reference session: {ref_session}')
        print(f'INFO: {fraction*100:.2f}% of UCIDs in reference session appear in at least {n_sesh_thresh} sessions.')
    ucids_out = [[val * np.isin(val, ucids_inAllSesh) - np.logical_not(np.isin(val, ucids_inAllSesh)) for val in u] for u in ucids_out]
    
    return ucids_out


def find_paths(
    dir_outer, 
    reMatch='filename', 
    find_files=True, 
    find_folders=False, 
    depth=0, 
    natsorted=True, 
):
    """
    Search for files and/or folders recursively in a directory.
    RH 2022
    """
    def get_paths_recursive_inner(dir_inner, depth_end, depth=0):
        paths = []
        for path in os.listdir(dir_inner):
            path = os.path.join(dir_inner, path)
            if os.path.isdir(path):
                if find_folders:
                    if re.search(reMatch, path) is not None:
                        paths.append(path)
                if depth < depth_end:
                    paths += get_paths_recursive_inner(path, depth_end, depth=depth+1)
            else:
                if find_files:
                    if re.search(reMatch, path) is not None:
                        paths.append(path)
        return paths

    paths = get_paths_recursive_inner(dir_outer, depth, depth=0)
    if natsorted:
        paths = natsort.natsorted(paths)
    return paths


def pickle_load(
    filename, 
    zipCompressed=False,
    mode='rb'
):
    """
    Loads a pickle file.
    Allows for loading of zipped pickle files.
    RH 2022
    """
    if zipCompressed:
        with zipfile.ZipFile(filename, 'r') as f:
            return pickle.loads(f.read('data'))
    else:
        with open(filename, mode) as f:
            return pickle.load(f)


def squeeze_integers(intVec):
    """
    Make integers in an array consecutive numbers
     starting from the smallest value. 
    ie. [7,2,7,4,-1,0] -> [3,2,3,1,-1,0].
    Useful for removing unused class IDs.
    This is v3.
    RH 2023
    """
    if isinstance(intVec, list):
        intVec = np.array(intVec, dtype=np.int64)
    if isinstance(intVec, np.ndarray):
        unique, arange = np.unique, np.arange
    elif isinstance(intVec, torch.Tensor):
        unique, arange = torch.unique, torch.arange
        
    u, inv = unique(intVec, return_inverse=True)  ## get unique values and their indices
    u_min = u.min()  ## get the smallest value
    u_s = arange(u_min, u_min + u.shape[0], dtype=u.dtype)  ## make consecutive numbers starting from the smallest value
    return u_s[inv]  ## return the indexed consecutive unique values


def make_session_bool(n_roi):
    """
    Makes a session_bool array from an n_roi array.
    """
    n_roi_total = np.sum(n_roi)
    r = np.arange(n_roi_total, dtype=np.int64)
    n_roi_cumsum = np.concatenate([[0], np.cumsum(n_roi)])
    session_bool = np.vstack([(b_lower <= r) * (r < b_upper) for b_lower, b_upper in zip(n_roi_cumsum[:-1], n_roi_cumsum[1:])]).T
    return session_bool

In [None]:
#@title Visualization Functions
def compute_colored_FOV(
    spatialFootprints,
    FOV_height,
    FOV_width,
    labels,
    cmap='random',
    alphas_labels=None,
    alphas_sf=None,
):
    """
    Computes a set of images of FOVs of spatial footprints, colored
     by the predicted class.
    """
    spatialFootprints = [spatialFootprints] if isinstance(spatialFootprints, np.ndarray) else spatialFootprints

    ## Check inputs
    assert all([scipy.sparse.issparse(sf) for sf in spatialFootprints]), "spatialFootprints must be a list of scipy.sparse.csr_matrix"

    n_roi = np.array([sf.shape[0] for sf in spatialFootprints], dtype=np.int64)
    n_roi_cumsum = np.concatenate([[0], np.cumsum(n_roi)]).astype(np.int64)
    n_roi_total = sum(n_roi)

    def _fix_list_of_arrays(v):
        if isinstance(v, np.ndarray) or (isinstance(v, list) and isinstance(v[0], (np.ndarray, list)) is False):
            v = [v[b_l: b_u] for b_l, b_u in zip(n_roi_cumsum[:-1], n_roi_cumsum[1:])]
        assert (isinstance(v, list) and isinstance(v[0], (np.ndarray, list))), "input must be a list of arrays or a single array of integers"
        return v
    
    labels = _fix_list_of_arrays(labels)
    alphas_sf = _fix_list_of_arrays(alphas_sf) if alphas_sf is not None else None

    labels_cat = np.concatenate(labels)
    u = np.unique(labels_cat)
    n_c = len(u)

    if alphas_labels is None:
        alphas_labels = np.ones(n_c)
    alphas_labels = np.clip(alphas_labels, a_min=0, a_max=1)
    assert len(alphas_labels) == n_c, f"len(alphas_labels)={len(alphas_labels)} != n_c={n_c}"

    if alphas_sf is None:
        alphas_sf = np.ones(len(labels_cat))
    if isinstance(alphas_sf, list):
        alphas_sf = np.concatenate(alphas_sf)
    alphas_sf = np.clip(alphas_sf, a_min=0, a_max=1)
    assert len(alphas_sf) == len(labels_cat), f"len(alphas_sf)={len(alphas_sf)} != len(labels_cat)={len(labels_cat)}"
    
    h, w = FOV_height, FOV_width

    rois = scipy.sparse.vstack(spatialFootprints)
    rois = rois.multiply(1.0/rois.max(1).toarray()).power(1)

    if n_c > 1:
        colors = rand_cmap(nlabels=n_c, verbose=False)(np.linspace(0.,1.,n_c, endpoint=True)) if cmap=='random' else cmap(np.linspace(0.,1.,n_c, endpoint=True))
        colors = colors / colors.max(1, keepdims=True)
    else:
        colors = np.array([[0,0,0,0]])

    if np.isin(-1, labels_cat):
        colors[0] = [0,0,0,0]

    labels_squeezed = squeeze_integers(labels_cat)
    labels_squeezed -= labels_squeezed.min()

    rois_c = scipy.sparse.hstack([rois.multiply(colors[labels_squeezed, ii][:,None]) for ii in range(4)]).tocsr()
    rois_c.data = np.minimum(rois_c.data, 1)

    ## apply alpha
    rois_c = rois_c.multiply(alphas_labels[labels_squeezed][:,None] * alphas_sf[:,None]).tocsr()

    ## make session_bool
    session_bool = make_session_bool(n_roi)

    rois_c_bySessions = [rois_c[idx] for idx in session_bool.T]

    rois_c_bySessions_FOV = [r.max(0).toarray().reshape(4, h, w).transpose(1,2,0)[:,:,:3] for r in rois_c_bySessions]

    return rois_c_bySessions_FOV

def crop_cluster_ims(ims):
    """
    Crops the images to the smallest rectangle containing all non-zero pixels.
    RH 2022
    """
    ims_max = np.max(ims, axis=0)
    z_im = ims_max > 0
    z_where = np.where(z_im)
    z_top = z_where[0].max()
    z_bottom = z_where[0].min()
    z_left = z_where[1].min()
    z_right = z_where[1].max()
    
    ims_copy = copy.deepcopy(ims)
    im_out = ims_copy[:, max(z_bottom-1, 0):min(z_top+1, ims.shape[1]), max(z_left-1, 0):min(z_right+1, ims.shape[2])]
    im_out[:,(0,-1),:] = 1
    im_out[:,:,(0,-1)] = 1
    return im_out

def rand_cmap(
    nlabels, 
    first_color_black=False, 
    last_color_black=False,
    verbose=True,
    under=[0,0,0],
    over=[0.5,0.5,0.5],
    bad=[0.9,0.9,0.9],
    ):
    """
    Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
    """
    assert nlabels > 0, 'Number of labels must be greater than 0'

    if verbose:
        print('Number of labels: ' + str(nlabels))

    randRGBcolors = np.random.rand(nlabels, 3)
    randRGBcolors = randRGBcolors / np.max(randRGBcolors, axis=1, keepdims=True)

    if first_color_black:
        randRGBcolors[0] = [0, 0, 0]

    if last_color_black:
        randRGBcolors[-1] = [0, 0, 0]

    random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)

    # Display colorbar
    if verbose:
        from matplotlib import colors, colorbar
        fig, ax = plt.subplots(1, 1, figsize=(6, 0.5))

        bounds = np.linspace(0, nlabels, nlabels + 1)
        norm = colors.BoundaryNorm(bounds, nlabels)

        cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
                                   boundaries=bounds, format='%1i', orientation=u'horizontal')

    random_colormap.set_bad(bad)
    random_colormap.set_over(over)
    random_colormap.set_under(under)

    return random_colormap

In [None]:
#@title Interactive Widget Functions

# Checkbox widget
def checkbox_widget(input_list):
  checkboxes = []
  for input in input_list:
    checkbox = widgets.Checkbox(description=input, value=False)
    checkboxes.append(checkbox)

  selected_indices = []
  return checkboxes, selected_indices

def get_selected_indices(checkboxes):
  indices = [i for i, checkbox in enumerate(checkboxes) if checkbox.value]
  return indices

def make_checkbox_button_clicked(checkboxes, selected_indices):
  def checkbox_button_clicked(_):
    # _ is a placeholder argument to handle button click event
    # Clear the list
    selected_indices.clear()
    # Add the newly selected indices
    selected_indices.extend(get_selected_indices(checkboxes))
  return checkbox_button_clicked


# Slider widget
def create_image_widget_function(images, dpi = 96, cmap = 'gray'):
  def image_widget(session):
    plt.figure(figsize = (images[session].shape[1] / dpi, images[session].shape[0] / dpi))
    plt.imshow(images[session], cmap=cmap)
    plt.show()
  return image_widget

def display_image_stack(images, dpi = 96, cmap = 'gray'):
  slider = widgets.IntSlider(min=0, max=len(images)-1, continuous_update=False)
  widgets.interact(create_image_widget_function(images, dpi = dpi, cmap = cmap), session = slider)

# Import paths

### Mount google drive (OPTION 1: RECOMMENDED)

In [None]:
#@markdown You can upload your data onto Google Drive and mount the drive to access.

#@markdown This process allows you access to your data directory on Google Drive.

from google.colab import drive
path_gdrive = '/content/gdrive'
drive.mount(path_gdrive, force_remount=True)

In [None]:
#@markdown ### Enter your ROICaT tracking result files directory:
dir_ROICaT = '/content/gdrive/MyDrive' #@param {type:"string"}

### Upload file from local (OPTION 2)

In [None]:
#@markdown This cell allows you to load files (**not directories**) from local. Please note that loading file from local to Colab cloud is pretty slow. Also, you should **LOAD FILE AGAIN** if runtime terminates.


#@markdown You can upload **individual ROICaT result files** or **a zip file** that contains each files. This cell automatically detects and extracts uploaded zip file.


from google.colab import files
uploaded = files.upload()

for file_name in list(uploaded.keys()):
  file_path = Path.cwd() / file_name
  if file_path.exists():
    if zipfile.is_zipfile(file_path):
      print("Zip file detected: ", file_path)
      with zipfile.ZipFile(file_path, 'r') as zip_ref:
        zip_ref.extractall(Path.cwd())
        print("Zip file extracted to: ", Path.cwd())
    else:
      print("Uploaded file is not zip file: ", file_path)
  else:
    print("Failed to detect uploaded file path: ", file_path)

dir_ROICaT = os.getcwd()

### Select data paths to visualize

In [None]:
#@markdown **Find ROICaT result files to visualize:**

tracking_result = 'tracking.results.pkl' #@param {type:"string"}
tracking_rundata = 'tracking.rundata.pkl' #@param {type:"string"}
#@markdown Automatically searches for matching file names.
#@markdown
#@markdown **Default: tracking.results.pkl, tracking.rundata.pkl**, _type: str_

#@markdown \
#@markdown 
#@markdown \


#@markdown 1. By default, this notebook visualizes all ROIs in tracked sessions.
#@markdown 2. If you have a **single** ROICaT classification file for **all tracked sessions**, you can visualize **good ROIs per each session** only.
#@markdown 3. If you have a **single** ROICaT classification file for **one of the tracked sessions**, you can visualize **good ROIs for that session and track those ROIs over different sessions**.
good_ROIs_only = True #@param {type:"boolean"}
#@markdown If True, visualize _good ROIs_ selectively. You should have a **ROICaT classification result file**.
#@markdown
#@markdown **Default: False**
classify_result =  'classification_drawn.results.pkl' #@param {type:"string"}
#@markdown Automatically searches for matching file names. Only matters if _good_ROIs_only_ is True
#@markdown
#@markdown **Default: classification_drawn.results.pkl**, _type: str_

paths_trackingResult = find_paths(dir_outer=dir_ROICaT, reMatch=tracking_result, depth=8)
paths_trackingRundata = find_paths(dir_outer=dir_ROICaT, reMatch=tracking_rundata, depth=8)
print('paths_trackingResult')
display(paths_trackingResult)
print('paths_trackingRundata')
display(paths_trackingRundata)

# Default ROICaT results
paths_alltrack = paths_trackingResult
paths_allrun = paths_trackingRundata
if len(paths_trackingResult) == 1:
  selected_indices = [0]

if good_ROIs_only:
  paths_classifyResult = find_paths(dir_outer=dir_ROICaT, reMatch=classify_result, depth=8)
  print('paths_classifyResult')
  display(paths_classifyResult)
  paths_allclassify = paths_classifyResult
  if len(paths_classifyResult) == 0:
    good_ROIs_only = False
    print("Failed to find classification result file. Please double-check classification file path.")

### Load paths for tracking results

In [None]:
#@markdown If you have more than one ROICaT tracking result files, you can **check and select** result file to visualize.
if len(paths_trackingResult) != 1:
  tracking_list = ['/'.join(Path(tracking).parts[-2:]) for tracking in paths_trackingResult]
  tracking_checkboxes, selected_indices = checkbox_widget(tracking_list)
  checkbox_display = widgets.VBox(tracking_checkboxes)
  display(checkbox_display)
  button = Button(description="Get Tracking Result")
  button.on_click(make_checkbox_button_clicked(tracking_checkboxes, selected_indices))

  display(button)

In [None]:
#@markdown Selected tracking result file to visualize:

tracking_selected_indices = rundata_selected_indices = selected_indices

paths_alltrack = [paths_trackingResult[i] for i in tracking_selected_indices]
print("Tracking result file to visualize")
display(paths_alltrack)
paths_allrun = [paths_trackingRundata[i] for i in rundata_selected_indices]
print("Tracking rundata file to visualize")
display(paths_allrun)

In [None]:
#@markdown Optional: If selected rundata file does not correspond with the tracking result file, you can manually curate the error:
rundata_list = ['/'.join(Path(rundata).parts[-2:]) for rundata in paths_trackingRundata]
rundata_checkboxes, rundata_selected_indices = checkbox_widget(rundata_list)
checkbox_display = widgets.VBox(rundata_checkboxes)
display(checkbox_display)
button = Button(description="Get Rundata Result")
button.on_click(make_checkbox_button_clicked(rundata_checkboxes, rundata_selected_indices))

display(button)


### (Optional) Load path for classification result

In [None]:
#@markdown Please **check and select** matching classification result file.
#@markdown
#@markdown Only matters if _good_ROIs_only_ is True **and** you have multiple classification result files.
if (good_ROIs_only) & (len(paths_classifyResult) > 1):
  classify_list = ['/'.join(Path(classify).parts[-2:]) for classify in paths_classifyResult]
  classify_checkboxes, classify_selected_indices = checkbox_widget(classify_list)
  checkbox_display = widgets.VBox(classify_checkboxes)
  display(checkbox_display)
  button = Button(description="Get Classification Result")
  button.on_click(make_checkbox_button_clicked(classify_checkboxes, classify_selected_indices))

  display(button)
elif (good_ROIs_only) & (len(paths_classifyResult) == 1):
  print(f"You have only one classification result file: {paths_classifyResult[0]}")
  classify_selected_indices = [0]
else:
  print("good_ROIs_only is False")

In [None]:
#@markdown Selected classification result file to visualize:

paths_allclassify = [paths_classifyResult[i] for i in classify_selected_indices]
print("Classification result file to visualize")
display(paths_allclassify)

### (Optional) Define ROICaT file path by yourself

In [None]:
#@markdown If code above does not work for you very well, you can explicitly define the path for each files. 

Define_file_path = False #@param {type:"boolean"}
#@markdown If True, load ROICaT result files from paths below. **Default: False**

tracking_result_file = '/content/gdrive/MyDrive/20230612-005313_ROICaT.tracking.results.pkl' #@param {type:"string"}
tracking_rundata_file = '/content/gdrive/MyDrive/20230612-005313_ROICaT.tracking.rundata.pkl' #@param {type:"string"}
classify_result_file =  '/content/gdrive/MyDrive/20230612-144045_ROICaT.classification_drawn.results.pkl' #@param {type:"string"}

if Define_file_path:
  paths_trackingResult = [tracking_result_file]
  paths_trackingRundata = [tracking_rundata_file]
  paths_classifyResult = [classify_result_file]
  tracking_selected_indices = rundata_selected_indices = classify_selected_indices = [0]

### Finalize files to visualize

In [None]:
#@markdown Files to visualize, finalized:

paths_alltrack = [paths_trackingResult[i] for i in tracking_selected_indices]
print("Tracking result file to visualize")
display(paths_alltrack)
paths_allrun = [paths_trackingRundata[i] for i in rundata_selected_indices]
print("Tracking rundata file to visualize")
display(paths_allrun)
if good_ROIs_only:
  paths_allclassify = [paths_classifyResult[i] for i in classify_selected_indices]
  print("Classification result file to visualize")
  display(paths_allclassify)
else:
  print("No classification result file")

# Here's your ROICaT tracking results!

Look at some of the distributions of the quality metrics.

Silhouette score is a particularly useful one for this type of clustering. [Learn more here]()
We also define a handy 'confidence' variable which is a nice heuristic you can use for thresholding for inclusion criteria.
Note that the sample_silhouette score is a per-sample (per-ROI) score. So it can actually be used to remove / subselect ROIs from clusters.

In [None]:
#@markdown Load data...
use_classificationResults = good_ROIs_only

results_tracking = pickle_load(paths_alltrack[0])
rundata_tracking = pickle_load(paths_allrun[0])

if use_classificationResults:
  results_classification = pickle_load(paths_allclassify[0])
  if len(results_classification['spatialFootprints']) == len(rundata_tracking['data']['spatialFootprints']):
    print(f"Classification is done for all session")
    ref_session = None
  elif len(results_classification['spatialFootprints']) == 1:
    print(f"Classification is done for one session...")
    ref_session = find_ref_session(results_classification['spatialFootprints'][0], rundata_tracking['data']['spatialFootprints'])
    print(f"Reference session: {ref_session}") if ref_session is not None else print(f"!!!Failed to find matching reference session!!!")
  else:
    print(f"!!!Number of classified sessions and tracked sessions do not match!!!")
else:
    results_classification = None
    ref_session = None

In [None]:
%matplotlib inline

#@markdown Visualize aligned FOV...

dpi = 96 #@param {type:"number"}
#@markdown Pixel per inch. Controls image size. Smaller dpi, larger image. **Default: 96**, _type: int_

cmap = 'gray' #@param ['gray', 'viridis'] {type:"raw"}
#@markdown Image colormap. **Default: 'gray'**

display_image_stack(rundata_tracking['aligner']['ims_registered_nonrigid'], dpi = dpi, cmap = cmap)

In [None]:
#@markdown Distributions of the quality metrics...
confidence = (((results_tracking['quality_metrics']['cluster_silhouette'] + 1) / 2) * results_tracking['quality_metrics']['cluster_intra_means'])

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7))

axs[0,0].hist(results_tracking['quality_metrics']['cluster_silhouette'], 50);
axs[0,0].set_xlabel('cluster_silhouette');
axs[0,0].set_ylabel('cluster counts');

axs[0,1].hist(results_tracking['quality_metrics']['cluster_intra_means'], 50);
axs[0,1].set_xlabel('cluster_intra_means');
axs[0,1].set_ylabel('cluster counts');

axs[1,0].hist(confidence, 50);
axs[1,0].set_xlabel('confidence');
axs[1,0].set_ylabel('cluster counts');

axs[1,1].hist(results_tracking['quality_metrics']['sample_silhouette'], 50);
axs[1,1].set_xlabel('sample_silhouette score');
axs[1,1].set_ylabel('roi sample counts');

In [None]:
%matplotlib inline
#@markdown Look at a color visualization of the results. ROIs of the same color are considered a part of the same cluster. The colors are assigned randomly.

#@markdown \

n_sesh_thresh = 3 #@param {type:"number"}
#@markdown Number of sessions that a ROI must appear in to be kept. In other words, ROIs that are **not tracked for at least this number of sessions** are discarded.
#@markdown 
#@markdown If this number is **bigger** than the number of tracked sessions, then only ROIs that appear in **all sessions** are kept.
#@markdown
#@markdown **Default: 2**, _type: int >= 2_

#@markdown \
dpi = 96 #@param {type:"number"}
#@markdown Pixel per inch. Controls image size. Smaller dpi, larger image. **Default: 96**, _type: int_

if use_classificationResults:
  labels = broadcast_label(results_classification['preds'], results_tracking["clusters"]["labels_bySession"], ref_session)
  print(f"Number of good ROIs: {len(np.unique(np.concatenate(labels)))}")
  rois_good = "good "
else:  
  labels = results_tracking["clusters"]["labels_bySession"]
  print(f"Number of ROIs: {len(confidence)}")
  rois_good = ""

n_sesh_thresh = clear_n_sesh_thresh(n_sesh_thresh, len(results_tracking["clusters"]["labels_bySession"]))

labels_out = discard_UCIDs_with_fewer_matches(labels,
                                              n_sesh_thresh=n_sesh_thresh,
                                              ref_session=ref_session,
                                              verbose=True,
                                              )

confidence_mask = np.unique(np.concatenate(labels_out)) + 1
confidence_labels = confidence[confidence_mask]
print(f"Number of {rois_good}ROIs tracked across at least {n_sesh_thresh} sessions: {len(confidence_labels)}")

FOVs_colored = compute_colored_FOV(
    spatialFootprints=[r.power(0.7) for r in results_tracking['ROIs']['ROIs_aligned']], 
    FOV_height=results_tracking['ROIs']['frame_height'], 
    FOV_width=results_tracking['ROIs']['frame_width'], 
    labels=labels_out,  ## cluster labels
    alphas_labels=confidence_labels*1.5,  ## Set brightness of each cluster based on some 1-D array
    )

display_image_stack(FOVs_colored, dpi = dpi, cmap = 'viridis')

In [None]:
#@markdown Visualize the images of ROIs from the same cluster

ucids = np.array(results_tracking['clusters']['labels'])
# ucids = np.concatenate(ucid_gt)
ucids_unique = np.unique(ucids[ucids>=0])

# ROI_ims = np.concatenate(data.ROI_images, axis=0)
ROI_ims_sparse = scipy.sparse.vstack(results_tracking['ROIs']['ROIs_aligned'])
ROI_ims_sparse = ROI_ims_sparse.multiply( ROI_ims_sparse.max(1).power(-1) ).tocsr()


ucid_sfCat = []
for ucid in ucids_unique:
    idx = np.where(ucids == ucid)[0]
    ucid_sfCat.append( np.concatenate(list(crop_cluster_ims(ROI_ims_sparse[idx].toarray().reshape(len(idx), results_tracking['ROIs']['frame_height'], results_tracking['ROIs']['frame_width']))), axis=1) )
#     ucid_sfCat.append( np.concatenate(list(ROI_ims_sparse[idx].toarray().reshape(len(idx), data.FOV_height, data.FOV_width)), axis=1) )
# data.ROI_images[i_sesh][idx] for 


%matplotlib inline

for ii in range(min(len(ucid_sfCat), 50)):
    plt.figure(figsize=(40,1))
    plt.imshow(ucid_sfCat[ii], cmap='gray')
    plt.axis('off')

In [None]:
%matplotlib inline

#@markdown Distribution of the cluster size...

_, counts = np.unique(results_tracking['clusters']['labels'], return_counts=True)

plt.figure()
plt.hist(counts, results_tracking['ROIs']['n_sessions']*2 + 1, range=(0, results_tracking['ROIs']['n_sessions']+1));
plt.xlabel('n_sessions'), plt.ylabel('cluster counts');
