# Imports

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import time

import numpy as np
import torch
from bnpm import file_helpers, optimization
import sklearn.utils.class_weight
from torch import nn, optim
from tqdm import tqdm
import sklearn.linear_model
import multiprocessing as mp

import roicat.classification.classifier_util as cu
import scipy.sparse
import roicat
import bnpm.h5_handling
from pathlib import Path
import shutil
import warnings
import umap

import matplotlib.pyplot as plt
import matplotlib.offsetbox
import json
import os
import natsort
from roicat import helpers
# from kymatio.torch import Scattering2D
import gc
import functools

# Initialize Parameters

In [None]:
directory_save = Path(r'/Users/josh/analysis/outputs/ROICaT/classification/classifier_inference')


In [None]:
INTEGER_MAX = np.iinfo(np.int64(0).dtype).max
directory_save.mkdir(exist_ok=True, parents=True)

# Data Importing

## Option 1: Use Data Results from Labeling Interactive Outputs

### Option 1.A: Labeling Interactive — Specify Directory / Filenames

In [None]:
filepath_labellingInteractive = str(Path('/Users/josh/analysis/outputs/ROICaT/classification/labeling_interactive/test.ROICaT.labeling.results.pkl').resolve())

### Option 1.B: Labeling Interactive — Find and Load Data

In [None]:
labelingInteractive = roicat.helpers.pickle_load(filepath_labellingInteractive)
data = roicat.data_importing.Data_roicat();
data.set_ROI_images([labelingInteractive['images'][labelingInteractive['labels']['index']]]);
data.set_class_labels([labelingInteractive['labels']['label']]);

## Option 2: Suite2p Data

### Option 2.A: Suite2p — Specify Directory / Filenames

In [None]:
dir_allOuterFolders = str(Path('/Users/josh/analysis/data/ROICaT/classification/stat_s2p').resolve())
pathSuffixToStat = 'stat.npy'
pathSuffixToOps = 'ops.npy'

### Option 2.B: Suite2p — Find and Load Data

In [None]:
paths_allStat, paths_allOps = helpers.find_paths_requireAll(
    dir_outer=dir_allOuterFolders,
    filenames=[
        pathSuffixToStat,
        pathSuffixToOps,
    ],
depth=4,)

data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_allStat,
    paths_opsFiles=paths_allOps,
    um_per_pixel=2.0,
    new_or_old_suite2p='new',
    out_height_width=[36, 36],
    type_meanImg='meanImgE',
    verbose=True,
);

# Pass Through Network to Generate Image Latents

In [None]:
# Neural network embedding distances
roinet = roicat.ROInet.ROInet_embedder(
    device=roicat.util.helpers.set_device('cuda:0'),
    dir_networkFiles=r"/Users/josh/analysis/models",
    download_method="check_local_first",
    download_url="https://osf.io/xwzhp/download",
    download_hash="134b170242141c26b0adbd9e0fd80d0e",
    forward_pass_version="head",
    verbose=True,
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,
    um_per_pixel=data.um_per_pixel,
    pref_plot=False,
    batchSize_dataloader=8,
    pinMemory_dataloader=True,
    numWorkers_dataloader=mp.cpu_count(),
    persistentWorkers_dataloader=True,
    prefetchFactor_dataloader=2,    
);
roinet.generate_latents();

gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

# Load Previously Fit Model and Predict

In [None]:
model = cu.LogisticRegression(
    path_load='/Users/josh/analysis/outputs/ROICaT/classification/classifier_train/model.pkl'
)
predictions = model.predict(roinet.latents)
predictionProbas = model.predict_proba(roinet.latents)

# Save Outputs

In [None]:
np.save(str((Path(directory_save) / 'labels_predicted.npy').resolve()), predictions, allow_pickle=True)
np.save(str((Path(directory_save) / 'labels_predictedProbas.npy').resolve()), predictionProbas, allow_pickle=True)

print(f'Saved model prediction results.')