# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import time

import numpy as np
import torch
from bnpm import file_helpers, optimization
import sklearn.utils.class_weight
from torch import nn, optim
from tqdm import tqdm
import sklearn.linear_model
import multiprocessing as mp

import roicat.classification.classifier_util as cu
import scipy.sparse
import roicat
import bnpm.h5_handling
from pathlib import Path
import shutil
import warnings
import umap

import matplotlib.pyplot as plt
import matplotlib.offsetbox
import json
import os
import natsort
from roicat import helpers
# from kymatio.torch import Scattering2D
import gc
import functools

2023-06-22 15:44:49.738676: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Initialize Parameters

In [2]:
directory_save = Path(r'/Users/josh/analysis/outputs/ROICaT/classification/classifier_train')


In [3]:
INTEGER_MAX = np.iinfo(np.int64(0).dtype).max
directory_save.mkdir(exist_ok=True, parents=True)

# Data Importing

## Option 1: Use Data Results from Labeling Interactive Outputs

### Option 1.A: Labeling Interactive — Specify Directory / Filenames

In [4]:
filepath_labellingInteractive = str(Path('/Users/josh/analysis/outputs/ROICaT/classification/labeling_interactive/test.ROICaT.labeling.results.pkl').resolve())

### Option 1.B: Labeling Interactive — Find and Load Data

In [5]:
labelingInteractive = roicat.helpers.pickle_load(filepath_labellingInteractive)
data = roicat.data_importing.Data_roicat();
data.set_ROI_images([labelingInteractive['images'][labelingInteractive['labels']['index']]]);
data.set_class_labels([labelingInteractive['labels']['label']]);

Starting: Importing ROI images
Completed: Imported 1 sessions. Each session has [502] ROIs. Total number of ROIs is 502. The um_per_pixel is 1.0 um per pixel.
Starting: Importing class labels
Labels and ROI Images match in shapes: Class labels and ROI images have the same number of sessions and the same number of ROIs in each session.
Completed: Imported labels for 1 sessions. Each session has [502] class labels. Total number of class labels is 502.




## Option 2: Suite2p Data

### Option 2.A: Suite2p — Specify Directory / Filenames

In [None]:
dir_allOuterFolders = str(Path('/Users/josh/analysis/data/ROICaT/classification/stat_s2p').resolve())
pathSuffixToStat = 'stat.npy'
pathSuffixToOps = 'ops.npy'
pathSuffixToLabels = 'labels_round2_sesh2.npy'

### Option 2.B: Suite2p — Find and Load Data

In [None]:
paths_allStat, paths_allOps, paths_allLabels = helpers.find_paths_requireAll(
    dir_outer=dir_allOuterFolders,
    filenames=[
        pathSuffixToStat,
        pathSuffixToOps,
        pathSuffixToLabels
    ],
depth=4,)

data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_allStat,
    paths_opsFiles=paths_allOps,
    class_labels=[np.load(path_label, allow_pickle=True) for path_label in paths_allLabels],
    um_per_pixel=2.0,
    new_or_old_suite2p='new',
    out_height_width=[36, 36],
    type_meanImg='meanImgE',
    verbose=True,
);

## Option 3: Caiman Data

### Option 3.A: Caiman — Specify Directory / Filenames

In [None]:
## TODO: Implement Caiman Data Importing

### Option 3.B: Caiman — Find and Load Data

In [None]:
## TODO: Implement Caiman Data Importing

## Option 4: Sparse Raw Image Data

### Option 4.A: Raw Images — Specify Directory / Filenames

In [None]:
dir_allOuterFolders = str(Path('/Users/josh/analysis/data/ROICaT/classification/raw_images').resolve())
pathSuffixToROIImages = 'sf_concat_rs_sparse.npz'
pathSuffixToLabels = 'labels.npy'

### Option 4.B: Raw Images — Find and Load Data

In [None]:
paths_allROIImages, paths_allLabels = helpers.find_paths_requireAll(
    dir_outer=dir_allOuterFolders,
    filenames=[
        pathSuffixToROIImages,
        pathSuffixToLabels,
    ],
depth=4,)

list_ROIImages_sparse = [scipy.sparse.load_npz(filepath_data_rawImages) for filepath_data_rawImages in paths_allROIImages]
class_labels = [np.load(filepath_data_labels).astype(int) for filepath_data_labels in paths_allLabels]

data = roicat.data_importing.Data_roicat(verbose=True)
data.set_ROI_images(
    ROI_images=[sf.A.reshape(sf.shape[0], 36, 36) for sf in list_ROIImages_sparse],
    um_per_pixel=2.0,
)
data.set_class_labels(class_labels)

# Pass Through Network to Generate Image Latents

