# Intro

Welcome to the interactive 'classification by drawing' notebook!
This notebook goes through each step and allows you to tune parameters and view how it changes the results.

The notebook proceeds as follows:

1. **Import** libraries
2. Define **paths** to data
3. Run data through the **pipeline**. (ROInet embedding + UMAP)
4. **Draw** a circle around a region of the UMAP embedding to select as 'good ROIs to keep'
5. **Visualize** results
6. **Save** results


# Import libraries

Widen the notebook

In [1]:
# widen jupyter notebook window
from IPython.display import display, HTML
display(HTML("<style>.container {width:95% !important; }</style>"))
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

Import basic libraries

In [2]:
from pathlib import Path
import tempfile

import numpy as np
from umap import UMAP

Import `roicat`

In [3]:
import roicat

# Find paths to data

##### Prepare list of paths to data

In this example we are using suite2p output files, but other data types can be used (CaImAn, etc.) \
See the notebook on ingesting diverse data: https://github.com/RichieHakim/ROICaT/blob/main/notebooks/jupyter/other/demo_custom_data_importing.ipynb

Make a list containing the paths to all the input files.

In this example we are using suite2p, so the following are defined:
1. `paths_allStat`: a list to all the stat.npy files
2. `paths_allOps`: a list with ops.npy files that correspond 1-to-1 with the stat.npy files

In [4]:
dir_allOuterFolders = r'/media/rich/bigSSD/analysis_data/face_rhythm/mouse_0916N/'

from suite2p_paths import load_plane_paths

dir_allOuterFolders = "/data"
planes = load_plane_paths(dir_allOuterFolders)
print(planes)
paths_allOps = [ p['ops_path'] for i,plane in planes.items() for p in plane  ]
paths_allStat = [ p['stat_path'] for i,plane in planes.items() for p in plane  ]

[print(path) for path in paths_allStat];
print('');
print(f'paths to all ops files:');
[print(path) for path in paths_allOps];


