# Intro

Welcome to the classifier training notebook!

To run this notebook, you must first have prepared images and labels. This can be done using the `labeling_interactive.ipynb` notebook, or can be provided manually.

The notebook proceeds as follows:
1. **Import** libraries
2. Define **paths** to data
3. **Train** a classifier
4. **Save** results

# Imports

In [None]:
# widen jupyter notebook window
from IPython.display import display, HTML
display(HTML("<style>.container {width:95% !important; }</style>"))

In [16]:
from pathlib import Path
import tempfile
import functools

import numpy as np
import sklearn
import pandas as pd

In [1]:
import roicat

# Data Importing

## Option 1: Use Data Results from Labeling Interactive Outputs

##### 1. Specify filepath

In [67]:
filepath_labeling_results = r'/media/rich/bigSSD/data_tmp/test_data/mouse_1.labeling.results.richfile'

##### 2. Import data

Be sure to specify a `um_per_pixel`

In [None]:
labelingInteractive = roicat.util.RichFile_ROICaT(path=filepath_labeling_results).load()

data = roicat.data_importing.Data_roicat();
data.set_ROI_images(
    ROI_images=[labelingInteractive['images'][labelingInteractive['labels']['df'].index.to_numpy()]],
    um_per_pixel=1.0,
);
data.set_class_labels(
    [labelingInteractive['labels']['df']['label'].to_numpy()],
);

print('')
data.check_completeness()

Notice that `.check_completeness()` shows that `classification_training` is `True`

## Option 2: Make a new data object

