# Imports

In [1]:
import sys
import time

import numpy as np
import torch
from bnpm import file_helpers, optimization
import sklearn.utils.class_weight
from torch import nn, optim
from tqdm import tqdm
import sklearn.linear_model
import multiprocessing as mp

import roicat.classification.classifier_util as cu
import scipy.sparse
import roicat
import bnpm.h5_handling
from pathlib import Path
import sys
from pathlib import Path
import shutil
import warnings

import matplotlib.pyplot as plt
import json


# Specify Initial Parameters

In [2]:
path_params = None # Path(r"")
# filepath_ROIs = r'/Users/josh/analysis/outputs/ROICaT/classification/01_labels/arr_ROIs.npy'
filepath_ROIs = None
filepath_latents = r'/Users/josh/analysis/outputs/ROICaT/classification/01_labels/arr_latents.npy'
filepath_model = r'/Users/josh/analysis/outputs/ROICaT/classification/02_classifier_train/model.npy'

assert (filepath_ROIs is None) != (filepath_latents is None), 'Exactly one of filepath_ROIs or filepath_latents should be set'
assert Path(filepath_model).exists(), 'File located at filepath_model does not exist'

directory_save = '/Users/josh/analysis/outputs/ROICaT/classification/03_classifier_inference'
testing = True

In [3]:
filepath_ROIs = Path(filepath_ROIs) if filepath_ROIs else None
filepath_latents = Path(filepath_latents) if filepath_latents else None
filepath_model = Path(filepath_model)
directory_save = Path(directory_save)
directory_save.mkdir(exist_ok=True, parents=True)

classifierInferenceRun_interim = {}

if path_params is not None:
    try:
        Path(str((directory_save).resolve())).mkdir(exist_ok=True, parents=True)
        shutil.copy2(path_params, str(Path(directory_save) / Path(path_params).name));
    except Exception as e:
        print(f'JZ: Error copying params to {directory_save}')
        print(e)

In [4]:
tic = time.time()
tictoc = {}
tictoc['start'] = time.time() - tic

params = file_helpers.json_load(str(Path(path_params).resolve())) if path_params is not None else None
model = np.load(filepath_model, allow_pickle=True)

In [5]:
if params is None:
    params = {
        "method": "simclr",
        "device": "cuda:0",
        "datatype": "raw_images",
        "hyperparameters_split": {
            "n_train": 50000,
            "test_size": 0.3
        },
        "paths": {
            "directory_github": "/Users/josh/analysis/github_repos/",
            "directory_simclrModel": "/Users/josh/analysis//models",
            "filepath_umapModel": None,
        },
        "hyperparameters_training_classifier": {
            "num_transform_copies": 80,
            "solver": "lbfgs",
            "fit_intercept": True,
            "max_iter": 20000,
            "C": 0.01,
            "tol": 0.001,
            "simclrModel_download_url": "https://osf.io/xwzhp/download",
            "simclrModel_download_hash": "134b170242141c26b0adbd9e0fd80d0e"
        },
        "run_umap": True,
    }

roicat.util.helpers.set_device(params['device'])

devices available: []
no GPU available. Using CPU.


'cpu'

# Import Datasets (and Pass Through Network)

In [6]:
directory_model = str(Path(params['paths']['directory_model']).resolve()) if 'directory_model' in params['paths'] else None

if params['datatype'] == "stat_s2p":
    assert 'filename_stat' in params['paths'] and 'filename_ops' in params['paths'], 'JZ: The suite2p params.json file must include paths.filename_stat and paths.filename_ops for stat_s2p datatype'
    filepath_data_stat = str((Path(params['paths']['directory_data']) / params['paths']['filename_stat']).resolve())
    filepath_data_ops = str((Path(params['paths']['directory_data']) / params['paths']['filename_ops']).resolve())

    # Create data importing object to import suite2p data
    data = roicat.data_importing.Data_suite2p(
        paths_statFiles=[filepath_data_stat],
        paths_opsFiles=[filepath_data_ops],
        um_per_pixel=params['hyperparameters_data']['um_per_pixel'],
        new_or_old_suite2p=params['hyperparameters_data']['new_or_old_suite2p'],
        out_height_width=params['hyperparameters_data']['out_height_width'],
        type_meanImg=params['hyperparameters_data']['type_meanImg'],
        FOV_images=params['hyperparameters_data']['FOV_images'],
        verbose=params['hyperparameters_data']['verbose'],
    )
elif params['datatype'] == "raw_images":
    assert 'filename_rawImages' in params['paths'], 'JZ: The suite2p params.json file must include paths.filename_rawImages for raw_images datatype'
    filepath_data_rawImages = str((Path(params['paths']['directory_data']) / params['paths']['filename_rawImages']).resolve())

    sf = scipy.sparse.load_npz(filepath_data_rawImages)

    data = roicat.data_importing.Data_roicat(verbose=True)
    data.set_ROI_images(ROI_images=[sf.A.reshape(sf.shape[0], 36, 36)], um_per_pixel=params['hyperparameters_data']['um_per_pixel'])
