# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import time

import numpy as np
import torch
from bnpm import file_helpers, optimization
import sklearn.utils.class_weight
from torch import nn, optim
from tqdm import tqdm
import sklearn.linear_model
import multiprocessing as mp

import roicat.classification.classifier_util as cu
import scipy.sparse
import roicat
import bnpm.h5_handling
from pathlib import Path
import shutil
import warnings
import umap

import matplotlib.pyplot as plt
import matplotlib.offsetbox
import json
import os
import natsort
from roicat import helpers
# from kymatio.torch import Scattering2D
import gc
import functools

2023-06-22 15:46:49.743423: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Initialize Parameters

In [2]:
directory_save = Path(r'/Users/josh/analysis/outputs/ROICaT/classification/classifier_inference')


In [3]:
INTEGER_MAX = np.iinfo(np.int64(0).dtype).max
directory_save.mkdir(exist_ok=True, parents=True)

# Data Importing

## Option 1: Use Data Results from Labeling Interactive Outputs

### Option 1.A: Labeling Interactive — Specify Directory / Filenames

In [7]:
filepath_labellingInteractive = str(Path('/Users/josh/analysis/outputs/ROICaT/classification/labeling_interactive/test.ROICaT.labeling.results.pkl').resolve())

### Option 1.B: Labeling Interactive — Find and Load Data

In [8]:
labelingInteractive = roicat.helpers.pickle_load(filepath_labellingInteractive)
data = roicat.data_importing.Data_roicat();
data.set_ROI_images([labelingInteractive['images']]);

Starting: Importing ROI images
Completed: Imported 1 sessions. Each session has [4898] ROIs. Total number of ROIs is 4898. The um_per_pixel is 1.0 um per pixel.




  4%|▍         | 25/613 [00:50<12:02,  1.23s/it]

## Option 2: Suite2p Data

### Option 2.A: Suite2p — Specify Directory / Filenames

In [None]:
dir_allOuterFolders = str(Path('/Users/josh/analysis/data/ROICaT/classification/stat_s2p').resolve())
pathSuffixToStat = 'stat.npy'
pathSuffixToOps = 'ops.npy'

### Option 2.B: Suite2p — Find and Load Data

In [None]:
paths_allStat, paths_allOps = helpers.find_paths_requireAll(
    dir_outer=dir_allOuterFolders,
    filenames=[
        pathSuffixToStat,
        pathSuffixToOps,
    ],
depth=4,)

data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_allStat,
    paths_opsFiles=paths_allOps,
    um_per_pixel=2.0,
    new_or_old_suite2p='new',
    out_height_width=[36, 36],
    type_meanImg='meanImgE',
    verbose=True,
);



Starting: Importing FOV images from ops files
Completed: Set FOV_height and FOV_width successfully.
Completed: Imported 1 FOV images.
Completed: Set FOV_images for 1 sessions successfully.
Importing spatial footprints from stat files.


100%|██████████| 1/1 [00:02<00:00,  2.26s/it]


Imported 1 sessions of spatial footprints into sparse arrays.
Completed: Set spatialFootprints for 1 sessions successfully.
Completed: Created session_bool.
Completed: Created centroids.
Staring: Creating centered ROI images from spatial footprints...
Completed: Created ROI images.
Starting: Importing class labels
Labels and ROI Images match in shapes: Class labels and ROI images have the same number of sessions and the same number of ROIs in each session.
Completed: Imported labels for 1 sessions. Each session has [4898] class labels. Total number of class labels is 4898.


## Option 3: Caiman Data

### Option 3.A: Caiman — Specify Directory / Filenames

In [None]:
## TODO: Implement Caiman Data Importing

### Option 3.B: Caiman — Find and Load Data

In [None]:
## TODO: Implement Caiman Data Importing

## Option 4: Sparse Raw Image Data

### Option 4.A: Raw Images — Specify Directory / Filenames

In [None]:
dir_allOuterFolders = str(Path('/Users/josh/analysis/data/ROICaT/classification/raw_images').resolve())
pathSuffixToROIImages = 'sf_concat_rs_sparse.npz'

### Option 4.B: Raw Images — Find and Load Data

