# ROICaT Classification by Drawing Colab Notebook

Welcome to the interactive ROI classification notebook! This Colab notebook allows you to easily classify ROIs in your calcium imaging data.

This Colab notebook is currently designed to work with  **Suite2p output files** (stat.npy and ops.npy).

We recommend running this Colab notebook with GPU. Please check your runtime setting:
_Runtime -> Change runtime type -> Hardware accelerator -> choose GPU_

**The notebook proceeds as follows:**
1. **Import** libraries
2. Define **paths** to data
3. Run data through the **pipeline**.
4. Draw to select **good ROIs**
5. **Visualize** results
6. **Save** results

FAQ on how to play with ROICaT is [here](https://roicat.readthedocs.io/en/dev/).
If you have any questions not covered in the FAQ, please don't hesitate to open a new issue tap on [ROICaT github page](https://github.com/RichieHakim/ROICaT/issues).

# SETUP

In [None]:
#@title Install ROICaT
#@markdown Please execute this cell by pressing the _Play_ button on the left. This process will take less than 5 min.

#@markdown After running this cell, the kernel will **automatically be killed**, this is on purpose. The runtime must be restarted to remove tensorflow.

# Install ROICaT
!pip uninstall -y tensorflow # Uninstall default tensorflow to avoid any potential conflict
!pip install --user "roicat[classification] @ git+https://github.com/RichieHakim/ROICaT.git@dev"

display("Restart runtime!")
import os
os._exit(0)

In [None]:
#@title Import libraries

## standard libraries
import os
import zipfile
from pathlib import Path

## other libraries
import numpy as np
from umap import UMAP
from ipywidgets import widgets, Button, Output
from IPython.display import display

## roicat
import roicat


In [None]:
#@title Load helper functions

# Checkbox widget
def checkbox_widget(input_list):
  checkboxes = []
  for input in input_list:
    checkbox = widgets.Checkbox(description=input, value=False)
    checkboxes.append(checkbox)

  selected_indices = []
  return checkboxes, selected_indices

def get_selected_indices(checkboxes):
  indices = [i for i, checkbox in enumerate(checkboxes) if checkbox.value]
  return indices

def make_checkbox_button_clicked(checkboxes, selected_indices):
  def checkbox_button_clicked(_):
    # _ is a placeholder argument to handle button click event
    # Clear the list
    selected_indices.clear()
    # Add the newly selected indices
    selected_indices.extend(get_selected_indices(checkboxes))
  return checkbox_button_clicked

DEVICE = roicat.helpers.set_device(use_GPU=True, verbose=True)

# Import paths

#### OPTION 1: Mount google drive (RECOMMENDED)

In [None]:
#@markdown Upload your data to Google Drive, then mount the drive and access the cloud directory here.

#@markdown You can use the sidebar to the left to browse your google drive directories.

from google.colab import drive
path_gdrive = '/content/gdrive'
drive.mount(path_gdrive, force_remount=True)

In [None]:
#@markdown ### Enter your google drive directory containing suite2p files:
dir_s2p = '/content/gdrive/MyDrive/Colab_Notebooks/ROICaT_notebooks/statFiles/' #@param {type:"string"}

#### OPTION 2: Upload files from local (slower)


In [None]:
#@markdown Load files from local.

#@markdown a) For a **single session**: Upload a single stat.npy and ops.npy file.

#@markdown b) For a **nested** folder structure from multiple sessions: Create and upload a **zip file** containing all the folders. stat.npy and ops.npy files from the same session should be in a folder with each other and separate from other sessions.


from google.colab import files
uploaded = files.upload()

for file_name in list(uploaded.keys()):
  file_path = Path.cwd() / file_name
  if file_path.exists():
    if zipfile.is_zipfile(file_path):
      print("Zip file detected: ", file_path)
      with zipfile.ZipFile(file_path, 'r') as zip_ref:
        zip_ref.extractall(Path.cwd())
        print("Zip file extracted to: ", Path.cwd())
    else:
      print("Uploaded file is not a zip file: ", file_path)
  else:
    print("Failed to detect uploaded file path: ", file_path)

dir_s2p = os.getcwd()

### Select sessions to classify

In [None]:
#@markdown **Automatically find files by name:**

filename_statFile = 'stat.npy' #@param {type:"string"}
filename_opsFile = 'ops.npy' #@param {type:"string"}

paths_allStat = roicat.helpers.find_paths(dir_outer=dir_s2p, reMatch=filename_statFile, depth=8)
paths_allOps = roicat.helpers.find_paths(dir_outer=dir_s2p, reMatch=filename_opsFile, depth=8)
print('paths of stat files')
display(paths_allStat)
print('paths of ops files')
display(paths_allStat)

stat_list = ['/'.join(Path(statfile).parts[-3:]) for statfile in paths_allStat]
stat_checkboxes, selected_indices = checkbox_widget(stat_list)
checkbox_display = widgets.VBox(stat_checkboxes)
display(checkbox_display)
button = Button(description="Get Stat Files")
button.on_click(make_checkbox_button_clicked(stat_checkboxes, selected_indices))

display(button)

In [None]:
#@markdown Selected sessions to classify:

stat_selected_indices = ops_selected_indices = selected_indices

paths_stat = [paths_statFiles[i] for i in stat_selected_indices]
print("Stat files")
display(paths_stat)
paths_ops = [paths_opsFiles[i] for i in ops_selected_indices]
print("Ops files")
display(paths_ops)

# Import Data

In [None]:
#@markdown ### Enter micrometers per pixel of the imaging FOV (a rough estimate is okay, just make sure the resized images below fix the frame well):
um_per_pixel = 2.5 #@param {type:"number"}
#@markdown **Default: 2.5**, _type: float_


In [None]:
#@markdown Load ROIs...

data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_stat,
    paths_opsFiles=paths_ops,
    um_per_pixel=um_per_pixel,
    new_or_old_suite2p='new',
    type_meanImg='meanImgE',
    verbose=True,
)

