# Demo for Using ENS<sup>2</sup>
This demo shows partial examples of using ENS<sup>2</sup> to inference spikes as in:
<br>Zhou et al. <b>"Effective and Efficient Neural Networks for Spike Inference from In Vivo Calcium Imaging"</b>


# Requirements:

<b>Inference with pre-trained model:</b>
<br>We have provided the pre-trained ENS<sup>2</sup> model in this package. 
<br>Any regular PC is ready for inference even on a <b>CPU</b> (or GPU, if any) with the following packages installed:
- python == 3.6
- torch  >= 1.7.1
- numpy  >= 1.19.2
- scipy  >= 1.5.2
- tqdm   >= 4.59.0


<b>However, you can also train a new model from scratch with additional requirements:</b>
- CUDA-enabled <b>GPU</b>, together with proper CUDA and cuDNN toolkits
- 24Gb of system RAM (recommended)
- Training database (please refer to <i>Benchmark_demo.ipynb</i> for further instructions)

In [1]:
import numpy
import scipy.io as scio

from ENS2 import *

In [2]:
ens2 = ENS2()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'  # works with either CPU or GPU
ens2.DEVICE = DEVICE
print(f'Using {DEVICE}')

Using cuda


In [3]:
load_state_dict = True   # Whether to load pre-trained model
                          # True: use pre-trained model
                          # False: re-train model from scratch

neuron_type = 'Exc'       # Inference on 'Exc' or 'Inh' neurons

if load_state_dict:
    if neuron_type == 'Exc':
        state_dict_raw  = torch.load('./saved_model/C_220812172434_dsets0_60.0Hz_Raw_UNet_MSE_Epoch1515.pt',
                                 map_location=torch.device(DEVICE)).state_dict()
    elif neuron_type == 'Inh':
        state_dict_raw  = torch.load('./saved_model/C_220812172805_dsets0_60.0Hz_Raw_UNet_MSE_Epoch1357.pt',
                                 map_location=torch.device(DEVICE)).state_dict()
else:
    ens2.train(neuron=neuron_type)

In [4]:
# Load sample data from file

test_data = scio.loadmat('./sample_data/sample_data.mat')  # (configure your file path here)

print("Variable(s) are:\n" + "\n".join(str(i) for i in list(test_data.keys()) if '__' not in i))

Variable(s) are:
dff


In [5]:
dff = test_data['dff']    # input shape: (trials x frames) 2-d numpy array
print(dff.shape)

(6, 2000)


In [6]:
def predict_test_data(test_data, trial_time):
    
    # test_data       - input data in required format: (trials x frames)
    # trial_time      - total time accounted for one row of samples (frames)
    
    # Pre-process data and segmentation
    print('Process data...')
    test_data = compile_test_data(test_data, trial_time)
    
    # Prepare model parameters
    if load_state_dict:
        state_dict = state_dict_raw
    else:
        state_dict = None
        
    print('Inferring...')
    for trial in trange(len(test_data)):
        calcium, pd_rate, pd_spike, pd_event = ens2.predict(test_data[trial]['dff_resampled_segment'], 
                                                            state_dict=state_dict)
        test_data[trial]['calcium'] = np.float32(calcium)
        test_data[trial]['pd_rate'] = np.float32(pd_rate)
        test_data[trial]['pd_spike'] = np.float32(pd_spike)
        test_data[trial]['pd_event'] = np.float32(pd_event)
        
        test_data[trial]['dff_resampled_segment'] = [] # remove segments to reduce storage space
    return test_data

In [7]:
# Inference with ENS2

trial_num = 20      # trials
trial_duration = 10 # seconds per trial

dff_ENS2 = predict_test_data(dff, trial_time=trial_duration*trial_num)

Process data...
Test data has 6 trials.
Recording duration is 200s, equaling 10.0Hz frame rate.
Compile data done.
Inferring...


  0%|          | 0/6 [00:00<?, ?it/s]

In [8]:
# Save results for analyses in MATLAB 

scio.savemat('./results/sample_data_ENS2.mat', {'dff':dff, 'dff_ENS2':dff_ENS2})

Please use <i>visualizer.m</i> in MATLAB to visualize the results. 
<br> Note that there would be minor randomness every time the model is re-trained, due to different GPU/toolkits setup.