else:
    raise ValueError(f"Invalid datatype for simclr: {params['datatype']}")

AssertionError: JZ: The suite2p params.json file must include paths.filename_rawImages for raw_images datatype

In [None]:
if filepath_ROIs:
    tictoc['imported_data'] = time.time() - tic

    ROI_images_rescaled = [roicat.ROInet.ROInet_embedder.resize_ROIs(rois, params['hyperparameters_data']['um_per_pixel']) for rois in data.ROI_images]

    # Initialize concatendated data
    ROI_images_init = np.concatenate(data.ROI_images, axis=0).astype(np.float32)
    ROI_images_init_rescaled = np.concatenate(ROI_images_rescaled, axis=0).astype(np.float32)

    # Perform data cleaning
    idx_violations = (np.isnan(ROI_images_init_rescaled.sum(axis=(1,2)))*1 + (np.sum(ROI_images_init_rescaled, axis=(1,2))==0)*1 + np.isnan(_labels_init)) != 0
    print('Number of idx_violations: ', idx_violations.sum(), ' out of ', len(idx_violations), ' total ROIs.')
    print('Located at: ', np.where(idx_violations)[0])
    print('Discarding these ROIs...')

    ROI_images_filt = ROI_images_init_rescaled[~idx_violations]

    if testing:
        ROI_images_filt = ROI_images_filt[:100]

    classifierInferenceRun_interim['ROI_images_filt'] = ROI_images_filt

    tictoc['cleaned_data'] = time.time() - tic


    transforms_final_all = cu.get_transforms(params['hyperparameters_augmentations_all'], scripted=True)
    dataset_all = roicat.ROInet.dataset_simCLR(
            X=torch.as_tensor(ROI_images_filt, device='cpu', dtype=torch.float32),
            y=torch.as_tensor(np.zeros((ROI_images_filt.shape[0])), device='cpu', dtype=torch.float32),
            n_transforms=1,
            class_weights=np.array([1]),
            transform=transforms_final_all, # *Use WarpPoints
            DEVICE='cpu',
            dtype_X=torch.float32,
        )
    dataloader_all = torch.utils.data.DataLoader( 
            dataset_all,
            batch_size=64,
            shuffle=False,
            drop_last=False,
            pin_memory=False,
            num_workers=0,#mp.cpu_count(),
            persistent_workers=False,
            prefetch_factor=2,
    )

    roinet = roicat.ROInet.ROInet_embedder(
        device=params['device'],
        dir_networkFiles=params['paths']['directory_simclrModel'],
        download_method='check_local_first',
        forward_pass_version='head',
        download_url=params['hyperparameters_training_simclr']['simclrModel_download_url'],
        download_hash=params['hyperparameters_training_simclr']['simclrModel_download_hash'],
        verbose=True,
    )

    print(f'Extracting transformed images from dataloaders, passing through roinet model, and saving to {directory_save}...')

    features_all, _labels_all, _idx_all, _sample_all = cu.extract_with_dataloader(
        dataloader_all,
        model=roinet.net,
        num_copies=1,
        device=params['device'],
    )

    classifierInferenceRun_interim['features_all'] = features_all
    print(f'Unaugmented run completed.')

else:
    features_all = np.load(filepath_latents)

# Load Previously Fit Logistic Regression Model

In [None]:
tictoc['loaded_data'] = time.time() - tic
print('Calculating class weights...')

tictoc['splitted_data'] = time.time() - tic

# Create lenet model, associated optimizer, loss function, and training tracker
model = sklearn.linear_model.LogisticRegression(
   solver=params['hyperparameters_training_simclr']['solver'],
   fit_intercept=params['hyperparameters_training_simclr']['fit_intercept'],
   max_iter=params['hyperparameters_training_simclr']['max_iter'],
   C=params['hyperparameters_training_simclr']['C'],

)

dct_model = np.load(filepath_model, allow_pickle=True)
model.coef_ = dct_model[()]['coef_']
model.intercept_ = dct_model[()]['intercept_']
model.classes_ = dct_model[()]['classes_']
predictions = model.predict(features_all)
predictionProbas = model.predict_proba(features_all)

tictoc[f'completed_training_in_{0}'] = time.time() - tic

# Save Outputs

In [None]:
np.save(str((Path(directory_save) / 'labels_predicted.npy').resolve()), predictions, allow_pickle=True)
np.save(str((Path(directory_save) / 'labels_predictedProbas.npy').resolve()), predictionProbas, allow_pickle=True)

classifierInferenceRun_interim['params_prespecified'] = params
with open(str((Path(directory_save) / 'classifierInferenceRun_interim.pkl').resolve()), 'wb') as f:
    np.save(
        file=f,
        arr=classifierInferenceRun_interim,
        allow_pickle=True
    )

print(f'Saved model fit results.')