In [1]:
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !git clone https://github.com/MJC598/Neuron_Burst_Analysis.git

In [2]:
%matplotlib widget
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence
from numpy.random import default_rng
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import scipy.io
from scipy import signal
import random
import time
import pandas as pds
from sklearn.metrics import r2_score
import copy

s = 67

rs = RandomState(MT19937(SeedSequence(s)))
rng = default_rng(seed=s)
torch.manual_seed(s)

plt.rcParams.update({'font.size': 32})

![Fully Connected Network](graphs/fullyConnectedNetwork.png)

In [3]:
class FCN(nn.Module):
    def __init__(self, in_size, h_size, out_size):
        super(FCN,self).__init__()
        self.fc1 = nn.Linear(in_size, h_size)
        self.tanh = nn.Tanh()
#         self.fc2 = nn.Linear(h_size, 40)
#         self.fc3 = nn.Linear(40, 15)
        self.fc2 = nn.Linear(h_size, out_size)
    def forward(self, x):
        x = self.tanh(self.fc1(x))
#         x = self.tanh(self.fc2(x))
#         x = self.tanh(self.fc3(x))
        out = self.fc2(x)
        return out

In [4]:
#Formerly Conv1dFCN
class LFPNet1C(nn.Module):
    def __init__(self, in_size, h_size, out_size):
        super(LFPNet1C, self).__init__()
        self.cn1 = nn.Conv1d(1, 1, kernel_size=5,padding=2)
        self.cn2 = nn.Conv1d(1, 1, kernel_size=3,padding=1)
        self.bn = nn.BatchNorm1d(1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(int(in_size), h_size)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(h_size, out_size)
    
    def forward(self,x):
        residual = x
        x = self.relu(self.cn1(x))
        x = self.relu(self.cn2(x))
        x += residual
        
        x = self.relu(self.fc1(x))
        out = self.fc2(x)
        return out

In [5]:
class LFPNetMC(nn.Module):
    def __init__(self, in_size, h_size, out_size):
        super(LFPNetMC, self).__init__()
        self.cn1 = nn.Conv1d(2, 2, kernel_size=5,padding=2)
        self.cn2 = nn.Conv1d(2, 2, kernel_size=3,padding=1)
        self.cn3 = nn.Conv1d(2, 1, kernel_size=1)
        self.bn = nn.BatchNorm1d(1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(int(in_size), h_size)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(h_size, out_size)
    
    def forward(self,x):
        residual = x
        x = self.relu(self.cn1(x))
        x = self.relu(self.cn2(x))
        x += residual
        x = self.relu(self.cn3(x))
        
        x = self.relu(self.fc1(x))
        out = self.fc2(x)
        return out

In [6]:
# Specific Model Parameters
INPUT_FEATURES = 1
PREVIOUS_TIME = 50
LOOK_AHEAD = 20
INPUT_SIZE = INPUT_FEATURES * PREVIOUS_TIME
HIDDEN_SIZE = 95
OUTPUT_SIZE = 1
BATCH_SIZE = 512
BATCH_FIRST = True
DROPOUT = 0.0
EPOCHS = 20

MODEL = LFPNetMC
MODEL_NAME = 'LFPNetMC'
OUTPUT = 'FilteredLFP'
INPUT = 'ITNFRRawLFP'
LOSS_FILE = ('losses/bursts/losses_' + str(MODEL_NAME) + 
             '_' + INPUT + str(PREVIOUS_TIME) + '_' + OUTPUT + str(LOOK_AHEAD) + '.csv')
PATH = ('models/LFPNet/' + str(MODEL_NAME) + '_' + INPUT + str(PREVIOUS_TIME) +
        '_' + OUTPUT + str(LOOK_AHEAD) + '.pt')
# PATH = 'models/<class \'__main__.FCN\'>_FR_LFP_-50_90_full.pth'

DATA_PATH = 'data/bursts/burst_separatePNITNv3.mat'
COLAB_PRE = 'Neuron_Burst_Analysis/'
if RunningInCOLAB:
    LOSS_FILE = COLAB_PRE + LOSS_FILE
    PATH = COLAB_PRE + PATH
    DATA_PATH = COLAB_PRE + DATA_PATH

### Build Data From Matlab File

Read in Matlab info and Average over sliding window of size 3 PN Afferent and PN Firing Rate. 
Variables:
$$
x = \text{PN Firing Rate}\\
y = \text{ITN Firing Rate}\\
z = \text{Local Field Potential}\\
e_{1} = \text{PN Afferent}\\
e_{2} = \text{ITN Afferent}\\
e_{3} = \text{PN Excitatory Point Conductance}\\
e_{4} = \text{ITN Excitatory Point Conductance}\\
e_{5} = \text{PN Inhibitory Point Conductance}\\
e_{6} = \text{ITN Inhibitory Point Conductance}\\
t = \text{Timestep}\\
m = \text{Number of Previous Timesteps}\\
N = \text{Number of Samples}\\
n = \text{Individual Sample}\\
f = \text{Number of Features}\\
\omega(n) = \text{Time Sequence of Sample n}\\
$$

Sliding Window:
$$
x_{t} = 
\begin{cases} 
    \frac{1}{3}\sum_{i=0}^{2}x_{t+i} & \text{if } t \geq 2\\
    x_{t} & otherwise\\
\end{cases}\\
e_{1,t} = 
\begin{cases}
    \frac{1}{3}\sum_{i=0}^{2}e_{1,t+i}& \text{if } t \geq 2\\
    e_{1,t} & otherwise\\
\end{cases}\\
$$

Data $\text{size} = (N \times 1) , n = (\omega(n) \times f)$:
$$
\begin{bmatrix}
x_{t} & y_{t} & e_{1, t} & e_{2, t} & e_{3, t} & e_{4, t} & e_{5, t} & e_{6, t} & z_{t}\\
x_{t+1} & y_{t+1} & e_{1, t+1} & e_{2, t+1} & e_{3, t+1} & e_{4, t+1} & e_{5, t+1} & e_{6, t+1} & z_{t+1}\\
...\\
x_{t+\omega(n)} & y_{t+\omega(n)} & e_{1, t+\omega(n)} & e_{2, t+\omega(n)} & e_{3, t+\omega(n)} & e_{4, t+\omega(n)} & e_{5, t+\omega(n)} & e_{6, t+\omega(n)} & z_{t+\omega(n)}
\end{bmatrix}
$$

Label $\text{size} = (N \times 1) , n = (\omega(n) \times 3)$:
$$
\begin{bmatrix}
x_{t} & y_{t} & z_{t}\\
x_{t+1} & y_{t+1} & z_{t+1}\\
...\\
x_{t+\omega(n)} & y_{t+\omega(n)} & z_{t+\omega(n)}
\end{bmatrix}
$$

Building Inputs $\text{size} = ((\sum_{n=0}^{N}\omega(n) - m - 1) \times (f*m))$:
$$
\begin{bmatrix}
x_{0,t} & y_{0,t} & e_{1,0,t} & e_{2,0,t} & e_{3,0,t} & e_{4,0,t} & e_{5,0,t} & e_{6,0,t} & z_{0,t} & x_{0,t+1} & ... & z_{0,t+m}\\
x_{0,t+1} & y_{0,t+1} & e_{1,0,t+1} & e_{2,0,t+1} & e_{3,0,t+1} & e_{4,0,t+1} & e_{5,0,t+1} & e_{6,0,t+1} & z_{0,t+1} & x_{0,t+2} & ... & z_{0,t+m+1}\\
...\\
x_{0,\omega(0)-m-1} & y_{0,\omega(0)-m-1} & e_{1,0,\omega(0)-m-1} & e_{2,0,\omega(0)-m-1} & e_{3,0,\omega(0)-m-1} & e_{4,0,\omega(0)-m-1} & e_{5,0,\omega(0)-m-1} & e_{6,0,\omega(0)-m-1} & z_{0,\omega(0)-m-1} & x_{0,\omega(0)-m} & ... & z_{0,\omega(0)-1}\\
...\\
x_{N,t} & y_{N,t} & e_{1,N,t} & e_{2,N,t} & e_{3,N,t} & e_{4,N,t} & e_{5,N,t} & e_{6,N,t} & z_{N,t} & x_{N,t+1} & ... & z_{N,t+m}\\
x_{N,t+1} & y_{N,t+1} & e_{1,N,t+1} & e_{2,N,t+1} & e_{3,N,t+1} & e_{4,N,t+1} & e_{5,N,t+1} & e_{6,N,t+1} & z_{N,t+1} & x_{N,t+2} & ... & z_{N,t+m+1}\\
...\\
x_{N,\omega(N)-m-1} & y_{N,\omega(N)-m-1} & e_{1,N,\omega(N)-m-1} & e_{2,N,\omega(N)-m-1} & e_{3,N,\omega(N)-m-1} & e_{4,N,\omega(N)-m-1} & e_{5,N,\omega(N)-m-1} & e_{6,N,\omega(N)-m-1} & z_{N,\omega(N)-m-1} & x_{N,\omega(N)-m} & ... & z_{N,\omega(N)-1}\\
\end{bmatrix}
$$

Building Labels $\text{size} = ((\sum_{n=0}^{N}\omega(n) - m - 1) \times 3)$:
$$
\begin{bmatrix}
x_{0,t+m+1} & y_{0,t+m+1} & z_{0,t+m+1}\\
x_{0,t+m+2} & y_{0,t+m+2} & z_{0,t+m+2}\\
...\\
x_{0,\omega(0)} & y_{0,\omega(0)} & z_{0,\omega(0)}\\
...\\
x_{N,t+m+1} & y_{N,t+m+1} & z_{N,t+m+1}\\
x_{N,t+m+2} & y_{N,t+m+2} & z_{N,t+m+2}\\
...\\
x_{N,\omega(N)} & y_{N,\omega(N)} & z_{N,\omega(N)}\\
\end{bmatrix}
$$

In [7]:
def get_filteredLFP():
    lfp_input_file = 'data/raw_data/LFP_filt.txt'
    lfp_labels_file = 'data/raw_data/LFP_filt.txt'
    fir_file = 'data/raw_data/FR_PN_ITN.txt'
    aff_file = 'data/raw_data/AFF_PN_ITN.txt'
    with open(lfp_input_file) as f:
        lfp_in = f.read().splitlines()
    lfp_in = np.array([float(x) for x in lfp_in]).reshape((-1, 1))
    
    with open(lfp_labels_file) as f:
        lfp_out = f.read().splitlines()
    lfp_out = np.array([float(x) for x in lfp_out]).reshape((-1, 1))
#     print(lfp_out)
        
#     with open(fir_file) as f:
#         fr = f.read().splitlines()
#     fr = np.array([(float(x.split(',')[0]), float(x.split(',')[1])) for x in fr])
        
    with open(aff_file) as f:
        aff = f.read().splitlines()
    aff = np.array([(float(x.split(',')[0]), float(x.split(',')[1])) for x in aff])
        
#     full_data = np.hstack((lfp_in, aff))
    full_data = lfp_in
    full_labels = lfp_out
    
    training_samples = 900000
    indices = rng.integers(low=0, high=full_labels.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD), size=training_samples)
    validation_samples = 100000
    v_indices = rng.integers(low=0, high=full_labels.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD), size=validation_samples)
    training_data = []
    training_labels = []
    validation_data = []
    validation_labels = []
    f_data = []
    f_labels = []
    
    for idx in indices:
        training_data.append(full_data[idx:idx+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        training_labels.append(full_labels[idx+PREVIOUS_TIME+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
    training_data = np.stack(training_data, axis=0)
    training_labels = np.stack(training_labels, axis=0)
    
    for idx in v_indices:
        validation_data.append(full_data[idx:idx+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        validation_labels.append(full_labels[idx+PREVIOUS_TIME+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
    validation_data = np.stack(validation_data, axis=0)
    validation_labels = np.stack(validation_labels, axis=0)
    
    for i in range(full_data.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD)):
        f_data.append(full_data[i:i+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        f_labels.append(full_labels[i+PREVIOUS_TIME+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
    f_data = np.stack(f_data, axis=0)
    f_labels = np.stack(f_labels, axis=0)
    
    print('Training Data: {}'.format(training_data.shape))
    print('Training Labels: {}'.format(training_labels.shape))
    print('Validation Data: {}'.format(validation_data.shape))
    print('Validation Labels: {}'.format(validation_labels.shape))
    print('Full Data: {}'.format(f_data.shape))
    print('Full Labels: {}'.format(f_labels.shape))
    
    training_dataset = TensorDataset(torch.Tensor(training_data), torch.Tensor(training_labels))
    validation_dataset = TensorDataset(torch.Tensor(validation_data), torch.Tensor(validation_labels))
    f_dataset = TensorDataset(torch.Tensor(f_data), torch.Tensor(f_labels))

    return training_dataset, validation_dataset, f_dataset
# get_filteredLFP()

In [8]:
def get_rawLFP():
    lfp_input_file = 'data/raw_data/LFP_elec_combine.txt'
    lfp_labels_file = 'data/raw_data/LFP_filt.txt'
    fir_file = 'data/raw_data/FR_PN_ITN.txt'
    aff_file = 'data/raw_data/AFF_PN_ITN.txt'
    with open(lfp_input_file) as f:
        lfp_in = f.read().splitlines()
    lfp_in = np.array([float(x) for x in lfp_in]).reshape((-1, 1))
    
    with open(lfp_labels_file) as f:
        lfp_out = f.read().splitlines()
    lfp_out = np.array([float(x) for x in lfp_out]).reshape((-1, 1))
#     print(lfp_out)
        
    with open(fir_file) as f:
        fr = f.read().splitlines()
#     fr = np.array([(float(x.split(',')[0]), float(x.split(',')[1])) for x in fr])
    fr = np.array([float(x.split(',')[0]) for x in fr]).reshape((-1,1))
        
    with open(aff_file) as f:
        aff = f.read().splitlines()
    aff = np.array([(float(x.split(',')[0]), float(x.split(',')[1])) for x in aff])
        
#     full_data = np.hstack((lfp_in, fr))
#     print(full_data.shape)
    full_data = lfp_in
    full_labels = lfp_out
    
    training_samples = 900000
    indices = rng.integers(low=0, high=full_labels.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD), size=training_samples)
    validation_samples = 100000
    v_indices = rng.integers(low=0, high=full_labels.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD), size=validation_samples)
    training_data = []
    training_labels = []
    validation_data = []
    validation_labels = []
    f_data = []
    f_labels = []
    
    for idx in indices:
        training_data.append(full_data[idx:idx+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        training_labels.append(full_labels[idx+PREVIOUS_TIME+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
    training_data = np.stack(training_data, axis=0)
    training_labels = np.stack(training_labels, axis=0)
    
    for idx in v_indices:
        validation_data.append(full_data[idx:idx+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        validation_labels.append(full_labels[idx+PREVIOUS_TIME+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
    validation_data = np.stack(validation_data, axis=0)
    validation_labels = np.stack(validation_labels, axis=0)
    
    for i in range(full_data.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD)):
        f_data.append(full_data[i:i+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        f_labels.append(full_labels[i+PREVIOUS_TIME+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
    f_data = np.stack(f_data, axis=0)
    f_labels = np.stack(f_labels, axis=0)
    
    print('Training Data: {}'.format(training_data.shape))
    print('Training Labels: {}'.format(training_labels.shape))
    print('Validation Data: {}'.format(validation_data.shape))
    print('Validation Labels: {}'.format(validation_labels.shape))
    print('Full Data: {}'.format(f_data.shape))
    print('Full Labels: {}'.format(f_labels.shape))
    
    training_dataset = TensorDataset(torch.Tensor(training_data), torch.Tensor(training_labels))
    validation_dataset = TensorDataset(torch.Tensor(validation_data), torch.Tensor(validation_labels))
    f_dataset = TensorDataset(torch.Tensor(f_data), torch.Tensor(f_labels))

    return training_dataset, validation_dataset, f_dataset
# get_rawLFP()

In [9]:
def get_WN(time_s=300000, channels=1):
    lfp_input_file = 'data/raw_data/LFP_elec_combine.txt'
    with open(lfp_input_file) as f:
        lfp_in = f.read().splitlines()
    lfp_in = np.array([float(x) for x in lfp_in]).reshape((-1, 1))
    
#     print(np.std(lfp_in))
    
    noise = rng.normal(0, np.std(lfp_in), (time_s, channels))
    
    oscBand = np.array([0.08,0.14])
    b, a = signal.butter(4,oscBand,btype='bandpass')
    noise_filt = signal.lfilter(b, a, noise, axis=0)
    
#     print(noise_filt)
    
    full_data = noise
    full_labels = noise_filt
    f_data = []
    f_labels = []
    for i in range(full_data.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD)):
        f_data.append(full_data[i:i+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        f_labels.append(full_labels[i+PREVIOUS_TIME+LOOK_AHEAD,0].reshape((-1,OUTPUT_SIZE)))
    f_data = np.stack(f_data, axis=0)
    f_labels = np.stack(f_labels, axis=0)
    print('Noise Data: {}'.format(f_data.shape))
    print('Noise Labels: {}'.format(f_labels.shape))
    noise_dataset = TensorDataset(torch.Tensor(f_data), torch.Tensor(f_labels))
    return noise_dataset
# get_WN()

In [10]:
import math
def get_sin(time_s=3000, channels=1):
    A = .06 #Randomly chosen to be close to LFP magnitude
    data = []
    t_li = np.arange(0,time_s,0.001)
    for t in t_li:
        data.append(A*np.sin(t*(314.1516))) #A*np.sin((50*2*np.pi*t))
    data = np.array(data).reshape((-1,1))
    
#     oscBand = np.array([0.08,0.14])
#     b, a = signal.butter(4,oscBand,btype='bandpass')
#     data_filt = signal.lfilter(b, a, data, axis=0)
    
    full_data = data
    full_labels = data
    f_data = []
    f_labels = []
#     print(full_data)
    for i in range(full_data.shape[0]-(PREVIOUS_TIME+LOOK_AHEAD)):
        f_data.append(full_data[i:i+PREVIOUS_TIME,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES)))
        f_labels.append(full_labels[i+PREVIOUS_TIME+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
    f_data = np.stack(f_data, axis=0)
    f_labels = np.stack(f_labels, axis=0)
    print('Sin Data: {}'.format(f_data.shape))
    print('Sin Labels: {}'.format(f_labels.shape))
    sin_dataset = TensorDataset(torch.Tensor(f_data), torch.Tensor(f_labels))
    return sin_dataset
# get_sin()

In [11]:
DEFAULT_RAW_FILE = 'data/raw_data/LFP_elec_combine.txt'
# DEFAULT_RAW_FILE = 'data/raw_data/LFP_filt.txt'
DEFAULT_LABELS_FILE = 'data/raw_data/LFP_filt.txt'

from scipy import stats, signal

def get_burstLFP(in_file=DEFAULT_RAW_FILE, out_file=DEFAULT_LABELS_FILE):
    
    fir_file = fir_file = 'data/raw_data/FR_PN_ITN.txt'
    
    with open(in_file) as f:
        lfp_in = f.read().splitlines()
    lfp_in = np.array([float(x) for x in lfp_in]).reshape((-1, 1))
    
    with open(out_file) as f:
        lfp_out = f.read().splitlines()
    lfp_out = np.array([float(x) for x in lfp_out]).reshape((-1, 1))
    
    with open(fir_file) as f:
        fr = f.read().splitlines()
#     fr = np.array([(float(x.split(',')[0]), float(x.split(',')[1])) for x in fr])
    fr = np.array([float(x.split(',')[0]) for x in fr]).reshape((-1,1))
    
    hilb = np.abs(signal.hilbert(lfp_in)) #hilbert transform of raw data

    z_score = stats.zscore(hilb)
    thresh = np.mean(np.squeeze(hilb)) + 2*np.std(np.squeeze(hilb)) # 2*z_score
#     print(np.mean(np.squeeze(hilb)))
#     print(2*np.std(np.squeeze(hilb)))
#     print(thresh)

#     print(hilb.shape)
    indices = np.nonzero(np.squeeze(hilb)>thresh)[0]
#     print(indices.shape)

    idx_count = PREVIOUS_TIME + LOOK_AHEAD + 1

    burst_indices = []
    temp_idx = []
    #start at the second index and compare to first one
    for i, idx in enumerate(indices[1:], 0):
        #if the index is not next start a new sample
        if idx - indices[i] != 1 and temp_idx:
            if len(temp_idx) < idx_count:
                padding = idx_count - len(temp_idx)
                pad = np.arange(temp_idx[0]-padding+1,temp_idx[0],1)
                temp_idx[:0] = pad
                if temp_idx[0] < 0:
                    temp_idx = []
                    continue
#                 print(len(temp_idx))
#             print(np.array(temp_idx).reshape((-1,1)).shape)
            burst_indices.append(np.array(temp_idx).reshape((-1,1)))
            temp_idx = []
            temp_idx.append(idx)
        #otherwise add the index to the sample
        else:
            temp_idx.append(idx)

    # print(np.squeeze(indices))
    burst_in = []
    burst_out = []
    filt_out = []
    
    oscBand = np.array([0.08,0.14])
    b, a = signal.butter(4,oscBand,btype='bandpass')
    
    for sample in burst_indices:
#         print(sample.shape)
        if sample.shape[0] >= 10:
            inp = np.take(lfp_in, np.squeeze(sample[:PREVIOUS_TIME])).reshape((-1,1))
            fr_i = np.take(fr, np.squeeze(sample[:PREVIOUS_TIME])).reshape((-1,1))
            inp = np.concatenate((inp, fr_i), axis=1)
#             print(inp.shape)
            filt = np.take(lfp_out, np.squeeze(sample[-1])).reshape((-1,1))
            lab = np.take(lfp_in, np.squeeze(sample[PREVIOUS_TIME:,:])).reshape((-1,1))
            lab = signal.lfilter(b, a, lab.reshape((-1,OUTPUT_SIZE)), axis=0)
#             print(lab.shape)
            burst_in.append(inp)
            burst_out.append(lab[-1,:])
            filt_out.append(filt)
#     print(np.stack(burst_out).reshape((-1, 1, 1)))
    burst_in = np.transpose(np.stack(burst_in), (0,2,1))
    burst_out = np.transpose(np.stack(burst_out).reshape((-1,1,1)), (0,2,1))
    filt_out = np.transpose(np.stack(filt_out), (0,2,1))
    print('Burst Data: {}'.format(burst_in.shape))
    print('Burst Labels: {}'.format(burst_out.shape))
    print('Filter Labels: {}'.format(filt_out.shape))
    burst_dataset = TensorDataset(torch.Tensor(burst_in), torch.Tensor(burst_out))
    filt_dataset = TensorDataset(torch.Tensor(burst_in), torch.Tensor(filt_out))
    return burst_dataset, filt_dataset

# get_burstLFP()

In [12]:
def get_end1D(training_samples=900000, validation_samples=100000):
    lfp_input_file = 'data/raw_data/LFP_elec_combine.txt'
    lfp_filt_file = 'data/raw_data/LFP_filt.txt'
    fir_file = fir_file = 'data/raw_data/FR_PN_ITN.txt'
    
    with open(lfp_input_file) as f:
        lfp_in = f.read().splitlines()
    lfp_in = np.array([float(x) for x in lfp_in]).reshape((-1, 1))
    
    with open(lfp_filt_file) as f:
        lfp_filt = f.read().splitlines()
    lfp_filt = np.array([float(x) for x in lfp_filt]).reshape((-1, 1))
    
    with open(fir_file) as f:
        fr = f.read().splitlines()
#     fr = np.array([(float(x.split(',')[0]), float(x.split(',')[1])) for x in fr])
    fr = np.array([float(x.split(',')[0]) for x in fr]).reshape((-1,1))
    
    oscBand = np.array([0.08,0.14])
    b, a = signal.butter(4,oscBand,btype='bandpass')


    t_indices = rng.integers(low=PREVIOUS_TIME, high=lfp_in.shape[0]-LOOK_AHEAD, size=training_samples)
    v_indices = rng.integers(low=PREVIOUS_TIME, high=lfp_in.shape[0]-LOOK_AHEAD, size=validation_samples)

    training_data = []
    training_labels = []
    training_filt = []
    validation_data = []
    validation_labels = []
    validation_filt = []
    f_data = []
    f_labels = []
    f_filt = []
    
    for idx in t_indices:
        tlfp = lfp_in[idx-PREVIOUS_TIME:idx,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES))
        t_fr = fr[idx-PREVIOUS_TIME:idx,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES))
        training_data.append(np.concatenate((tlfp, t_fr), axis=0))
        filter_1d = signal.lfilter(b, a, lfp_in[idx:idx+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)), axis=0)
        training_filt.append(lfp_filt[idx+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
        training_labels.append(filter_1d[-1,:])
    training_data = np.stack(training_data, axis=0)
    training_labels = np.stack(training_labels, axis=0)
    training_filt = np.stack(training_filt, axis=0).reshape((-1,1))
    
    for idx in v_indices:
        vlfp = lfp_in[idx-PREVIOUS_TIME:idx,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES))
        v_fr = fr[idx-PREVIOUS_TIME:idx,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES))
        validation_data.append(np.concatenate((vlfp, v_fr), axis=0))
        filter_1d = signal.lfilter(b, a, lfp_in[idx:idx+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)), axis=0)
        validation_filt.append(lfp_filt[idx+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
        validation_labels.append(filter_1d[-1,:])
    validation_data = np.stack(validation_data, axis=0)
    validation_labels = np.stack(validation_labels, axis=0)
    validation_filt = np.stack(validation_filt, axis=0).reshape((-1,1))
    
#     for i in range(PREVIOUS_TIME, lfp_in.shape[0]-LOOK_AHEAD, 1):
#         flfp = lfp_in[i-PREVIOUS_TIME:i,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES))
#         f_fr = fr[i-PREVIOUS_TIME:i,:].reshape((-1,PREVIOUS_TIME*INPUT_FEATURES))
#         f_data.append(np.concatenate((flfp, f_fr), axis=1))
#         filter_1d = signal.lfilter(b, a, lfp_in[i:i+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)), axis=0)
#         f_filt.append(lfp_filt[i+LOOK_AHEAD,:].reshape((-1,OUTPUT_SIZE)))
#         f_labels.append(filter_1d[-1,:])
#     f_data = np.stack(f_data, axis=0)
#     f_labels = np.stack(f_labels, axis=0)
#     f_filt = np.stack(f_filt, axis=0).reshape((-1,1))
    
    print('Training Data: {}'.format(training_data.shape))
    print('Training Labels: {}'.format(training_labels.shape))
    print('Training Filter: {}'.format(training_filt.shape))
    print('Validation Data: {}'.format(validation_data.shape))
    print('Validation Labels: {}'.format(validation_labels.shape))
    print('Validation Filter: {}'.format(validation_filt.shape))
#     print('Full Data: {}'.format(f_data.shape))
#     print('Full Labels: {}'.format(f_labels.shape))
#     print('Full Filter: {}'.format(f_filt.shape))
#     print(f_filt)
    
    training_dataset = TensorDataset(torch.Tensor(training_data), torch.Tensor(training_labels))
    training_filt = TensorDataset(torch.Tensor(training_data), torch.Tensor(training_filt))
    validation_dataset = TensorDataset(torch.Tensor(validation_data), torch.Tensor(validation_labels))
    validation_filt = TensorDataset(torch.Tensor(validation_data), torch.Tensor(validation_filt))
#     f_dataset = TensorDataset(torch.Tensor(f_data), torch.Tensor(f_labels))
#     f_filt = TensorDataset(torch.Tensor(f_data), torch.Tensor(f_filt))

    return training_dataset, validation_dataset, training_filt, validation_filt#f_dataset, training_filt, validation_filt, f_filt
# get_end1D()

In [13]:
def train_model(model,save_filepath,training_loader,validation_loader,epochs,device):
    epochs_list = []
    train_loss_list = []
    val_loss_list = []
    training_len = len(training_loader.dataset)
    validation_len = len(validation_loader.dataset)
    
#     feedback_arr = torch.zeros(BATCH_SIZE, 90)
    
    #splitting the dataloaders to generalize code
    data_loaders = {"train": training_loader, "val": validation_loader}
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_func = nn.MSELoss()
#     loss_func = nn.L1Loss()
    decay_rate = 0.95 #decay the lr each step to 98% of previous lr
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

    total_start = time.time()

    """
    You can easily adjust the number of EPOCHS trained here by changing the number in the range
    """
    for epoch in tqdm(range(epochs), position=0, leave=True):
        start = time.time()
        train_loss = 0.0
        val_loss = 0.0
        temp_loss = 100000000000000.0
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train(True)
            else:
                model.train(False)

            running_loss = 0.0
            for i, (x, y) in enumerate(data_loaders[phase]):
                x = x.to(device)
                output = model(x)
                y = y.to(device)
#                 if i%100000 == 0 and epoch%5 == 0:
#                     print(output)
#                     print(y)
                loss = loss_func(torch.squeeze(output), torch.squeeze(y)) 
                #backprop             
                optimizer.zero_grad()
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
#                     if i%100000 == 0 and epoch%5 == 0:
#                         print(model.cn1.weight.grad)
#                         print(model.cn2.weight.grad)
#                         print(model.fc1.weight.grad)
#                         print(model.fc2.weight.grad)

                #calculating total loss
                running_loss += loss.item()
            
            if phase == 'train':
                train_loss = running_loss
                lr_sch.step()
            else:
                val_loss = running_loss

        end = time.time()
        # shows total loss
        if epoch%5 == 0:
            print('[%d, %5d] train loss: %.6f val loss: %.6f' % (epoch + 1, i + 1, train_loss, val_loss))
#         print(end - start)
        
        #saving best model
        if train_loss < temp_loss:
            torch.save(model, save_filepath)
            temp_loss = train_loss
        epochs_list.append(epoch)
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)
    total_end = time.time()
#     print(total_end - total_start)
    #Creating loss csv
    loss_df = pds.DataFrame(
        {
            'epoch': epochs_list,
            'training loss': train_loss_list,
            'validation loss': val_loss_list
        }
    )
    # Writing loss csv, change path to whatever you want to name it
    
    loss_df.to_csv(LOSS_FILE, index=None)
    return train_loss_list, val_loss_list

In [14]:
# f_tr, f_va, f_data = get_filteredLFP()
# f_tr, f_va, f_data = get_rawLFP()
f_tr, f_va, t_filt, v_filt = get_end1D()#f_data, t_filt, v_filt, f_filt = get_end1D()
noise = get_WN(channels=2)
# sin = get_sin()
burst, fburst = get_burstLFP()

# Turn datasets into iterable dataloaders
train_loader = DataLoader(dataset=f_tr,batch_size=BATCH_SIZE)
tfilt_loader = DataLoader(dataset=t_filt,batch_size=BATCH_SIZE)
val_loader = DataLoader(dataset=f_va,batch_size=BATCH_SIZE)
vfilt_loader = DataLoader(dataset=v_filt,batch_size=BATCH_SIZE)
# full_loader = DataLoader(dataset=f_data,batch_size=BATCH_SIZE)
# ffull_loader = DataLoader(dataset=f_filt,batch_size=BATCH_SIZE)
noise_loader = DataLoader(dataset=noise,batch_size=BATCH_SIZE)
# sine_loader = DataLoader(dataset=sin,batch_size=BATCH_SIZE)
burst_loader = DataLoader(dataset=burst,batch_size=BATCH_SIZE)
fburst_loader = DataLoader(dataset=fburst,batch_size=BATCH_SIZE)

Training Data: (900000, 2, 50)
Training Labels: (900000, 1)
Training Filter: (900000, 1)
Validation Data: (100000, 2, 50)
Validation Labels: (100000, 1)
Validation Filter: (100000, 1)
Noise Data: (299930, 2, 50)
Noise Labels: (299930, 1, 1)
Burst Data: (55198, 2, 50)
Burst Labels: (55198, 1, 1)
Filter Labels: (55198, 1, 1)


In [15]:
import copy

model1 = MODEL(INPUT_SIZE,HIDDEN_SIZE,OUTPUT_SIZE)
model_initial = copy.deepcopy(model1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model1.to(device)

pnfr_training_loss, pnfr_validation_loss = train_model(model1,PATH,train_loader,
                                                       val_loader,EPOCHS,device)

  0%|          | 0/20 [00:00<?, ?it/s]

[1,   196] train loss: 0.104166 val loss: 0.006472
[6,   196] train loss: 0.037616 val loss: 0.004166
[11,   196] train loss: 0.036788 val loss: 0.004078
[16,   196] train loss: 0.036461 val loss: 0.004086


In [16]:
def r2_eval(model, testing_dataloader, filt=None, k=None):
    output_list = []
    labels_list = []
    temp_list = []
    if filt is not None:
        filt = iter(filt)
    for i, (x, y) in enumerate(testing_dataloader):
        output = model(x)         
        if filt is None:
            output_list.append(output.detach().cpu().numpy())
            labels_list.append(y.detach().cpu().numpy())
        else:
            xf, yf = next(filt)
            yf = yf.detach().cpu().numpy().reshape((-1,1))
#             print(yf)
            y = y.detach().cpu().numpy().reshape((-1,1))
            output = output.detach().cpu().numpy().reshape((-1,1))
#             print((yf-y).shape)
            pred = yf-y+output
#             print(pred.shape)
            output_list.append(pred)
            labels_list.append(yf)
        if k != None and i == k-1:
            break
#     print("Output list size: {}".format(len(output_list)))
#     print(output_list[0].shape)
    output_list = np.squeeze(np.concatenate(output_list, axis=0))
#     print(output_list.shape)
    labels_list = np.squeeze(np.concatenate(labels_list, axis=0))
#     print(labels_list.shape)
#     print(output_list.shape)
#     print(labels_list.shape)
    print(r2_score(labels_list, output_list))
    return output_list, labels_list

RandomState(MT19937(SeedSequence(123456789)))### Mask
Variables:
$$
x = \text{PN Firing Rate}\\
y = \text{ITN Firing Rate}\\
z = \text{Local Field Potential}\\
t = \text{Timestep}\\
m = \text{Number of Previous Timesteps}\\
e_{x} = \text{External Input}\\
$$


Output Array:
$$
\begin{bmatrix}
x_{n,t+m} & y_{n,t+m} & z_{n,t+m}\\
\end{bmatrix}
$$
Feedback Array: 
$$
\begin{bmatrix}
x_{n,t-1} & y_{n,t-1} & z_{n,t-1} & x_{n,t} & ... & z_{n,t+m-1}\\
\end{bmatrix}
$$
Replace index 0, 1, 2 with Output Array and Roll

Rolled Feedback Array: 
$$
\begin{bmatrix}
x_{n,t} & y_{n,t} & z_{n,t} & x_{n,t+1} & ... & z_{n,t+m}\\
\end{bmatrix}
$$
Input Array:
$$
\begin{bmatrix}
x_{n,t} & y_{n,t} & e_{1,n,t} & e_{2,n,t} & e_{3,n,t} & e_{4,n,t} & e_{5,n,t} & e_{6,n,t} & z_{n,t} & x_{n,t+1} & ... & z_{n,t+m}\\
\end{bmatrix}
$$
New Output:
$$
\begin{bmatrix}
x_{n,t+m+1} & y_{n,t+m+1} & z_{n,t+m+1}\\
\end{bmatrix}
$$

In [17]:
model1 = torch.load(PATH)
model1.eval()

start = 40
k = 10000
end= (start + k) if k != None else None

model1.to('cpu')

t_pred, t_real = r2_eval(model1, train_loader, filt=tfilt_loader ,k=end)
v_pred, v_real = r2_eval(model1, val_loader, filt=vfilt_loader, k=end)
# f_pred, f_real = r2_eval(model1, full_loader,filt=ffull_loader, k=end)
# print(f_real)
# t_pred, t_real = r2_eval(model1, train_loader, filt=None ,k=end)
# v_pred, v_real = r2_eval(model1, val_loader, filt=None, k=end)
# f_pred, f_real = r2_eval(model1, full_loader,filt=None, k=end)
# n_pred, n_real = r2_eval(model1, noise_loader, end)
# s_pred, s_real = r2_eval(model1, sine_loader, end)
b_pred, b_real = r2_eval(model1, burst_loader, filt=fburst_loader,k=end)
# for i in range(len(s_pred)):
#     print("output: {} label: {}".format(s_pred[i], s_real[i]))

0.9772238651890114
0.977000965633649
0.9804787399051351


In [18]:
from sklearn.metrics import mean_squared_error

print("Train MSE: {:f}".format(mean_squared_error(t_real, t_pred)))
print("Val MSE: {:f}".format(mean_squared_error(v_real, v_pred)))
# print("Full MSE: {:f}".format(mean_squared_error(f_real, f_pred)))
print("Burst MSE: {:f}".format(mean_squared_error(b_real, b_pred)))

Train MSE: 0.000021
Val MSE: 0.000021
Burst MSE: 0.000018


In [19]:
fig1, ax1 = plt.subplots(nrows=1, ncols=2)
fig1.tight_layout()
ax1[0].plot(range(EPOCHS), pnfr_training_loss)
ax1[0].set_title('Training Loss')
ax1[0].set_ylabel('Loss')
ax1[0].set_xlabel('Epoch')

ax1[1].plot(range(EPOCHS), pnfr_validation_loss)
ax1[1].set_title('Validation Loss')
ax1[1].set_ylabel('Loss')
ax1[1].set_xlabel('Epoch')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, -5.47777777777778, 'Epoch')

In [20]:
fig, ax = plt.subplots(nrows=2, ncols=1)
fig.tight_layout()

ax[0].plot(np.arange(start-OUTPUT_SIZE,end), v_real[start-OUTPUT_SIZE:end], color='blue',label='Labels')
# ax[2,0].plot(np.arange(start-10,end), v_output_list[start-10:end,2], color='red',label='Internal Loop')
ax[0].scatter(np.arange(start-OUTPUT_SIZE,end), v_pred[start-1:end], color='slateblue',label='Training t+10')
# ax[0].scatter(np.arange(start-OUTPUT_SIZE,end), v_pred[start-2:end+8,1], color='lightsteelblue',label='Training t+2')
# ax[0].scatter(np.arange(start-OUTPUT_SIZE,end), v_pred[start-3:end+7,2], color='gray',label='Training t+3')
# ax[0].scatter(np.arange(start-OUTPUT_SIZE,end), v_pred[start-4:end+6,3], color='sienna',label='Training t+4')
# ax[0].scatter(np.arange(start-OUTPUT_SIZE,end), v_pred[start-5:end+5,4], color='magenta',label='Training t+5')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-6:end+4,5], color='aquamarine',label='Training t+6')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-7:end+3,6], color='darkorange',label='Training t+7')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-8:end+2,7], color='brown',label='Training t+8')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-9:end+1,8], color='purple',label='Training t+9')
# ax[0].plot(np.arange(start-10,end), v_pred[start-10:end], color='green',label='Training t+10')


ax[0].set_title('Validation LFP')
ax[0].set_ylabel('LFP')
ax[0].set_xlabel('Time')
# ax[2,0].legend()

ax[1].plot(np.arange(start-OUTPUT_SIZE,end), t_real[start-OUTPUT_SIZE:end], color='blue',label='Labels')
# a[2,1].plot(np.arange(start-10,end), t_output_list[start-10:end,2], color='red',label='Internal Loop')
ax[1].scatter(np.arange(start-OUTPUT_SIZE,end), t_pred[start-1:end], color='slateblue',label='Training t+10')
# ax[1].scatter(np.arange(start-OUTPUT_SIZE,end), t_pred[start-2:end+8,1], color='lightsteelblue',label='Training t+2')
# ax[1].scatter(np.arange(start-OUTPUT_SIZE,end), t_pred[start-3:end+7,2], color='gray',label='Training t+3')
# ax[1].scatter(np.arange(start-OUTPUT_SIZE,end), t_pred[start-4:end+6,3], color='sienna',label='Training t+4')
# ax[1].scatter(np.arange(start-OUTPUT_SIZE,end), t_pred[start-5:end+5,4], color='magenta',label='Training t+5')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-6:end+4,5], color='aquamarine',label='Training t+6')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-7:end+3,6], color='darkorange',label='Training t+7')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-8:end+2,7], color='brown',label='Training t+8')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-9:end+1,8], color='purple',label='Training t+9')
# ax[1].plot(np.arange(start-10,end), t_pred[start-10:end], color='green',label='Training t+10')

ax[1].set_title('Training LFP')
ax[1].set_ylabel('LFP')
ax[1].set_xlabel('Time')
ax[1].legend()

# import plotly.tools as tls
# plotly_fig = tls.mpl_to_plotly(fig)
# plotly_fig.write_html("testfile.html")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
fig, ax = plt.subplots(nrows=1, ncols=1)
fig.tight_layout()

ax.plot(np.arange(start,end), f_real[start:end], color='blue',label='Labels')
ax.scatter(np.arange(start,end), f_pred[start:end], color='red',label='Training t+10')
ax.set_title('Full LFP vs Time')
ax.set_ylabel('Signal')
ax.set_xlabel('Time')

# ax[1].plot(np.arange(start,end), n_real[start:end], color='blue',label='Labels')
# ax[1].scatter(np.arange(start,end), n_pred[start:end], color='red',label='Training t+10')
# ax[1].set_title('Noise')
# ax[1].set_ylabel('Signal')
# ax[1].set_xlabel('Time')z

# ax.plot(np.arange(start,end), s_real[start:end], color='blue',label='Labels')
# ax.scatter(np.arange(start,end), s_pred[start:end], color='red',label='Training t+10')
# ax.set_title('Sine')
# ax.set_ylabel('LFP')
# ax.set_xlabel('Time')

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

NameError: name 'f_real' is not defined

In [None]:
plt.close('all')

In [None]:
import loss_landscapes
import loss_landscapes.metrics
import copy
from mpl_toolkits.mplot3d import axes3d, Axes3D 

STEPS = 100
# model_initial = MODEL(INPUT_SIZE,HIDDEN_SIZE,OUTPUT_SIZE)
model_final = copy.deepcopy(model1)


# data that the evaluator will use when evaluating loss
x, y = iter(noise_loader).__next__()
metric = loss_landscapes.metrics.Loss(nn.MSELoss(), x, y)


loss_data_fin = loss_landscapes.random_plane(model_final, metric, 10000, STEPS, normalization='model', deepcopy_model=True)
# plt.contour(loss_data_fin, levels=50)
# plt.title('Loss Contours around Trained Model')
# plt.show()

In [None]:
fig = plt.figure()
ax = Axes3D(fig)
X = np.array([[j for j in range(STEPS)] for i in range(STEPS)])
Y = np.array([[i for _ in range(STEPS)] for i in range(STEPS)])
ax.plot_surface(X, Y, loss_data_fin, rstride=1, cstride=1, cmap='viridis', edgecolor='none')
ax.set_title('Surface Plot of Loss Landscape')
fig.show()