# Imports

In [23]:
import os
import rippl_AI
import importlib
importlib.reload(rippl_AI)
import aux_fcn

import matplotlib.pyplot as plt
import numpy as np
%matplotlib qt

# Section A
In this section, a use example of the predict and detection functions are provided

### Download data

In [3]:
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                  
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

Downloading data... Please wait, this might take up some time
Data downloaded!


### Data loading

In [4]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)
channels_map = aux_fcn.load_channels_map(path)

# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)
print(channels_map)


Channel map value:  [0, 1, 2, 3, 4, 5, 6, 7]
Downloaded_data\Dlx1\figshare_14959449/lfp_Dlx1-2021-02-12_12-46-54.dat
fileStart  0
fileStop  490242048
nSamples  245121024
nSamplesPerChannel  30640128
nSamplesPerChunk  10000
size data  30640128
Sampling frequency:  30000
Shape of the original data (30640128, 8)
[[1 1 2]
 [2 1 2]
 [3 1 2]
 [4 1 2]
 [5 1 2]
 [6 1 2]
 [7 1 2]
 [8 1 2]]


In [None]:
# The predict function takes care of normalizing and subsampling your data
# If no architecture or model is specified, the best CNN1D will be used
prob,LFP_norm=rippl_AI.predict(LFP,sf) 

In [None]:
det_ind=rippl_AI.get_intervals(prob,LFP_norm=LFP_norm)  # Problem: must be executed twice to choose the threshols
                                      # 1: launch, choose th, save
                                      # 2: lauch again, any  th. The previous th will be saved
print(f"{det_ind.shape[0]} events where detected")

# Section B
Every model predict, get_intervals is used automatically and the performance metric is ploted

### Data download

In [None]:
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                    
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data loading

In [2]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)

channels_map = aux_fcn.load_channels_map(path)
# Now the ground truth (GT) tagged events is loaded 
ripples=aux_fcn.load_ripples(path)/sf
# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)

Channel map value:  [0, 1, 2, 3, 4, 5, 6, 7]
Downloaded_data\Dlx1\figshare_14959449/lfp_Dlx1-2021-02-12_12-46-54.dat
fileStart  0
fileStop  490242048
nSamples  245121024
nSamplesPerChannel  30640128
nSamplesPerChunk  10000
size data  30640128
Sampling frequency:  30000
Shape of the original data (30640128, 8)


In [None]:
# Two loops going over every possible model
architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']
SWR_prob=[[None]*5]*5
for i,architecture in enumerate(architectures):
    print(i,architecture)
    for n in range(1,6):
        # Make sure the selected model expected number of channels is the same as the channels array passed to the predict fcn
        # In this case, we are manually setting the channel array to 3 
        if architecture=='CNN2D' and n>=3:
            channels=[0,3,7]
        else:
            channels=[0,1,2,3,4,5,6,7]
        SWR_prob[i][n-1],_=rippl_AI.predict(LFP,sf,arch=architecture,model_number=n,channels=channels)

# SWR_prob contains the output of each model


In [None]:
th_arr=np.linspace(0.1,1,10)
fig,axs=plt.subplots(5,5,figsize=(10,10),sharex='all',sharey='all')
for i in range(5):
    for j in range(5):
        F1_arr=np.zeros(shape=(len(th_arr)))
        for k,th in enumerate(th_arr):
            det_ind=rippl_AI.get_intervals(SWR_prob[i][j],threshold=th)
            #print(ripples)
            _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)
        axs[i,j].plot(th_arr,F1_arr)
    axs[i,0].set_title(architectures[i])

axs[0,0].set_xlabel('Threshold')
axs[0,0].set_ylabel('F1')

# Section C
In this section how to use the interpolation function is used will be shown. It is generally called internally from LFP_predict(), which could cause some confusion

### Data download

In [3]:
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                    
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

Data already exists. Moving on.


### Data load
To ilustrate how 'interpolate_channels' can be used to extract the desired number of channels, we will be simulating two cases using the DLx1 session:
1. We are using a recording probe that extracts 4 channels, when we need 8.
2. Some channels are dead or have to much noise.

In [None]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)

channels_map = aux_fcn.load_channels_map(path)
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=False)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)
LFP_linear=LFP[:,[0,2,4,6]]
print('Shape of the 4 channels simulated data: ',LFP_linear.shape)
LFP[:,[2,5]]=0
LFP_dead=LFP
print('Sample of the simulated dead LFP: ',LFP_dead[0])

After interpolation, the data is ready to use in prediction

In [None]:
# Define channels
channels_interpolation = [0,-1,1,-1,2,-1,-1,3]

# Make interpolation
LFP_interpolated = aux_fcn.interpolate_channels(LFP_linear, channels_interpolation)
print('Shape of the interpolated LFP: ',LFP_interpolated.shape)

In [None]:
 # Define channels
channels_interpolation = [0,1,-1,3,4,-1,6,7]

# Make interpolation
LFP_interpolated = aux_fcn.interpolate_channels(LFP_dead, channels_interpolation)
print('Value of the 1st sample of the interpolated LFP: ',LFP_interpolated[0])

# Section D
In this section an ensemble model that combines the output of the other 5 previous models will be shown. In this case, only the best ensemble model will be provided.

First, the output of the 5 selected models needs to reshaped

In [16]:
# 5 outputs are generated
architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']
output=[]
for architecture in architectures:
    channels=[0,1,2,3,4,5,6,7]
    SWR_prob,_=rippl_AI.predict(LFP,sf,arch=architecture,model_number=1,channels=channels)
    output.append(SWR_prob)
ens_input=np.array(output).transpose()


Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_1_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE1
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
SVM_1_Ch8_W60_Ts001_Us0.05
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
CNN1D_1_Ch8_W60_Ts16_OGmodel12
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
CNN2D_1_Ch8_W60_Ts40_OgModel


Generating ensemble model output

In [18]:

prob_ens=rippl_AI.predict_ens(ens_input)



Plot performance

In [26]:
fig,ax=plt.subplots()
th_arr=np.linspace(0.1,1,10)
F1_arr=np.zeros(shape=(len(th_arr)))
for k,th in enumerate(th_arr):
    det_ind=rippl_AI.get_intervals(prob_ens,threshold=th)
    _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)
ax.plot(th_arr,F1_arr)
ax.set_title('Ensemble model')
ax.set_ylim(-0.05,0.8)
ax.set_xlabel('Threshold')
ax.set_ylabel('F1')


precision = 0.6297376093294461
recall = 0.8293838862559242
F1 = 0.7159023115311409
precision = 0.7204301075268817
recall = 0.7819905213270142
F1 = 0.7499491214978631
precision = 0.7258687258687259
recall = 0.7488151658767772
F1 = 0.7371634197791289
precision = 0.7386363636363636
recall = 0.7298578199052133
F1 = 0.7342208530458063
precision = 0.746268656716418
recall = 0.7061611374407584
F1 = 0.725661130862514
precision = 0.7942386831275721
recall = 0.6729857819905214
F1 = 0.7286020018875701
precision = 0.8401639344262295
recall = 0.6635071090047393
F1 = 0.7414583737001873
precision = 0.8518518518518519
recall = 0.6255924170616114
F1 = 0.7213971723892124
precision = 0.8875502008032129
recall = 0.5971563981042654
F1 = 0.7139542337029676
x is empty. Cant perform IoU
precision = nan
recall = 0.0
F1 = nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Text(0, 0.5, 'F1')