assert data.check_completeness(verbose=False)['classification_inference'], f"Data object is missing attributes necessary for tracking."

# ROInet embedding

In [None]:
#@markdown Initialize the ROInet object.The ROInet_embedder class will automatically download and load a pretrained ROInet model for classification. 

#@markdown If you have a GPU, this step will be much faster.
roinet = roicat.ROInet.ROInet_embedder(
    device=DEVICE,  ## Which torch device to use ('cpu', 'cuda', etc.)
    dir_networkFiles=os.getcwd(),  ## Directory to download the pretrained network to
    download_method='check_local_first',  ## Check to see if a model has already been downloaded to the location (will skip if hash matches)
    download_url='https://osf.io/c8m3b/download',  ## URL of the model
    download_hash='357a8d9b630ec79f3e015d0056a4c2d5',  ## Hash of the model file
    forward_pass_version='latent',  ## How the data is passed through the network
    verbose=True,  ## Whether to print updates
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,  ## Input images of ROIs
    um_per_pixel=data.um_per_pixel,  ## Resolution of FOV
    pref_plot=False,  ## Whether or not to plot the ROI sizes
);

In [None]:
#@markdown Visualize ROI images
%matplotlib notebook
roicat.visualization.display_toggle_image_stack(roinet.ROI_images_rs[:1000], image_size=(200,200))

In [None]:
#@markdown Pass data through ROInet
roinet.generate_latents();

# Draw Selection

In [None]:
#@markdown Prepare UMAP
umap = UMAP(
    n_neighbors=2,
    n_components=2,
    n_epochs=200,
    verbose=True,
    densmap=True,
)
emb = umap.fit_transform(roinet.latents)

#### DRAW
Now we can use an interactive plot (using the holoviews library) to select our region of the scatterplot to circle.\
This plot works as follows:
- Use the **LASSO TOOL** to circle a region on the plot containing the images of ROIs that you'd like to keep/extract/mark.
    - You can circle multiple times, but only the last one will be saved
