# Imports

In [1]:
import sys
import time

import numpy as np
import torch
from bnpm import file_helpers, optimization
import sklearn.utils.class_weight
from torch import nn, optim
from tqdm import tqdm
import sklearn.linear_model
import multiprocessing as mp

import roicat.classification.classifier_util as cu
import scipy.sparse
import roicat
import bnpm.h5_handling
from pathlib import Path
import shutil
import warnings
import umap

import matplotlib.pyplot as plt
import matplotlib.offsetbox
import json
import os
import natsort
from roicat import helpers
# from kymatio.torch import Scattering2D
import gc
import functools

# Specify Initial Parameters

In [None]:
path_params = None # Path(r"")
# directory_data = r'/Users/josh/analysis/outputs/ROICaT/classification/00_data_ingestion'
directory_data = r'/Users/josh/analysis/outputs/ROICaT/classification/labeling_interactive/feature_label_combo.npy'
directory_save = r'/Users/josh/analysis/outputs/ROICaT/classification/02_classifier_train'
testing = True
save_ROIs = True
save_latents = True

In [3]:
dir_allOuterFolders = str(Path('/Users/josh/analysis/data/ROICaT/classification/stat_s2p').resolve())

pathSuffixToStat = 'stat.npy'
pathSuffixToOps = 'ops.npy'
pathSuffixToLabels = 'labels_round2_sesh2.npy'

paths_allStat = helpers.find_paths(
    dir_outer=dir_allOuterFolders,
    reMatch=pathSuffixToStat,
    depth=4,
)
paths_allOps = helpers.find_paths(
    dir_outer=dir_allOuterFolders,
    reMatch=pathSuffixToOps,
    depth=4,
)
paths_allLabels = helpers.find_paths(
    dir_outer=dir_allOuterFolders,
    reMatch=pathSuffixToLabels,
    depth=4,
)

display(paths_allStat)
display(paths_allOps)

#Import data
data = roicat.data_importing.Data_suite2p(
    paths_statFiles=paths_allStat,
    paths_opsFiles=paths_allOps,
    class_labels=paths_allLabels,
    um_per_pixel=2.0,
    new_or_old_suite2p='new',
    out_height_width=[36, 36],
    type_meanImg='meanImgE',
    verbose=True,
);

# Data Importing — Hand Labeled Inputs

In [None]:
filepath_labellingInteractive = str(Path('/Users/josh/analysis/outputs/ROICaT/classification/labeling_interactive/test.ROICaT.labeling.results.pkl').resolve())
labelingInteractive = roicat.helpers.pickle_load(filepath_labellingInteractive)
category_mappings, codes_categories = np.unique(labelingInteractive['labels']['label'], return_inverse=True)

In [None]:
# Import data
data = roicat.data_importing.Data_roicat();
data.set_ROI_images([labelingInteractive['images'][labelingInteractive['labels']['index']]]);
data.set_class_labels([codes_categories]);

# Pass Through Network

In [5]:
# Neural network embedding distances
roinet = roicat.ROInet.ROInet_embedder(
    device=roicat.util.helpers.set_device('cuda:0'),
    dir_networkFiles=r"/Users/josh/analysis/models",
    download_method="check_local_first",
    download_url="https://osf.io/xwzhp/download",
    download_hash="134b170242141c26b0adbd9e0fd80d0e",
    forward_pass_version="head",
    verbose=True,
)

roinet.generate_dataloader(
    ROI_images=data.ROI_images,
    um_per_pixel=data.um_per_pixel,
    pref_plot=False,
    batchSize_dataloader=8,
    pinMemory_dataloader=True,
    numWorkers_dataloader=mp.cpu_count(),
    persistentWorkers_dataloader=True,
    prefetchFactor_dataloader=2,    
);

# roicat.visualization.display_toggle_image_stack(roinet.ROI_images_rs)

roinet.generate_latents();

gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

# Train / Validation / Test Split Data, Hyperparameter Tune on Validation Set, and Fit Model

In [None]:
INTEGER_MAX = np.iinfo(np.int64(0).dtype).max