In [6]:
# Neural network embedding distances
roinet = roicat.ROInet.ROInet_embedder(
    device=roicat.util.helpers.set_device('cuda:0'),
    dir_networkFiles=r"/Users/josh/analysis/models",
    download_method="check_local_first",
    download_url="https://osf.io/xwzhp/download",
    download_hash="134b170242141c26b0adbd9e0fd80d0e",
    forward_pass_version="head",
    verbose=True,
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,
    um_per_pixel=data.um_per_pixel,
    pref_plot=False,
    batchSize_dataloader=8,
    pinMemory_dataloader=True,
    numWorkers_dataloader=mp.cpu_count(),
    persistentWorkers_dataloader=True,
    prefetchFactor_dataloader=2,    
);
roinet.generate_latents();

gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

devices available: []
no GPU available. Using CPU.
File already exists locally: /Users/josh/analysis/models/ROInet.zip




Hash of local file matches provided hash_hex.
Extracting /Users/josh/analysis/models/ROInet.zip to /Users/josh/analysis/models.
Completed zip extraction.
['/Users/josh/analysis/models/ROInet_classification_20220902', '/Users/josh/analysis/models/ROInet_classification_20220902/ConvNext_tiny__1_0_best__simCLR_wPCA.pth', '/Users/josh/analysis/models/ROInet_classification_20220902/model.py', '/Users/josh/analysis/models/ROInet_classification_20220902/classifier.pkl', '/Users/josh/analysis/models/ROInet_classification_20220902/params.json', '/Users/josh/analysis/models/ROInet_classification_20220902/__pycache__', '/Users/josh/analysis/models/ROInet_classification_20220902/__pycache__/model.cpython-39.pyc']
Imported model from /Users/josh/analysis/models/ROInet_classification_20220902/model.py
Loaded params_model from /Users/josh/analysis/models/ROInet_classification_20220902/params.json




Generated network using params_model
Loaded state_dict into network from /Users/josh/analysis/models/ROInet_classification_20220902/ConvNext_tiny__1_0_best__simCLR_wPCA.pth
Loaded network onto device cpu
Starting: resizing ROIs
Completed: resizing ROIs
Defined image transformations: Sequential(
  (0): ScaleDynamicRange(scaler_bounds=(0, 1))
  (1): Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
  (2): TileChannels(dim=0)
)
Defined dataset
Defined dataloader
starting: running data through network


100%|██████████| 63/63 [01:13<00:00,  1.16s/it]


completed: running data through network


# Train / Validation / Test Split Data, Hyperparameter Tune on Validation Set, and Fit Model

In [7]:
# TODO: JZ, IMPLEMENT AS LOOP FOR OPTUNA FOR HYPERPARAMETER TUNING
# Create data splitting object for stratified sampling into train and test sets (as well as downsampling)
data_splitter = cu.Datasplit(
    features=roinet.latents.cpu().numpy(),
    labels=np.concatenate(data.class_labels, axis=0),
    n_train=INTEGER_MAX,
    val_size=0.2,
    test_size=0.2,
)

# Create lenet model, associated optimizer, loss function, and training tracker
model = cu.Classifier(
    sklearn.linear_model.LogisticRegression,                      
    solver='lbfgs',                      
    fit_intercept=True,                      
    max_iter=10000,                      
    C=1e5,                      
    class_weight=data_splitter.dict_class_weights
);
model.fit(data_splitter.features_train, data_splitter.labels_train);

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [8]:
print(f'Calculating tracker outputs and saving to {directory_save}...')
training_tracker = cu.TrainingTracker(
    directory_save=directory_save,
    class_weights=data_splitter.class_weights, # Class Weights
    n_train_actual=data_splitter.n_train_actual,
    model=({'coef':model.coef_, 'intercept':model.intercept_})
);
model.save_eval(data_splitter, training_tracker)

Calculating tracker outputs and saving to /Users/josh/analysis/outputs/ROICaT/classification/classifier_train...
Saving results:  /Users/josh/analysis/outputs/ROICaT/classification/classifier_train/results_training.csv /Users/josh/analysis/outputs/ROICaT/classification/classifier_train/results_timing.json
self.tictoc={}
self.model={'coef': array([[-0.00031918, -0.0204018 , -0.00530634,  0.05814678,  0.02840021,
        -0.00309238,  0.01647533,  0.00807982,  0.0152512 , -0.04126018,
         0.00864298,  0.10024013, -0.02810845, -0.0082666 ,  0.01360647,
        -0.00253014,  0.0069552 ,  0.0040184 ,  0.04236059, -0.03728926,
         0.00942566,  0.02635264,  0.03377191, -0.00702802, -0.01558957,
        -0.00984542,  0.01580523,  0.05130067,  0.00331547, -0.00275494,
         0.03161863, -0.01197764,  0.03965136,  0.04142987,  0.01000913,
         0.04203891, -0.00431738, -0.00539592,  0.00861832,  0.00460757,
        -0.00074163,  0.07118501, -0.01004113, -0.04882371, -0.00381383,
 

<roicat.classification.classifier_util.TrainingTracker at 0x7fc8499b5070>

# Save Outputs

In [9]:
model._verbose = False

In [10]:
model.save(str((Path(directory_save) / 'model.pkl').resolve()), save_as_serializable_dict=True, allow_overwrite=False,)

print(f'Saved model fit results.')

Saved model fit results.
