# Imports

In [1]:
import os
import importlib
import rippl_AI
import aux_fcn
import matplotlib.pyplot as plt
import numpy as np
importlib.reload(aux_fcn)
%matplotlib qt

# Section A
In this section, a use example of the predict and detection functions are provided

### Download data

In [None]:
import os
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                  
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data loading

In [2]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)
channels_map = aux_fcn.load_channels_map(path)

# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)
print(channels_map)


Channel map value:  [0, 1, 2, 3, 4, 5, 6, 7]
Downloaded_data\Dlx1\figshare_14959449/lfp_Dlx1-2021-02-12_12-46-54.dat
fileStart  0
fileStop  490242048
nSamples  245121024
nSamplesPerChannel  30640128
nSamplesPerChunk  10000
size data  30640128
Sampling frequency:  30000
Shape of the original data (30640128, 8)
[[1 1 2]
 [2 1 2]
 [3 1 2]
 [4 1 2]
 [5 1 2]
 [6 1 2]
 [7 1 2]
 [8 1 2]]


In [3]:
# The predict function takes care of normalizing and subsampling your data
# If no architecture or model is specified, the best CNN1D will be used
prob,LFP_norm=rippl_AI.predict(LFP,sf) 

[0 1 2 3 4 5 6 7]
Original LFP shape:  (30640128, 8)
CNN1D_1_Ch8_W60_Ts16_OGmodel12
8 8


In [16]:
det_ind=rippl_AI.get_intervals(prob,LFP_norm=LFP_norm)  # Problem: must be executed twice to choose the threshols
                                      # 1: launch, choose th, save
                                      # 2: lauch again, any  th. The previous th will be saved
print(f"{det_ind.shape[0]} events where detected")

248 events where detected


# Section B
Every model predict, get_intervals is used automatically and the performance metric is ploted

### Data download

In [17]:
import os
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                    
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

Data already exists. Moving on.


### Data loading

In [3]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)

channels_map = aux_fcn.load_channels_map(path)
# Now the ground truth (GT) tagged events is loaded 
ripples=aux_fcn.load_ripples(path)/sf
# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)

Channel map value:  [0, 1, 2, 3, 4, 5, 6, 7]
Downloaded_data\Dlx1\figshare_14959449/lfp_Dlx1-2021-02-12_12-46-54.dat
fileStart  0
fileStop  490242048
nSamples  245121024
nSamplesPerChannel  30640128
nSamplesPerChunk  10000
size data  30640128
Sampling frequency:  30000
Shape of the original data (30640128, 8)


In [5]:
# Two loops going over every possible model
importlib.reload(rippl_AI)
importlib.reload(aux_fcn)
architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']
SWR_prob=[[None]*5]*5
for i,architecture in enumerate(architectures):
    print(i,architecture)
    for n in range(1,6):
        # Make sure the selected model expected number of channels is the same as the channels array passed to the predict fcn
        # In this case, we are manually setting the channel array to 3 
        if architecture=='CNN2D' and n>=3:
            channels=[0,3,7]
        else:
            channels=[0,1,2,3,4,5,6,7]
        SWR_prob[i][n-1],_=rippl_AI.predict(LFP,sf,arch=architecture,model_number=n,channels=channels)

# SWR_prob contains the output of each model


0 XGBOOST
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_1_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE1
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_2_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE5
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_3_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE3
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_4_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE5
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_5_

KeyboardInterrupt: 

In [5]:
th_arr=np.linspace(0.1,1,10)
fig,axs=plt.subplots(5,5,figsize=(10,10),sharex='all',sharey='all')
for i in range(5):
    for j in range(5):
        F1_arr=np.zeros(shape=(len(th_arr)))
        for k,th in enumerate(th_arr):
            det_ind=rippl_AI.get_intervals(SWR_prob[i][j],threshold=th)
            #print(ripples)
            _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)
        axs[i,j].plot(th_arr,F1_arr)
    axs[i,0].set_title(architectures[i])