- The saved indices are saved in a temporary file that can be recovered using the `fn_get_indices` function output below. Just call `fn_get_indices()` and it will return a list of the integer indices.
- If it is difficult to see the images, do the following:
    - adjust the number of images in the above function (`roicat.visualization.get_spread_out_points`) using the `n_ims` argument
    - adjust the overlap of the images in the below function (`roicat.visualization.select_region_scatterPlot`) using the `frac_overlap_allowed` argument

In [None]:
#@markdown Drawing GUI
plot_size = (1200, 1200) #@param {type:"raw"}
#@markdown Plot Size in pixel. **Default: (1200, 1200)**, _type: tuple_
n_images = 2000 #@param {type:"raw"}
#@markdown Number of images. **Default: 2000**, _type: int_
overlap = 0.6  #@param {type:"raw"}
#@markdown Overlap fraction between images on the scatterplot. Larger numbers means bigger images.  **Default: 0.6**, _type: float_

idx_images_overlay = roicat.visualization.get_spread_out_points(
    emb,
    n_ims=min(emb.shape[0], n_images),
    dist_im_to_point=0.3,
    border_frac=0.05,
    device='cpu',
)

images_overlay = roinet.ROI_images_rs[idx_images_overlay]

# UMAP drawing plus ROI image
fn_get_indices, layout, path_tempFile = roicat.visualization.select_region_scatterPlot(
    data=emb,
    idx_images_overlay=idx_images_overlay,
    images_overlay=images_overlay,
    size_images_overlay=None,
    frac_overlap_allowed=overlap,
    figsize=plot_size,
);

In [None]:
#@markdown Visualize Classification results...
# Index good ROIs
n_sessions = len(data.ROI_images)
idx_session_cat = np.concatenate([[ii]*data.ROI_images[ii].shape[0] for ii in range(n_sessions)])
bool_good_cat = roicat.helpers.idx2bool(fn_get_indices(), length=len(idx_session_cat))
preds_good_sessions = [np.int64((bool_good_cat * (idx_session_cat==ii))[idx_session_cat==ii]) for ii in range(data.n_sessions)]

classification_output = {
    'preds': preds_good_sessions,
    'spatialFootprints': data.spatialFootprints,
    'FOV_height': data.FOV_height,
    'FOV_width': data.FOV_width,
}

print(f"Number of 'good' and 'bad' ROIs from each session:")
print([f"good: {p.sum()} / bad: {(p!=1).sum()}" for p in preds_good_sessions])

# Visualize
%matplotlib inline
FOVs_colored = roicat.visualization.compute_colored_FOV(
    spatialFootprints=data.spatialFootprints,
    FOV_height=data.FOV_height,
    FOV_width=data.FOV_width,
    labels=preds_good_sessions,
    cmap=roicat.helpers.simple_cmap([[1,0,0],[0,1,0]]),
)

roicat.visualization.display_toggle_image_stack(FOVs_colored)

# Save results

The results file can be opened using any of the following methods:
1. `roicat.helpers.pickle_load(path)`
2. `np.load(path)`
3. ```
    import pickle
    with open(path_save, mode='rb') as f:
        test = pickle.load(f)
  ```

In [None]:
#@markdown Save result in temporary Colab directory...
import datetime
save_filename = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + '_ROICaT.classification_drawn.results' + '.pkl'
path_save = Path.cwd() / save_filename
print(f'Classification result on Colab cloud: {path_save}')

roicat.helpers.pickle_save(classification_output, path_save)

In [None]:
#@markdown Copy the saved results file to your google drive.
import shutil

if ('path_gdrive' not in locals()) and ('path_gdrive' not in globals()):
  from google.colab import drive
  path_gdrive = '/content/gdrive'
  drive.mount(path_gdrive, force_remount=True)

  copy_path = Path(path_gdrive) / 'MyDrive' / save_filename
  shutil.copyfile(path_save, copy_path)
  print(f'Classification Result on your gdrive: {copy_path}')

In [None]:
#@markdown Download the results file to your local computer. This process might take several minutes.

from google.colab import files
files.download(path_save) 

# Thank you
If you encountered any difficulties, please let us know at the issues page: https://github.com/RichieHakim/ROICaT/issues