{0: [{'plane_name': '1347888024', 'projection_image_path': '/data/multiplane-ophys_719363_2024-04-25_09-17-21_processed_2024-07-04_13-46-20/1347888024/motion_correction/1347888024_maximum_projection.png', 'stat_path': '/data/multiplane-ophys_719363_2024-04-25_09-17-21_processed_2024-07-04_13-46-20/1347888024/segmentation/suite2p/plane0/stat.npy', 'ops_path': '/data/multiplane-ophys_719363_2024-04-25_09-17-21_processed_2024-07-04_13-46-20/1347888024/segmentation/suite2p/plane0/ops.npy'}, {'plane_name': '1369843881', 'projection_image_path': '/data/multiplane-ophys_719363_2024-05-30_08-57-21_processed_2024-07-05_20-18-01/1369843881/motion_correction/1369843881_maximum_projection.png', 'stat_path': '/data/multiplane-ophys_719363_2024-05-30_08-57-21_processed_2024-07-05_20-18-01/1369843881/segmentation/suite2p/plane0/stat.npy', 'ops_path': '/data/multiplane-ophys_719363_2024-05-30_08-57-21_processed_2024-07-05_20-18-01/1369843881/segmentation/suite2p/plane0/ops.npy'}], 1: [{'plane_name': '

**Important parameters**:

- `um_per_pixel` (float):
    - Resolution. 'micrometers per pixel' of the imaging field of view.

In [5]:
data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_allStat,
    paths_opsFiles=paths_allOps,
    um_per_pixel=0.78,  
    new_or_old_suite2p='new',
    type_meanImg='meanImg',
    verbose=True,
)

assert data.check_completeness(verbose=False)['classification_inference'], f"Data object is missing attributes necessary for tracking."

Starting: Importing FOV images from ops files
Completed: Set FOV_height and FOV_width successfully.
Completed: Imported 4 FOV images.
Completed: Set FOV_images for 4 sessions successfully.
Importing spatial footprints from stat files.


100%|██████████| 4/4 [00:00<00:00, 11.53it/s]


Imported 4 sessions of spatial footprints into sparse arrays.
Completed: Set spatialFootprints for 4 sessions successfully.
Completed: Created session_bool.
Completed: Created centroids.
Staring: Creating centered ROI images from spatial footprints...
Completed: Created ROI images.


# ROInet embedding

This step passes the images of each ROI through the ROInet neural network. The inputs are the images, the output is an array describing the visual properties of each ROI.

##### 1. Initialize ROInet

Initialize the ROInet object. The `ROInet_embedder` class will automatically download and load a pretrained ROInet model. If you have a GPU, this step will be much faster.

In [6]:
DEVICE = roicat.helpers.set_device(use_GPU=True, verbose=True)
dir_temp = tempfile.gettempdir()

roinet = roicat.ROInet.ROInet_embedder(
    device=DEVICE,  ## Which torch device to use ('cpu', 'cuda', etc.)
    dir_networkFiles=dir_temp,  ## Directory to download the pretrained network to
    download_method='check_local_first',  ## Check to see if a model has already been downloaded to the location (will skip if hash matches)
    download_url='https://osf.io/c8m3b/download',  ## URL of the model
    download_hash='357a8d9b630ec79f3e015d0056a4c2d5',  ## Hash of the model file
    forward_pass_version='head',  ## How the data is passed through the network
    verbose=True,  ## Whether to print updates
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,  ## Input images of ROIs
    um_per_pixel=data.um_per_pixel,  ## Resolution of FOV
    pref_plot=False,  ## Whether or not to plot the ROI sizes
);

Using device: cuda:0
File already exists locally: /tmp/ROInet.zip
Hash of local file matches provided hash_hex.
Extracting /tmp/ROInet.zip to /tmp.
Completed zip extraction.
Imported model from /tmp/ROInet_classification_20220902/model.py
Loaded params_model from /tmp/ROInet_classification_20220902/params.json




Generated network using params_model


  self.net.load_state_dict(torch.load(paths_networkFiles['state_dict'], map_location=torch.device(self._device)))


Loaded state_dict into network from /tmp/ROInet_classification_20220902/ConvNext_tiny__1_0_unfrozen__simCLR.pth
Loaded network onto device cuda:0
Starting: resizing ROIs
Completed: resizing ROIs
Defined image transformations: Sequential(
  (0): ScaleDynamicRange(scaler_bounds=(0, 1))
  (1): Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
  (2): TileChannels(dim=0)
)
Defined dataset
Defined dataloader


##### 2. Check ROI_images sizes
In general, you want to see that a neuron fills roughly 25-50% of the area of the image. \
**Adjust `um_per_pixel` above to rescale image size**

In [7]:
roicat.visualization.display_toggle_image_stack(roinet.ROI_images_rs[:1000], image_size=(200,200))

##### 3. Pass data through network

Pass the data through the network. Expect for large datasets (~40,000 ROIs) that this takes around 15 minutes on CPU or 1 minute on GPU.

In [8]:
roinet.generate_latents();

starting: running data through network


100%|██████████| 124/124 [00:04<00:00, 30.61it/s]


completed: running data through network


# UMAP embedding

Reduce the dimensionality of the output of ROInet (~100 dims) to 2 dimensions so that we can visualize it. Feel free to use any settings here that do a good job of clustering your data as you see fit.

In [9]:
umap = UMAP(
    n_neighbors=25,
    n_components=2,
    n_epochs=400,
    verbose=True,
    densmap=False,
)
emb = umap.fit_transform(roinet.latents)

UMAP(n_epochs=400, n_neighbors=25, verbose=True)
Wed Aug 14 18:21:19 2024 Construct fuzzy simplicial set
Wed Aug 14 18:21:20 2024 Finding Nearest Neighbors
Wed Aug 14 18:21:22 2024 Finished Nearest Neighbor Search
Wed Aug 14 18:21:24 2024 Construct embedding


Epochs completed:   0%|            0/400 [00:00]

	completed  0  /  400 epochs
	completed  40  /  400 epochs
	completed  80  /  400 epochs
	completed  120  /  400 epochs
	completed  160  /  400 epochs
	completed  200  /  400 epochs
	completed  240  /  400 epochs
	completed  280  /  400 epochs
	completed  320  /  400 epochs
	completed  360  /  400 epochs
Wed Aug 14 18:21:26 2024 Finished embedding


# Draw selection

In order to visualize the kinds of ROIs at each region of the plot, we need to select a subset of points to overlay ROI images onto.

In [10]:
idx_images_overlay = roicat.visualization.get_spread_out_points(
    emb,
    n_ims=min(emb.shape[0], 1500),  ## Select number of overlayed images here
    dist_im_to_point=0.8,
#     border_frac=0.05,
)

images_overlay = roinet.ROI_images_rs[idx_images_overlay]

Now we can use an interactive plot (using the holoviews library) to select our region of the scatterplot to circle.\
This plot works as follows:
- Use the **LASSO TOOL** to circle a region on the plot containing the images of ROIs that you'd like to keep/extract/mark.
    - You can circle multiple times, but only the last one will be saved
- The saved indices are saved in a temporary file that can be recovered using the `fn_get_indices` function output below. Just call `fn_get_indices()` and it will return a list of the integer indices.
- If it is difficult to see the images, do the following:
    - adjust the number of images in the above function (`roicat.visualization.get_spread_out_points`) using the `n_ims` argument
    - adjust the overlap of the images in the below function (`roicat.visualization.select_region_scatterPlot`) using the `frac_overlap_allowed` argument

In [11]:
fn_get_indices, layout, path_tempFile = roicat.visualization.select_region_scatterPlot(
    data=emb,
    idx_images_overlay=idx_images_overlay,
    images_overlay=images_overlay[:, 6:30][:,:,6:30],
    size_images_overlay=0.35,
    frac_overlap_allowed=0.5,
    figsize=(800,800),
    alpha_points=1.0,
    size_points=10,
    color_points='b',
);

Drop the results into easier to use output variables

In [None]:
# roicat.helpers.export_svg_hv_bokeh(layout, '/home/rich/Desktop/umap_with_labels_dotsOnly.svg')

In [27]:
n_sessions = len(data.ROI_images)
idx_session_cat = np.concatenate([[ii]*data.ROI_images[ii].shape[0] for ii in range(n_sessions)])
bool_good_cat = roicat.helpers.idx2bool(fn_get_indices(), length=len(idx_session_cat))
preds_good_sessions = [np.int64((bool_good_cat * (idx_session_cat==ii))[idx_session_cat==ii]) for ii in range(data.n_sessions)]

classification_output = {
    'preds': preds_good_sessions,
    'spatialFootprints': data.spatialFootprints,
    'FOV_height': data.FOV_height,
    'FOV_width': data.FOV_width,
    'input_data': data.paths_stat,
    
}

# Visualize outputs

In [28]:
print(f"Number of 'good' and 'bad' ROIs from each session:")
print([f"good: {p.sum()} / bad: {(p!=1).sum()}" for p in preds_good_sessions])

Number of 'good' and 'bad' ROIs from each session:
['good: 320 / bad: 14', 'good: 290 / bad: 18', 'good: 181 / bad: 15', 'good: 141 / bad: 8']


In [29]:
FOVs_colored = roicat.visualization.compute_colored_FOV(
    spatialFootprints=data.spatialFootprints,
    FOV_height=data.FOV_height,
    FOV_width=data.FOV_width,
    labels=preds_good_sessions,
    cmap=roicat.helpers.simple_cmap([[1,0,0],[0,1,0]]),
)

roicat.visualization.display_toggle_image_stack(
    FOVs_colored,
    image_size=(FOVs_colored[0].shape[0]*2, FOVs_colored[0].shape[1]*2)
)

# Save results

The results file can be opened using any of the following methods:
1. `roicat.helpers.pickle_load(path)`
2. `np.load(path)`
3. 
```
    import pickle
    with open(path_save, mode='rb') as f:
        test = pickle.load(f)
  ```

In [30]:
mouse = '719363'

In [31]:
dir_save = f'/scratch/{mouse}/'
filename_save = mouse

path_save = str(Path(dir_save).resolve() / (filename_save + '.ROICaT.classification_drawn.results' + '.pkl'))
print(f'path_save: {path_save}')

roicat.helpers.pickle_save(classification_output, path_save, mkdir=True)

path_save: /scratch/719363/719363.ROICaT.classification_drawn.results.pkl


# Thank you
If you encountered any difficulties, please let us know at the issues page: https://github.com/RichieHakim/ROICaT/issues