See the [demo_data_importing](https://github.com/RichieHakim/ROICaT/blob/dev/notebooks/jupyter/other/demo_data_importing.ipynb) notebook to build a custom data object using any kind of data (Suite2p, CaImAn, etc.). It's really easy!

For example:
```
data = roicat.data_importing.Data_suite2p(
    paths_statFiles=['/path/to/stat.npy'],
    paths_opsFiles=['/path/to/ops.npy'],
    um_per_pixel=2.5,  
    new_or_old_suite2p='new',
    type_meanImg='meanImgE',
    verbose=True,
)
data.set_class_labels(path_labels=['/path/to/labels.npy'])
assert data.check_completeness(verbose=False)['classification_inference'], f"Data object is missing attributes necessary for tracking."
```

In [None]:
data = roicat.data_importing.Data_suite2p(
    paths_statFiles=['/media/rich/bigSSD/for_Josh/SimCLR-Label-Data/mouse2_6__20210409/stat.npy'],
    paths_opsFiles=['/media/rich/bigSSD/for_Josh/SimCLR-Label-Data/mouse2_6__20210409/ops.npy'],
    um_per_pixel=2.5,  
    new_or_old_suite2p='new',
    type_meanImg='meanImgE',
    verbose=True,
)

data.set_class_labels(path_labels=['/media/rich/bigSSD/for_Josh/SimCLR-Label-Data/mouse2_6__20210409/labels_round2_sesh1.npy'])

assert data.check_completeness(verbose=False)['classification_inference'], f"Data object is missing attributes necessary for tracking."

## Option 3: Use Data Results from Classify by Drawing Selection output


##### 1. Specify filepath

In [52]:
filepath_labelingDrawing = r'/media/rich/bigSSD/data_tmp/test_data/mouse_1.classification_drawn.run_data.richfile'

##### 2. Import data

In [None]:
labelingDrawing = roicat.util.RichFile_ROICaT(path=filepath_labelingDrawing).load()

data = roicat.data_importing.Data_roicat();

data.set_spatialFootprints(
    spatialFootprints=labelingDrawing['data']['spatialFootprints'], 
    um_per_pixel=labelingDrawing['data']['um_per_pixel'],
)

data.set_FOVHeightWidth(
    FOV_height=labelingDrawing['data']['FOV_height'],
    FOV_width=labelingDrawing['data']['FOV_width'],
)

data.transform_spatialFootprints_to_ROIImages(out_height_width=(36, 36));

data.set_class_labels(labels=labelingDrawing['preds'])

assert data.check_completeness(verbose=False)['classification_inference'], f"Data object is missing attributes necessary for tracking."

# ROInet embedding

This step passes the images of each ROI through the ROInet neural network. The inputs are the images, the output is an array describing the visual properties of each ROI.

##### 1. Initialize ROInet

Initialize the ROInet object. The `ROInet_embedder` class will automatically download and load a pretrained ROInet model. If you have a GPU, this step will be much faster.

In [None]:
DEVICE = roicat.helpers.set_device(use_GPU=True, verbose=True)
dir_temp = tempfile.gettempdir()

roinet = roicat.ROInet.ROInet_embedder(
    device=DEVICE,  ## Which torch device to use ('cpu', 'cuda', etc.)
    dir_networkFiles=dir_temp,  ## Directory to download the pretrained network to
    download_method='check_local_first',  ## Check to see if a model has already been downloaded to the location (will skip if hash matches)
    download_url='https://osf.io/c8m3b/download',  ## URL of the model
    download_hash='357a8d9b630ec79f3e015d0056a4c2d5',  ## Hash of the model file
    forward_pass_version='head',  ## How the data is passed through the network
    verbose=True,  ## Whether to print updates
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,  ## Input images of ROIs
    um_per_pixel=data.um_per_pixel,  ## Resolution of FOV
    pref_plot=False,  ## Whether or not to plot the ROI sizes
);

##### 2. Check ROI_images sizes
In general, you want to see that a neuron fills roughly 25-50% of the area of the image. \
**Adjust `um_per_pixel` above to rescale image size**

In [None]:
roicat.visualization.display_toggle_image_stack(roinet.ROI_images_rs[:1000], image_size=(200,200))

##### 3. Pass data through network

Pass the data through the network. Expect for large datasets (~40,000 ROIs) that this takes around 15 minutes on CPU or 1 minute on GPU.

In [None]:
roinet.generate_latents();

# Train / Validation / Test Split Data, Hyperparameter Tune on Validation Set, and Fit Model

Prepare input data

In [75]:
X = np.array(roinet.latents).astype(np.float32)
y = np.concatenate(data.class_labels_index).astype(np.int64)

##### Initialize the **AutoClassifier**. 

This class is meant to be easy to use, but advanced users will find all the parameters they might want to play with available as arguments. Feel free to play with some of them by looking through the detailed [**DOCUMENTATION**](https://roicat.readthedocs.io/en/latest/roicat.html#roicat.classification.classifier.Auto_LogisticRegression) on this class. Here's a brief tutorial:
- This class performs classification by fitting a Logistic Regression model
- There is one critical parameter in this model: ``'C'``. Lowering 'C' means more regularization.
- This class will **automatically tune** any parameter that is specified in the `params_LogisticRegression` argument as a list of values. See the sklearn documentation on [LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) for the full list of arguments that can be specified in ``params_LogisticRegression``.
    - Examples: 
    ```
    ## Initialize with NO TUNING. All parameters are fixed.
    autoclassifier = Auto_LogisticRegression(
        X,
        y,
        params_LogisticRegression={
            'C': 1e-14,
            'penalty': 'l2',
            'solver': 'lbfgs',
        },
    )

    ## Initialize with TUNING 'C', 'penalty', and 'l1_ratio'. 'solver' is fixed.
    autoclassifier = Auto_LogisticRegression(
        X,
        y,
        params_LogisticRegression={
            'C': [1e-14, 1e3],
            'penalty': ['l1', 'l2', 'elasticnet'],
            'l1_ratio': [0.0, 1.0],
            'solver': 'lbfgs',
        },
    )
    ```

In [82]:
autoclassifier = roicat.classification.classifier.Auto_LogisticRegression(
    X=X,
    y=y,
    params_LogisticRegression={
        'C': [1e-13, 1e3],
    },
    verbose=True,
)

##### Run the AutoClassifier

In [None]:
autoclassifier.fit()

# Visualize results

In [None]:
autoclassifier.plot_C_curve()

In [None]:
accuracy, confusion_matrix = autoclassifier.evaluate_model()

print(f"Accuracy: {accuracy}")

roicat.visualization.plot_confusion_matrix(
    confusion_matrix=confusion_matrix,
    class_names=None,
)

# Save Outputs

In [88]:
directory_save = '/media/rich/bigSSD/data_tmp/test_data/'
filename_prefix_model = 'mouse_1'

paths_save = {
    'model':    str(Path(directory_save) / (filename_prefix_model + '.classification_training.autoclassifier' + '.onnx')),
    'run_data': str(Path(directory_save) / (filename_prefix_model + '.classification_training.run_data.richfile')),
}

In [None]:
autoclassifier.save_model(
    filepath=paths_save['model'],
    allow_overwrite=True,
)

roicat.util.RichFile_ROICaT(path=paths_save['run_data']).save({
    'data': data.__dict__,
    'roinet': roinet.__dict__,
    'accuracy': accuracy,
    'confusion_matrix': confusion_matrix,
}, overwrite=True)