# Imports

In [1]:
import sys
import time

import numpy as np
import torch
from bnpm import file_helpers, optimization
import sklearn.utils.class_weight
from torch import nn, optim
from tqdm import tqdm
import sklearn.linear_model
import multiprocessing as mp

import roicat.classification.classifier_util as cu
import scipy.sparse
import roicat
import bnpm.h5_handling
from pathlib import Path
import sys
from pathlib import Path
import shutil
import warnings

import matplotlib.pyplot as plt
import json


# Specify Initial Parameters

In [None]:
path_params = None # Path(r"")
# filepath_ROIs = r'/Users/josh/analysis/outputs/ROICaT/classification/01_labels/arr_ROIs.npy'
filepath_ROIs = None
filepath_latents = r'/Users/josh/analysis/outputs/ROICaT/classification/01_labels/arr_latents.npy'
filepath_labels = r'/Users/josh/analysis/outputs/ROICaT/classification/01_labels/arr_labels.npy'
directory_save = r'/Users/josh/analysis/outputs/ROICaT/classification/02_classifier_train'
testing = True

assert (filepath_ROIs is None) != (filepath_latents is None), 'Exactly one of filepath_ROIs or filepath_latents should be set'
assert Path(filepath_labels).exists(), 'File located at filepath_labels does not exist'

In [None]:
filepath_ROIs = Path(filepath_ROIs) if filepath_ROIs else None
filepath_latents = Path(filepath_latents) if filepath_latents else None
filepath_labels = Path(filepath_labels)
directory_save = Path(directory_save)
directory_save.mkdir(exist_ok=True, parents=True)

classifierTrainingRun_interim = {}

if path_params is not None:
    try:
        Path(str((directory_save).resolve())).mkdir(exist_ok=True, parents=True)
        shutil.copy2(path_params, str(Path(directory_save) / Path(path_params).name));
    except Exception as e:
        print(f'JZ: Error copying params to {directory_save}')
        print(e)

In [4]:
tic = time.time()
tictoc = {}
tictoc['start'] = time.time() - tic

params = file_helpers.json_load(str(Path(path_params).resolve())) if path_params is not None else None

In [5]:
if params is None:
    params = {
        "method": "simclr",
        "device": "cuda:0",
        "datatype": "raw_images",
        "hyperparameters_split": {
            "n_train": 50000,
            "test_size": 0.3
        },
        "paths": {
            "directory_github": "/Users/josh/analysis/github_repos/",
            "directory_simclrModel": "/Users/josh/analysis//models",
            "filepath_umapModel": None,
        },
        "hyperparameters_training_classifier": {
            "num_transform_copies": 80,
            "solver": "lbfgs",
            "fit_intercept": True,
            "max_iter": 20000,
            "C": 0.01,
            "tol": 0.001,
            "simclrModel_download_url": "https://osf.io/xwzhp/download",
            "simclrModel_download_hash": "134b170242141c26b0adbd9e0fd80d0e"
        },
        "run_umap": True,
    }

roicat.util.helpers.set_device(params['device'])

# Import Datasets

In [None]:
data = roicat.data_importing.Data_roicat(verbose=True)
data.load(str(Path(directory_data) / 'classification_data.pkl'))



Loaded Data_roicat object from /Users/josh/analysis/outputs/ROICaT/classification/00_data_ingestion/classification_data.pkl.


# Pass Through Network