axs[0,0].set_xlabel('Threshold')
axs[0,0].set_ylabel('F1')

[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.1
precision = 0.6796536796536796
recall = 0.7488151658767772
F1 = 0.7125601436265709
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.2
precision = 0.7272727272727273
recall = 0.7203791469194313
F1 = 0.7238095238095238
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.30000000000000004
precision = 0.7326732673267327
recall = 0.7014218009478673
F1 = 0.7167070217917677
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.4
precision = 0.745
recall = 0.6966824644549763
F1 = 0.7200315587041206
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.5
precision = 0.7807486631016043
recall = 0.6872037914691943
F1 = 0.7309956665112826
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


precision = nan
recall = 0.0
F1 = nan
[0.00457716 0.00457716 0.00457716 ... 0.00087281 0.00087281 0.00087281] 0.1
precision = 0.17195767195767195
recall = 0.9146919431279621
F1 = 0.2894922059790698
[0.00457716 0.00457716 0.00457716 ... 0.00087281 0.00087281 0.00087281] 0.2
precision = 0.3049907578558225
recall = 0.7677725118483412
F1 = 0.43656140522795756
[0.00457716 0.00457716 0.00457716 ... 0.00087281 0.00087281 0.00087281] 0.30000000000000004
precision = 0.44155844155844154
recall = 0.6398104265402844
F1 = 0.5225112413910866
[0.00457716 0.00457716 0.00457716 ... 0.00087281 0.00087281 0.00087281] 0.4
precision = 0.5766871165644172
recall = 0.4454976303317536
F1 = 0.5026737967914439
[0.00457716 0.00457716 0.00457716 ... 0.00087281 0.00087281 0.00087281] 0.5
precision = 0.7358490566037735
recall = 0.1800947867298578
F1 = 0.28936834911646975
[0.00457716 0.00457716 0.00457716 ... 0.00087281 0.00087281 0.00087281] 0.6
precision = 0.9166666666666666
recall = 0.05213270142180093
F1 = 0.0986

Text(0, 0.5, 'F1')

# Section C
In this section how to use the interpolation function is used will be shown. It is generally called internally from LFP_predict(), which could cause some confusion

### Data download

In [53]:
import os
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                    
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

Data already exists. Moving on.


### Data load
To ilustrate how 'interpolate_channels' can be used to extract the desired number of channels, we will be simulating two cases using the DLx1 session:
1. We are using a recording probe that extracts 4 channels, when we need 8.
2. Some channels are dead or have to much noise.

In [64]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)

channels_map = aux_fcn.load_channels_map(path)
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=False)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)
LFP_linear=LFP[:,[0,2,4,6]]
print('Shape of the 4 channels simulated data: ',LFP_linear.shape)
LFP[:,[2,5]]=0
LFP_dead=LFP
print('Sample of the simulated dead LFP: ',LFP_dead[0])

Sampling frequency:  30000
Shape of the original data (30640128, 8)
Shape of the 4 channels simulated data:  (30640128, 4)
Sample of the simulated dead LFP:  [-111 -113    0 -138 -152    0 -138 -151]


After interpolation, the data is ready to use in prediction

In [61]:
importlib.reload(aux_fcn)
# Define channels
channels_interpolation = [0,-1,1,-1,2,-1,-1,3]

# Make interpolation
LFP_interpolated = aux_fcn.interpolate_channels(LFP_linear, channels_interpolation)
print('Shape of the interpolated LFP: ',LFP_interpolated.shape)

Shape of the interpolated LFP:  (30640128, 8)


In [66]:
 # Define channels
channels_interpolation = [0,1,-1,3,4,-1,6,7]

# Make interpolation
LFP_interpolated = aux_fcn.interpolate_channels(LFP_dead, channels_interpolation)
print('Value of the 1st sample of the interpolated LFP: ',LFP_interpolated[0])

[-111.  -113.  -125.5 -138.  -152.  -145.  -138.  -151. ]
