In [None]:
# widen jupyter notebook window
from IPython.display import display, HTML
display(HTML("<style>.container {width:95% !important; }</style>"))
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

In [2]:
import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt

In [3]:
import roicat

This notebook is for visualization of the results of an ROICaT run. You can play with different inclusion criteria here, and also combine classification and tracking results to view the intersection of the two.

In [25]:
use_classificationResults = False

results = roicat.util.RichFile_ROICaT(path='/media/rich/bigSSD/data_tmp/test_data/mouse_1.tracking.results_all.richfile').load()

if use_classificationResults:
    results_classification = roicat.util.RichFile_ROICaT(path='/media/rich/bigSSD/data_tmp/test_data/mouse_1.classification_drawn.run_data.richfile').load()
else:
    results_classification = None

In [None]:
## List all available quality metrics
print('Available quality metrics:')
display(results['clusters']['quality_metrics'].keys())

In [None]:
## Plot the distribution of the quality metrics
confidence = (((np.array(results['clusters']['quality_metrics']['cluster_silhouette']) + 1) / 2) * np.array(results['clusters']['quality_metrics']['cluster_intra_means']))

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7))

axs[0,0].hist(results['clusters']['quality_metrics']['cluster_silhouette'], 50);
axs[0,0].set_xlabel('cluster_silhouette');
axs[0,0].set_ylabel('cluster counts');

axs[0,1].hist(results['clusters']['quality_metrics']['cluster_intra_means'], 50);
axs[0,1].set_xlabel('cluster_intra_means');
axs[0,1].set_ylabel('cluster counts');

axs[1,0].hist(confidence, 50);
axs[1,0].set_xlabel('confidence');
axs[1,0].set_ylabel('cluster counts');

axs[1,1].hist(results['clusters']['quality_metrics']['sample_silhouette'], 50);
axs[1,1].set_xlabel('sample_silhouette score');
axs[1,1].set_ylabel('roi sample counts');

In [None]:
labels = [np.array(c)*np.array(t) - (np.logical_not(np.array(c))) for c, t in zip(results_classification['preds'], results['clusters']['labels_bySession'])] if use_classificationResults else results['clusters']['labels_bySession']

FOVs_colored = roicat.visualization.compute_colored_FOV(
    labels=np.array(results['clusters']['labels']),
    spatialFootprints=results['ROIs']['ROIs_aligned'], 
    FOV_height=results['ROIs']['frame_height'], 
    FOV_width=results['ROIs']['frame_width'], 
    alphas_sf=np.array(results['clusters']['quality_metrics']['sample_silhouette']) > 0.0,  ## SET INCLUSION CRITERIA FOR CLUSTERS/LABELS
    alphas_labels=np.array(results['clusters']['quality_metrics']['cluster_silhouette']) > 0.0,  ## SET INCLUSION CRITERIA FOR ROI SAMPLES
)

roicat.visualization.display_toggle_image_stack(FOVs_colored, image_size=2)

In [None]:
ucids = np.array(results['clusters']['labels'])
ucids_unique = np.unique(ucids[ucids>=0])

ROI_ims_sparse = scipy.sparse.vstack(results['ROIs']['ROIs_aligned'])
ROI_ims_sparse = ROI_ims_sparse.multiply( ROI_ims_sparse.max(1).power(-1) ).tocsr()


ucid_sfCat = []
for ucid in ucids_unique:
    idx = np.where(ucids == ucid)[0]
    ucid_sfCat.append( np.concatenate(list(roicat.visualization.crop_cluster_ims(ROI_ims_sparse[idx].toarray().reshape(len(idx), results['ROIs']['frame_height'], results['ROIs']['frame_width']))), axis=1) )

for ii in range(min(len(ucid_sfCat), 10)):
    plt.figure(figsize=(40,1))
    plt.imshow(ucid_sfCat[ii], cmap='gray')
    plt.axis('off')

In [None]:
# %matplotlib widget

ucids = np.array(results['clusters']['labels'])
_, counts = np.unique(ucids, return_counts=True)

n_sessions = len(results['clusters']['labels_bySession'])
plt.figure()
plt.hist(counts, bins=n_sessions*2 + 1, range=(0, n_sessions+1));
plt.xlabel('number of sessions a cluster is present in');
plt.ylabel('cluster counts');