In [7]:
if filepath_ROIs:
    tictoc['imported_data'] = time.time() - tic

    ROI_images_rescaled = [roicat.ROInet.ROInet_embedder.resize_ROIs(rois, params['hyperparameters_data']['um_per_pixel']) for rois in data.ROI_images]

    # Initialize concatendated data
    ROI_images_init = np.concatenate(data.ROI_images, axis=0).astype(np.float32)
    ROI_images_init_rescaled = np.concatenate(ROI_images_rescaled, axis=0).astype(np.float32)
    _labels_init = np.concatenate(data.class_labels, axis=0).astype(int).copy()

    # Perform data cleaning
    idx_violations = (np.isnan(ROI_images_init_rescaled.sum(axis=(1,2)))*1 + (np.sum(ROI_images_init_rescaled, axis=(1,2))==0)*1 + np.isnan(_labels_init)) != 0
    print('Number of idx_violations: ', idx_violations.sum(), ' out of ', len(idx_violations), ' total ROIs.')
    print('Located at: ', np.where(idx_violations)[0])
    print('Discarding these ROIs...')

    ROI_images_filt = ROI_images_init_rescaled[~idx_violations]
    _labels_filt = _labels_init[~idx_violations]

    if testing:
        ROI_images_filt = ROI_images_filt[:100]
        _labels_filt = _labels_filt[:100]

    classifierTrainingRun_interim['ROI_images_filt'] = ROI_images_filt

    ## No remapping for generate preproc
    # labels_remapped = cu.remap_labels(labels_filt, params['label_remapping'])

    tictoc['cleaned_data'] = time.time() - tic

    print(f'Shape of ROI_images_filt: {ROI_images_filt.shape}, shape of labels_remapped: {_labels_filt.shape}')

    transforms_final_all = cu.get_transforms(params['hyperparameters_augmentations_all'], scripted=True)
    dataset_all = roicat.ROInet.dataset_simCLR(
            X=torch.as_tensor(ROI_images_filt, device='cpu', dtype=torch.float32),
            y=torch.as_tensor(np.zeros((ROI_images_filt.shape[0])), device='cpu', dtype=torch.float32),
            n_transforms=1,
            class_weights=np.array([1]),
            transform=transforms_final_all, # *Use WarpPoints
            DEVICE='cpu',
            dtype_X=torch.float32,
        )
    dataloader_all = torch.utils.data.DataLoader( 
            dataset_all,
            batch_size=64,
            shuffle=False,
            drop_last=False,
            pin_memory=False,
            num_workers=0,#mp.cpu_count(),
            persistent_workers=False,
            prefetch_factor=2,
    )

    roinet = roicat.ROInet.ROInet_embedder(
        device=params['device'],
        dir_networkFiles=params['paths']['directory_simclrModel'],
        download_method='check_local_first',
        forward_pass_version='head',
        download_url=params['hyperparameters_training_simclr']['simclrModel_download_url'],
        download_hash=params['hyperparameters_training_simclr']['simclrModel_download_hash'],
        verbose=True,
    )

    print(f'Extracting transformed images from dataloaders, passing through roinet model, and saving to {directory_save}...')

    features_all, _labels_all, _idx_all, _sample_all = cu.extract_with_dataloader(
        dataloader_all,
        model=roinet.net,
        num_copies=1,
        device=params['device'],
    )

    classifierTrainingRun_interim['features_all'] = features_all
    print(f'Unaugmented run completed.')

else:
    features_all = np.load(filepath_latents)

labels_all = np.load(filepath_labels)

# Train / Validation / Test Split Data, Hyperparameter Tune on Validation Set, and Fit Model

In [8]:
INTEGER_MAX = np.iinfo(np.int64(0).dtype).max

# TODO: JZ, IMPLEMENT AS LOOP FOR OPTUNA FOR HYPERPARAMETER TUNING
print('Splitting data...')
# Create data splitting object for stratified sampling into train and test sets (as well as downsampling)
data_split_val = cu.Datasplit(
    features=features_all,
    labels=labels_all,
    n_train=INTEGER_MAX,
    test_size=params['hyperparameters_split']['val_size'],
)
data_split_test = cu.Datasplit(
    features=data_split_val.features_train,
    labels=data_split_val.labels_train,
    n_train=INTEGER_MAX,
    test_size = params['hyperparameters_split']['test_size']/(1 - params['hyperparameters_split']['val_size']),
)

print('Creating X and y matrices for training data...')
X_train = data_split_test.features_train
y_train = data_split_test.labels_train

X_val = data_split_val.features_val
y_val = data_split_val.labels_val

