# CNN example

Some code to show how the CNN finds ripples.

In [None]:
# Imports
import sys
import numpy as np

sys.path.insert(1, '../cnn/')

## Download data

Downloads data from a Figshare repository. Example data can be found with articles IDs:
- [Thy7_2020-11-11_16-05-00](https://figshare.com/articles/dataset/Thy7_2020-11-11_16-05-00/14960085): 14960085
- [Dlx1_2021-02-12_12-46-54](https://figshare.com/articles/dataset/Dlx1_2021-02-12_12-46-54/14959449): 14959449

*If you have your own data you can skip this step.*

In [None]:
import os
sys.path.insert(1, '../../figshare')
from figshare import Figshare

fshare = Figshare()

article_id = 14960085 # This is the ID of the data repository

datapath = "figshare_%d"%(article_id)

if os.path.isdir(datapath):
    print("Data already exists. Moving on.")
else:
    print("Downloading data... Please wait")
    fshare.retrieve_files_from_article(article_id)
    print("Data downloaded!")


## Load data

Loads data from the path specified in *datapath*. 

If you are using **your own data** please replace this code with your own methods to load it. By the end of this cell, loaded data must comply with the following conditions:
- It has to be a 2D numpy matrix (numpy.darray) with dimensions **(Number of samples x Number of channels)**.
- Number of channels has to be 8 (corresponding to a shank).
- The variable containing the loaded data must be named **loaded_data**.
- Data sampling rate must be saved in a variable named **fs** (in Hz).

In [None]:
'''''''''''''''''
Load data from the figshare files
'''''''''''''''''

from load_data import load_data

shank = 1

# Load data
print("Loading data...", end=" ")
loaded_data, fs = load_data(path=datapath, shank=shank)
print("Done!")

print("Shape of loaded data: ", np.shape(loaded_data))

In [None]:
'''''''''''''''''
If you have your 
own data use this 
cell to load it
'''''''''''''''''

Loaded data will be downsampled to **1250 Hz** and then normalized using **z-score** by channels. Afterwards, it will separated into **12.8 ms** windows that will be the input for the CNN. By default, these windows will have an **overlapping of 6.4 ms** between them. Overlapping can be avoided by setting the *overlapping* variable to *False*.

In [None]:
from load_data import z_score_normalization, downsample_data
# Downsample data
downsampled_fs = 1250
print("Downsampling data from %d Hz to %d Hz..."%(fs, downsampled_fs), end=" ")
data = downsample_data(loaded_data, fs, downsampled_fs)
print("Done!")

# Normalize it with z-score
print("Normalizing data...", end=" ")
data = z_score_normalization(data)
print("Done!")

print("Shape of loaded data after downsampling and z-score: ", np.shape(data))

In [None]:
overlapping = True
window_size = 0.0128

print("Generating windows...", end=" ")
if overlapping:
    from load_data import generate_overlapping_windows
    
    stride = 0.0064

    # Separate the data into 12.8ms windows with 6.4ms overlapping
    X = generate_overlapping_windows(data, window_size, stride, downsampled_fs)
else:
    stride = window_size
    X = np.expand_dims(data, 0)
print("Done!")

## Load trained CNN model

In [None]:
import tensorflow.keras.backend as K
import tensorflow.keras as kr

print("Loading CNN model...", end=" ")
optimizer = kr.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False)
model = kr.models.load_model("../../model", compile=False)
model.compile(loss="binary_crossentropy", optimizer=optimizer)

print("Done!")

## Detecting ripples with CNN model

The CNN will make a prediction for every window. This consists in a **number between 0 and 1**, representing the **probability** of a ripple being in that window according to the CNN.

In [None]:
print("Detecting ripples...", end=" ")
predictions = model.predict(X, verbose=True)
print("Done!")

### Get detected ripples and times

Gets the detected ripples times, both in seconds and in indexes of the downsampled data array. The resulting times are a 2D numpy array with dimensions **(Number of detections x 2)**, having the starting and ending times for each detection.

It will consider as a ripple detection all those windows whose associated probability is **over a given threshold**. This threshold can be changed by the user and can vary from one session to another. With a high threshold it will consider only those events that the CNN strongly believes to be ripples, but some other may be left aside. A lower threshold will capture more events, but maybe some of them will not be really ripples.

In [None]:
from format_predictions import get_predictions_indexes

# This threshold can be changed
threshold = 0.7

print("Getting detected ripples indexes and times...", end=" ")
pred_indexes = get_predictions_indexes(data, predictions, window_size=window_size, stride=stride, fs=downsampled_fs, threshold=threshold)

pred_times = pred_indexes / downsampled_fs
print("Done!")

### Plot ripple detections

This is an interactive plot of the loaded data, where detected ripples are shown in blue. Data is displayed in chunks of 1 seconds and you can **move forward, backwards or jump to an specific second** using the control bar at the bottom.

In [None]:
%matplotlib widget
from ipywidgets import interact, Layout
import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

k = 0
data_size = data.shape[0]
data_dur = data_size / downsampled_fs
ini_idx = int(k * downsampled_fs)
end_idx = np.minimum(int((k+1) * downsampled_fs), data_size-1)
times = np.arange(data_size) / downsampled_fs

pos_mat = list(range(data.shape[1]-1, -1, -1)) * np.ones((end_idx-ini_idx, data.shape[1]))

fig = plt.figure(figsize=(9.75,5))
ax = fig.add_subplot(1, 1, 1)
ax.set_ylim(-3, 9)
ax.margins(x=0)
plt.tight_layout()

lines = ax.plot(times[ini_idx:end_idx], data[ini_idx:end_idx, :]*1/np.max(data[ini_idx:end_idx, :], axis=0) + pos_mat, color='k', linewidth=1)

fills = []
for pred in pred_indexes:
    if (pred[0] >= ini_idx and pred[0] <= end_idx) or (pred[1] >= ini_idx and pred[1] <= end_idx):
        rip_ini = (pred[0] - ini_idx) / downsampled_fs
        rip_end = (pred[1] - ini_idx) / downsampled_fs
        fill = ax.fill_between([rip_ini, rip_end], [-3, -3], [9, 9], color="tab:blue", alpha=0.3)
        fills.append(fill)


def update(k=0):
    ini_idx = int(k * downsampled_fs)
    end_idx = np.minimum(int((k+1) * downsampled_fs), data_size-1)

    label_format = '{:,.1f}'
    ax.xaxis.set_major_locator(mticker.MaxNLocator(6))
    x_ticks = ax.get_xticks().tolist()
    ax.xaxis.set_major_locator(mticker.FixedLocator(x_ticks))
    ax.set_xticklabels([label_format.format(x+k) for x in x_ticks])
    
    for i in range(len(lines)):
        #lines[i].set_xdata(times[ini_idx:end_idx])
        lines[i].set_ydata(data[ini_idx:end_idx, i]*1/np.max(data[ini_idx:end_idx, i], axis=0) + pos_mat[:,i])

    for fill in fills:
        fill.remove()
    
    fills.clear()
    for pred in pred_indexes:
        if (pred[0] >= ini_idx and pred[0] <= end_idx) or (pred[1] >= ini_idx and pred[1] <= end_idx):
            rip_ini = (pred[0] - ini_idx) / downsampled_fs
            rip_end = (pred[1] - ini_idx) / downsampled_fs
            fill = ax.fill_between([rip_ini, rip_end], [-3, -3], [9, 9], color="tab:blue", alpha=0.3)
            fills.append(fill)
    
    fig.canvas.draw_idle()

interact(update, k = widgets.BoundedIntText(value=17, min=0, max=data_dur, step=1, description="Second:", layout=Layout(width='900px')));