# TODO: JZ, IMPLEMENT AS LOOP FOR OPTUNA FOR HYPERPARAMETER TUNING
print('Splitting data...')
# Create data splitting object for stratified sampling into train and test sets (as well as downsampling)
data_splitter = cu.Datasplit(
    features=roinet.latents,
    labels=data.class_labels,
    n_train=INTEGER_MAX,
    val_size=0.2,
    test_size=0.2,
)

print('Calculating class weights...')
num_classes = len(np.unique(data.class_labels))
class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(data.class_labels), y=data.class_labels)

print(f'Fitting model to data of dimensions: X: {data_splitter.features_train}, y: {data_splitter.labels_train}...')
# Create lenet model, associated optimizer, loss function, and training tracker
model = sklearn.linear_model.LogisticRegression(
   solver='lbfgs',
   fit_intercept=True,
   max_iter=10000,
   C=1e5,
   class_weight={iClassWeight:classWeight for iClassWeight, classWeight in enumerate(class_weights)},
#    class_weight=class_weights,
)
model.fit(data_splitter.features_train, data_splitter.labels_train)

print(f'Calculating tracker outputs and saving to {directory_save}...')
training_tracker = cu.TrainingTracker(
    directory_save=directory_save,
    class_weights=class_weights, # Class Weights
    n_train_actual=data_splitter.n_train_actual,
    model=({'coef':model.coef_, 'intercept':model.intercept_})
)

y_train_preds = model.predict(data_splitter.features_train).astype(int)
y_train_true = data_splitter.labels_train
y_val_preds = model.predict(data_splitter.features_val).astype(int)
y_val_true = data_splitter.labels_val

# Save training loop results from current epoch for training set
training_tracker.add_accuracy(0, 'accuracy_training', y_train_true, y_train_preds) # Generating training loss
training_tracker.add_confusion_matrix(0, 'confusionMatrix_training', y_train_true, y_train_preds) # Generating confusion matrix

# Save training loop results from current epoch for validation set
training_tracker.add_accuracy(0, 'accuracy_val', y_val_true, y_val_preds) # Generating validation accuracy
training_tracker.add_confusion_matrix(0, 'confusionMatrix_val', y_val_true, y_val_preds) # Generating validation confusion matrix

training_tracker.save_results() # TODO: JZ, ADJUST RESULTS SAVING TO SAVE CONFUSION MATRICES AS NOT A DATAFRAME CSV
training_tracker.print_results()

model_save = {
    'intercept_': model.intercept_,
    'coef_': model.coef_,
    'classes_': model.classes_,
}

Splitting data...
Creating X and y matrices for training data...
Calculating class weights...
Fitting model to data of dimensions: X: (60, 100), y: (60,)...
Calculating tracker outputs and saving to /Users/josh/analysis/outputs/ROICaT/classification/02_classifier_train...
Saving results:  /Users/josh/analysis/outputs/ROICaT/classification/02_classifier_train/results_training.csv /Users/josh/analysis/outputs/ROICaT/classification/02_classifier_train/results_timing.json
self.tictoc={'start': 4.887580871582031e-05, 'loaded_data': 0.6734569072723389, 'splitted_data': 0.6762540340423584, 'completed_training_in_0': 0.7153449058532715}
self.model={'coef': array([[-5.85583578e-03, -2.56761415e-03, -2.85717620e-04,
         1.39980724e-03, -4.69398613e-03,  1.39073784e-03,
        -4.09502665e-03, -6.19066812e-03, -2.53216414e-03,
        -5.05807930e-04, -9.05063450e-04,  1.13785403e-03,
        -7.10335497e-03, -3.48018141e-04, -4.56430364e-03,
        -4.09421531e-03, -9.48537968e-04,  2.061

# Save Outputs

In [None]:
np.save(str((Path(directory_save) / 'model.npy').resolve()), model_save, allow_pickle=True)
# with open(str((Path(directory_save) / 'classifierTrainingRun_interim.pkl').resolve()), 'wb') as f:
#     np.save(
#         file=f,
#         arr=classifierTrainingRun_interim,
#         allow_pickle=True
#     )

print(f'Saved model fit results.')