# Imports

In [3]:
import os
import importlib
import rippl_AI
import aux_fcn
import matplotlib.pyplot as plt
import numpy as np
importlib.reload(aux_fcn)
%matplotlib qt

# Section A
In this section, a use example of the predict and detection functions are provided

### Download data

In [None]:
import os
from figshare.figshare.figshare import Figshare
fshare = Figshare()

# Array with the data IDs from figshare.
#     Remove the session IDs and number if you dont want to download al the data
#             Amigo2_1  Som2     Dlx1_1   Thy7_1
article_ids = [14959449] # This is the ID of the data repository: [16847521,16856137,14959449,14960085]
sess=['Dlx1']                                    # [-2,-1,0,1]   
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data loading

In [3]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)
channels_map = aux_fcn.load_channels_map(path)

# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
data = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',data.shape)
print(channels_map)


Channel map value:  [0, 1, 2, 3, 4, 5, 6, 7]
Downloaded_data\Dlx1\figshare_14959449/lfp_Dlx1-2021-02-12_12-46-54.dat
fileStart  0
fileStop  490242048
nSamples  245121024
nSamplesPerChannel  30640128
nSamplesPerChunk  10000
size data  30640128
Sampling frequency:  30000
Shape of the original data (30640128, 8)
[[1 1 2]
 [2 1 2]
 [3 1 2]
 [4 1 2]
 [5 1 2]
 [6 1 2]
 [7 1 2]
 [8 1 2]]


In [11]:
# The predict function takes care of normalizing and subsampling of your data
importlib.reload(aux_fcn)
importlib.reload(rippl_AI)

channels=[0,1,2,3,4,5,6,7]
# If no architecture or model is specified, the best CNN1D will be used
prob,LFP_norm=rippl_AI.predict(data,sf) 

[0 1 2 3 4 5 6 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
CNN1D_1_Ch8_W60_Ts16_OGmodel12
8 8


In [16]:
det_ind=rippl_AI.get_intervals(prob,LFP_norm=LFP_norm)  # Hay un problema: debe ejecutarse 2 veces para elegir el umbral, no tengo tiemppo de resolverlo
                                      # 1: se ejecuta una vez, se escoge el umbral, se pulsa save
                                      # 2: se ejecuta una vez, se escoge un umbral cualquiera
print(f"{det_ind.shape[0]} events where detected")

248 events where detected


# Section B
Every model predict, get_intervals is used automatically and the performance metric is ploted

### Data download

In [17]:
import os
from figshare.figshare.figshare import Figshare
fshare = Figshare()

# Array with the data IDs from figshare.
#     Remove the session IDs and number if you dont want to download al the data
#             Amigo2_1  Som2     Dlx1_1   Thy7_1
article_ids = [14959449] # This is the ID of the data repository: [16847521,16856137,14959449,14960085]
sess=['Dlx1']                                    # [-2,-1,0,1]   
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

Data already exists. Moving on.


### Data loading

In [18]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)

channels_map = aux_fcn.load_channels_map(path)
# Now the ground truth (GT) tagged events is loaded 
ripples=aux_fcn.load_ripples(path)/sf
# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
data = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',data.shape)

Channel map value:  [0, 1, 2, 3, 4, 5, 6, 7]
Downloaded_data\Dlx1\figshare_14959449/lfp_Dlx1-2021-02-12_12-46-54.dat
fileStart  0
fileStop  490242048
nSamples  245121024
nSamplesPerChannel  30640128
nSamplesPerChunk  10000
size data  30640128
Sampling frequency:  30000
Shape of the original data (30640128, 8)


In [19]:
# Two loops going over every possible model
importlib.reload(rippl_AI)
architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']
SWR_prob=[[None]*5]*5
for i,architecture in enumerate(architectures):
    for n in range(1,6):
        # Make sure the selected model expected number of channels is the same as the channels array passed to the predict fcn
        # In this case, we are manually setting the channel array to 3 
        if architecture=='CNN2D' and n>=3:
            channels=[0,3,7]
        else:
            channels=[0,1,2,3,4,5,6,7]
        SWR_prob[i][n-1],_=rippl_AI.predict(data,sf,arch=architecture,model_number=n,channels=channels)

# SWR_prob contains the output of each model


[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_1_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE1
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_2_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE5
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_3_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE3
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_4_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE5
8 8
[0, 1, 2, 3, 4, 5, 6, 7]
Original LFP shape:  (30640128, 8)
Downsampling data at 1250 Hz...
Shape of downsampled data: (1276672, 8)
Normalizing data...
XGBOOST_5_Ch8_W60_Ts

(array([1263992,    1320,     640,     520,     680,     840,    1000,
            800,    1440,    5440], dtype=int64),
 array([0.        , 0.09999534, 0.19999069, 0.29998603, 0.39998138,
        0.49997672, 0.59997207, 0.69996741, 0.79996276, 0.8999581 ,
        0.99995345]))

In [23]:
th_arr=np.linspace(0.1,1,10)
fig,axs=plt.subplots(5,5,figsize=(10,5),sharex='all',sharey='all')
for i in range(5):
    for j in range(5):
        F1_arr=np.zeros(shape=(len(th_arr)))
        for k,th in enumerate(th_arr):
            det_ind=rippl_AI.get_intervals(SWR_prob[i][j],threshold=th)
            #print(det_ind.shape)
            #print(ripples)
            _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)
        axs[i,j].plot(th_arr,F1_arr)

            

        

[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.1
precision = 0.6796536796536796
recall = 0.7488151658767772
F1 = 0.7125601436265709
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.2
precision = 0.7272727272727273
recall = 0.7203791469194313
F1 = 0.7238095238095238
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.30000000000000004
precision = 0.7326732673267327
recall = 0.7014218009478673
F1 = 0.7167070217917677
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.4
precision = 0.745
recall = 0.6966824644549763
F1 = 0.7200315587041206
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e+00 0.00000000e+00] 0.5
precision = 0.7807486631016043
recall = 0.6872037914691943
F1 = 0.7309956665112826
[3.41276585e-09 3.41276585e-09 3.41276585e-09 ... 0.00000000e+00
 0.00000000e