In [None]:
paths_allROIImages, paths_allLabels = helpers.find_paths_requireAll(
    dir_outer=dir_allOuterFolders,
    filenames=[
        pathSuffixToROIImages,
    ],
depth=4,)

list_ROIImages_sparse = [scipy.sparse.load_npz(filepath_data_rawImages) for filepath_data_rawImages in paths_allROIImages]
class_labels = [np.load(filepath_data_labels).astype(int) for filepath_data_labels in paths_allLabels]

data = roicat.data_importing.Data_roicat(verbose=True)
data.set_ROI_images(
    ROI_images=[sf.A.reshape(sf.shape[0], 36, 36) for sf in list_ROIImages_sparse],
    um_per_pixel=2.0,
)
data.set_class_labels(class_labels)



Starting: Importing ROI images
Completed: Imported 1 sessions. Each session has [68382] ROIs. Total number of ROIs is 68382. The um_per_pixel is 2.0 um per pixel.
Starting: Importing class labels
Labels and ROI Images match in shapes: Class labels and ROI images have the same number of sessions and the same number of ROIs in each session.
Completed: Imported labels for 1 sessions. Each session has [68382] class labels. Total number of class labels is 68382.


# Pass Through Network to Generate Image Latents

In [9]:
# Neural network embedding distances
roinet = roicat.ROInet.ROInet_embedder(
    device=roicat.util.helpers.set_device('cuda:0'),
    dir_networkFiles=r"/Users/josh/analysis/models",
    download_method="check_local_first",
    download_url="https://osf.io/xwzhp/download",
    download_hash="134b170242141c26b0adbd9e0fd80d0e",
    forward_pass_version="head",
    verbose=True,
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,
    um_per_pixel=data.um_per_pixel,
    pref_plot=False,
    batchSize_dataloader=8,
    pinMemory_dataloader=True,
    numWorkers_dataloader=mp.cpu_count(),
    persistentWorkers_dataloader=True,
    prefetchFactor_dataloader=2,    
);
roinet.generate_latents();

gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

devices available: []
no GPU available. Using CPU.
File already exists locally: /Users/josh/analysis/models/ROInet.zip




Hash of local file matches provided hash_hex.
Extracting /Users/josh/analysis/models/ROInet.zip to /Users/josh/analysis/models.
Completed zip extraction.
['/Users/josh/analysis/models/ROInet_classification_20220902', '/Users/josh/analysis/models/ROInet_classification_20220902/ConvNext_tiny__1_0_best__simCLR_wPCA.pth', '/Users/josh/analysis/models/ROInet_classification_20220902/model.py', '/Users/josh/analysis/models/ROInet_classification_20220902/classifier.pkl', '/Users/josh/analysis/models/ROInet_classification_20220902/params.json', '/Users/josh/analysis/models/ROInet_classification_20220902/__pycache__', '/Users/josh/analysis/models/ROInet_classification_20220902/__pycache__/model.cpython-39.pyc']
Imported model from /Users/josh/analysis/models/ROInet_classification_20220902/model.py
Loaded params_model from /Users/josh/analysis/models/ROInet_classification_20220902/params.json




Generated network using params_model
Loaded state_dict into network from /Users/josh/analysis/models/ROInet_classification_20220902/ConvNext_tiny__1_0_best__simCLR_wPCA.pth
Loaded network onto device cpu
Starting: resizing ROIs
Completed: resizing ROIs
Defined image transformations: Sequential(
  (0): ScaleDynamicRange(scaler_bounds=(0, 1))
  (1): Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
  (2): TileChannels(dim=0)
)
Defined dataset
Defined dataloader
starting: running data through network




# Load Previously Fit Model

In [None]:
model = cu.Classifier(
    sklearn.linear_model.LogisticRegression,
    path_load='/Users/josh/analysis/outputs/ROICaT/classification/02_classifier_train/model.pkl'
)
predictions = model.predict(features_all)
predictionProbas = model.predict_proba(features_all)

# Save Outputs

In [None]:
np.save(str((Path(directory_save) / 'labels_predicted.npy').resolve()), predictions, allow_pickle=True)
np.save(str((Path(directory_save) / 'labels_predictedProbas.npy').resolve()), predictionProbas, allow_pickle=True)

print(f'Saved model prediction results.')