X_test = data_split_val.features_val
y_test = data_split_val.labels_val

y_train = y_train.astype(int)
y_val = y_val.astype(int)
y_test = y_test.astype(int)

tictoc['loaded_data'] = time.time() - tic
print('Calculating class weights...')
num_classes = len(np.unique(labels))
class_weights = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(labels_all), y=labels_all)

labels_train = y_train.reshape(-1) # np.stack([data_split.labels_train_subset]*latents_augmented.shape[1], axis=1).reshape(-1)
features_train = X_train.reshape(-1, X_train.shape[-1]) # latents_augmented[data_split.features_train_subset].reshape(-1, latents_augmented.shape[2])

labels_val = y_val.reshape(-1) # data_split.labels_val
features_val = X_val.reshape(-1, X_val.shape[-1]) # latents_unaugmented[data_split.features_val]

labels_test = y_test.reshape(-1) # data_split.labels_val
features_test = X_test.reshape(-1, X_test.shape[-1]) # latents_unaugmented[data_split.features_val]

n_train_actual = X_train.shape[0]
n_val_actual = X_val.shape[0]
n_test_actual = X_test.shape[0]

tictoc['splitted_data'] = time.time() - tic

print(f'Fitting model to data of dimensions: X: {X_train.shape}, y: {y_train.shape}...')
# Create lenet model, associated optimizer, loss function, and training tracker
model = sklearn.linear_model.LogisticRegression(
   solver=params['hyperparameters_training_simclr']['solver'],
   fit_intercept=params['hyperparameters_training_simclr']['fit_intercept'],
   max_iter=params['hyperparameters_training_simclr']['max_iter'],
   C=params['hyperparameters_training_simclr']['C'],
   class_weight={iClassWeight:classWeight for iClassWeight, classWeight in enumerate(class_weights)},
#    class_weight=class_weights,
)
model.fit(features_train, labels_train)

print(f'Calculating tracker outputs and saving to {directory_save}...')
training_tracker = cu.TrainingTracker(
    directory_save=directory_save,
    class_weights=class_weights, # Class Weights
    tictoc=tictoc, # Time Tracker
    n_train_actual=n_train_actual,
    model=({'coef':model.coef_, 'intercept':model.intercept_})
)

y_train_preds = model.predict(features_train).astype(int)
y_train_true = labels_train
y_val_preds = model.predict(features_val).astype(int)
y_val_true = labels_val

# Save training loop results from current epoch for training set
training_tracker.add_accuracy(0, 'accuracy_training', y_train_true, y_train_preds) # Generating training loss
training_tracker.add_confusion_matrix(0, 'confusionMatrix_training', y_train_true, y_train_preds) # Generating confusion matrix

# Save training loop results from current epoch for validation set
training_tracker.add_accuracy(0, 'accuracy_val', y_val_true, y_val_preds) # Generating validation accuracy
training_tracker.add_confusion_matrix(0, 'confusionMatrix_val', y_val_true, y_val_preds) # Generating validation confusion matrix

tictoc[f'completed_training_in_{0}'] = time.time() - tic

training_tracker.save_results() # TODO: JZ, ADJUST RESULTS SAVING TO SAVE CONFUSION MATRICES AS NOT A DATAFRAME CSV
training_tracker.print_results()

model_save = {
<<<<<<< local
    'intercept_': model.intercept_,
    'coef_': model.coef_,
    'classes_': model.classes_,
=======
    'intercept': model.intercept_,
    'coefs': model.coef_,
>>>>>>> remote
}

# Save Outputs

In [None]:
np.save(str((Path(directory_save) / 'model.npy').resolve()), model_save, allow_pickle=True)

classifierTrainingRun_interim['params_prespecified'] = params
with open(str((Path(directory_save) / 'classifierTrainingRun_interim.pkl').resolve()), 'wb') as f:
    np.save(
        file=f,
        arr=classifierTrainingRun_interim,
        allow_pickle=True
    )

print(f'Saved model fit results.')

Saved model fit results.
