# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import time

import numpy as np
import torch
from bnpm import file_helpers, optimization
import sklearn.utils.class_weight
from torch import nn, optim
from tqdm import tqdm
import sklearn.linear_model
import multiprocessing as mp

import roicat.classification.classifier_util as cu
import scipy.sparse
import roicat
import bnpm.h5_handling
from pathlib import Path
import shutil
import warnings
import umap

import matplotlib.pyplot as plt
import matplotlib.offsetbox
import json
import os
import natsort
from roicat import helpers
# from kymatio.torch import Scattering2D
import gc
import functools

2023-06-23 12:27:38.358554: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Initialize Parameters

In [2]:
directory_save = Path(r'/Users/josh/analysis/outputs/ROICaT/classification/classifier_inference')


In [3]:
INTEGER_MAX = np.iinfo(np.int64(0).dtype).max
directory_save.mkdir(exist_ok=True, parents=True)

# Data Importing

## Option 1: Use Data Results from Labeling Interactive Outputs

### Option 1.A: Labeling Interactive — Specify Directory / Filenames

In [4]:
filepath_labellingInteractive = str(Path('/Users/josh/analysis/outputs/ROICaT/classification/labeling_interactive/test.ROICaT.labeling.results.pkl').resolve())

### Option 1.B: Labeling Interactive — Find and Load Data

In [5]:
labelingInteractive = roicat.helpers.pickle_load(filepath_labellingInteractive)
data = roicat.data_importing.Data_roicat();
data.set_ROI_images([labelingInteractive['images'][labelingInteractive['labels']['index']]]);
data.set_class_labels([labelingInteractive['labels']['label']]);

Starting: Importing ROI images
Completed: Imported 1 sessions. Each session has [502] ROIs. Total number of ROIs is 502. The um_per_pixel is 1.0 um per pixel.
Starting: Importing class labels
Labels and ROI Images match in shapes: Class labels and ROI images have the same number of sessions and the same number of ROIs in each session.
Completed: Imported labels for 1 sessions. Each session has [502] class labels. Total number of class labels is 502.




## Option 2: Suite2p Data

### Option 2.A: Suite2p — Specify Directory / Filenames

In [None]:
dir_allOuterFolders = str(Path('/Users/josh/analysis/data/ROICaT/classification/stat_s2p').resolve())
pathSuffixToStat = 'stat.npy'
pathSuffixToOps = 'ops.npy'

### Option 2.B: Suite2p — Find and Load Data

In [None]:
paths_allStat, paths_allOps = helpers.find_paths_requireAll(
    dir_outer=dir_allOuterFolders,
    filenames=[
        pathSuffixToStat,
        pathSuffixToOps,
    ],
depth=4,)

data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_allStat,
    paths_opsFiles=paths_allOps,
    um_per_pixel=2.0,
    new_or_old_suite2p='new',
    out_height_width=[36, 36],
    type_meanImg='meanImgE',
    verbose=True,
);

# Pass Through Network to Generate Image Latents

In [6]:
# Neural network embedding distances
roinet = roicat.ROInet.ROInet_embedder(
    device=roicat.util.helpers.set_device('cuda:0'),
    dir_networkFiles=r"/Users/josh/analysis/models",
    download_method="check_local_first",
    download_url="https://osf.io/xwzhp/download",
    download_hash="134b170242141c26b0adbd9e0fd80d0e",
    forward_pass_version="head",
    verbose=True,
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,
    um_per_pixel=data.um_per_pixel,
    pref_plot=False,
    batchSize_dataloader=8,
    pinMemory_dataloader=True,
    numWorkers_dataloader=mp.cpu_count(),
    persistentWorkers_dataloader=True,
    prefetchFactor_dataloader=2,    
);
roinet.generate_latents();

gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

devices available: []
no GPU available. Using CPU.
File already exists locally: /Users/josh/analysis/models/ROInet.zip




Hash of local file matches provided hash_hex.
Extracting /Users/josh/analysis/models/ROInet.zip to /Users/josh/analysis/models.
Completed zip extraction.
['/Users/josh/analysis/models/ROInet_classification_20220902', '/Users/josh/analysis/models/ROInet_classification_20220902/ConvNext_tiny__1_0_best__simCLR_wPCA.pth', '/Users/josh/analysis/models/ROInet_classification_20220902/model.py', '/Users/josh/analysis/models/ROInet_classification_20220902/classifier.pkl', '/Users/josh/analysis/models/ROInet_classification_20220902/params.json', '/Users/josh/analysis/models/ROInet_classification_20220902/__pycache__', '/Users/josh/analysis/models/ROInet_classification_20220902/__pycache__/model.cpython-39.pyc']
Imported model from /Users/josh/analysis/models/ROInet_classification_20220902/model.py
Loaded params_model from /Users/josh/analysis/models/ROInet_classification_20220902/params.json




Generated network using params_model
Loaded state_dict into network from /Users/josh/analysis/models/ROInet_classification_20220902/ConvNext_tiny__1_0_best__simCLR_wPCA.pth
Loaded network onto device cpu
Starting: resizing ROIs
Completed: resizing ROIs
Defined image transformations: Sequential(
  (0): ScaleDynamicRange(scaler_bounds=(0, 1))
  (1): Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
  (2): TileChannels(dim=0)
)
Defined dataset
Defined dataloader
starting: running data through network


100%|██████████| 63/63 [01:14<00:00,  1.19s/it]


completed: running data through network


# Load Previously Fit Model and Predict

In [11]:
model = cu.LogisticRegression(
    path_load='/Users/josh/analysis/outputs/ROICaT/classification/classifier_train/model.pkl'
)
predictions = model.predict(roinet.latents)
predictionProbas = model.predict_proba(roinet.latents)

/Users/josh/analysis/outputs/ROICaT/classification/classifier_train/model.pkl
{'model_dict': {}, '_coef': array([[-7.66807350e-02, -8.86173794e-03,  7.47228536e-02,
        -4.74597642e-02, -3.82585067e-02, -3.33099461e-02,
         4.56421634e-02,  3.20266647e-02,  4.39130199e-02,
        -8.58321703e-02,  1.17918204e-01,  1.11814840e-01,
        -5.38246436e-02, -3.09600362e-02,  5.01042749e-02,
        -8.88464003e-02,  1.15717043e-02, -1.42116612e-02,
         4.69384194e-02, -7.90472031e-02, -5.17846670e-02,
         2.01918755e-02, -9.30674802e-03, -3.32344797e-02,
         8.16221535e-03, -5.24645768e-02,  1.43414412e-02,
         2.76007513e-02, -1.90099800e-02,  3.46021829e-02,
        -2.67180506e-02,  1.86024347e-02, -2.34790578e-02,
        -1.84597610e-02, -5.60442429e-02,  4.75524539e-02,
         1.34783250e-02,  2.33763647e-02,  3.02218511e-02,
        -8.46563465e-02, -5.63652591e-02,  7.26132236e-02,
         7.13524931e-02,  3.09893099e-02, -2.47231707e-02,
         

# Save Outputs

In [12]:
np.save(str((Path(directory_save) / 'labels_predicted.npy').resolve()), predictions, allow_pickle=True)
np.save(str((Path(directory_save) / 'labels_predictedProbas.npy').resolve()), predictionProbas, allow_pickle=True)

print(f'Saved model prediction results.')

Saved model prediction results.
