In [1]:
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !git clone https://github.com/MJC598/Neuron_Burst_Analysis.git

In [2]:
%matplotlib widget
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import scipy.io
import random
import time
import pandas as pds
from sklearn.metrics import r2_score

torch.manual_seed(0)

<torch._C.Generator at 0x7fad1e8286d0>

![Fully Connected Network](graphs/fullyConnectedNetwork.png)

In [3]:
class FCN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FCN,self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = self.fc1(x)
        x = self.tanh(x)
        out = self.fc2(x)
        return out

In [4]:
FRONT_TIME = -50
BACK_TIME = 40
T_START = 50+FRONT_TIME
T_END = 'Variable'
MODEL = FCN
OUTPUT = 'FR_LFP'
LOSS_FILE = ('losses/bursts/losses_' + str(MODEL) + 
             '_' + OUTPUT + '_' + str(FRONT_TIME) + 
             '_' + str(T_END) + '_30ms_avgaff_and_fr.csv')
PATH = ('models/bursts/' + str(MODEL) + '_' + OUTPUT + 
        '_' + str(FRONT_TIME) + '_' + str(T_END) + 
        '_30ms_avgaff_and_fr.pth')
# PATH = 'models/<class \'__main__.FCN\'>_FR_LFP_-50_90_full.pth'

DATA_PATH = 'data/bursts/burst_separatePNITNv2.mat'
COLAB_PRE = 'Neuron_Burst_Analysis/'
if RunningInCOLAB:
    LOSS_FILE = COLAB_PRE + LOSS_FILE
    PATH = COLAB_PRE + PATH
    DATA_PATH = COLAB_PRE + DATA_PATH

# Specific Model Parameters
input_features = 9
previous_time = 10
input_size = input_features * previous_time
hidden_size = 95
output_size = 3
batch_size = 32
num_layers = 1
batch_first = True
dropout = 0.0
epochs = 50

### Build Data From Matlab File

Read in Matlab info and Average over sliding window of size 3 PN Afferent and PN Firing Rate. 
Variables:
$$
x = \text{PN Firing Rate}\\
y = \text{ITN Firing Rate}\\
z = \text{Local Field Potential}\\
e_{1} = \text{PN Afferent}\\
e_{2} = \text{ITN Afferent}\\
e_{3} = \text{PN Excitatory Point Conductance}\\
e_{4} = \text{ITN Excitatory Point Conductance}\\
e_{5} = \text{PN Inhibitory Point Conductance}\\
e_{6} = \text{ITN Inhibitory Point Conductance}\\
t = \text{Timestep}\\
m = \text{Number of Previous Timesteps}\\
N = \text{Number of Samples}\\
n = \text{Individual Sample}\\
f = \text{Number of Features}\\
\omega(n) = \text{Time Sequence of Sample n}\\
$$

Sliding Window:
$$
x_{t} = 
\begin{cases} 
    \frac{1}{3}\sum_{i=0}^{2}x_{t+i} & \text{if } t \geq 2\\
    x_{t} & otherwise\\
\end{cases}\\
e_{1,t} = 
\begin{cases}
    \frac{1}{3}\sum_{i=0}^{2}e_{1,t+i}& \text{if } t \geq 2\\
    e_{1,t} & otherwise\\
\end{cases}\\
$$

Data $\text{size} = (N \times 1) , n = (\omega(n) \times f)$:
$$
\begin{bmatrix}
x_{t} & y_{t} & e_{1, t} & e_{2, t} & e_{3, t} & e_{4, t} & e_{5, t} & e_{6, t} & z_{t}\\
x_{t+1} & y_{t+1} & e_{1, t+1} & e_{2, t+1} & e_{3, t+1} & e_{4, t+1} & e_{5, t+1} & e_{6, t+1} & z_{t+1}\\
...\\
x_{t+\omega(n)} & y_{t+\omega(n)} & e_{1, t+\omega(n)} & e_{2, t+\omega(n)} & e_{3, t+\omega(n)} & e_{4, t+\omega(n)} & e_{5, t+\omega(n)} & e_{6, t+\omega(n)} & z_{t+\omega(n)}
\end{bmatrix}
$$

Label $\text{size} = (N \times 1) , n = (\omega(n) \times 3)$:
$$
\begin{bmatrix}
x_{t} & y_{t} & z_{t}\\
x_{t+1} & y_{t+1} & z_{t+1}\\
...\\
x_{t+\omega(n)} & y_{t+\omega(n)} & z_{t+\omega(n)}
\end{bmatrix}
$$

Building Inputs $\text{size} = ((\sum_{n=0}^{N}\omega(n) - m - 1) \times (f*m))$:
$$
\begin{bmatrix}
x_{0,t} & y_{0,t} & e_{1,0,t} & e_{2,0,t} & e_{3,0,t} & e_{4,0,t} & e_{5,0,t} & e_{6,0,t} & z_{0,t} & x_{0,t+1} & ... & z_{0,t+m}\\
x_{0,t+1} & y_{0,t+1} & e_{1,0,t+1} & e_{2,0,t+1} & e_{3,0,t+1} & e_{4,0,t+1} & e_{5,0,t+1} & e_{6,0,t+1} & z_{0,t+1} & x_{0,t+2} & ... & z_{0,t+m+1}\\
...\\
x_{0,\omega(0)-m-1} & y_{0,\omega(0)-m-1} & e_{1,0,\omega(0)-m-1} & e_{2,0,\omega(0)-m-1} & e_{3,0,\omega(0)-m-1} & e_{4,0,\omega(0)-m-1} & e_{5,0,\omega(0)-m-1} & e_{6,0,\omega(0)-m-1} & z_{0,\omega(0)-m-1} & x_{0,\omega(0)-m} & ... & z_{0,\omega(0)-1}\\
...\\
x_{N,t} & y_{N,t} & e_{1,N,t} & e_{2,N,t} & e_{3,N,t} & e_{4,N,t} & e_{5,N,t} & e_{6,N,t} & z_{N,t} & x_{N,t+1} & ... & z_{N,t+m}\\
x_{N,t+1} & y_{N,t+1} & e_{1,N,t+1} & e_{2,N,t+1} & e_{3,N,t+1} & e_{4,N,t+1} & e_{5,N,t+1} & e_{6,N,t+1} & z_{N,t+1} & x_{N,t+2} & ... & z_{N,t+m+1}\\
...\\
x_{N,\omega(N)-m-1} & y_{N,\omega(N)-m-1} & e_{1,N,\omega(N)-m-1} & e_{2,N,\omega(N)-m-1} & e_{3,N,\omega(N)-m-1} & e_{4,N,\omega(N)-m-1} & e_{5,N,\omega(N)-m-1} & e_{6,N,\omega(N)-m-1} & z_{N,\omega(N)-m-1} & x_{N,\omega(N)-m} & ... & z_{N,\omega(N)-1}\\
\end{bmatrix}
$$

Building Labels $\text{size} = ((\sum_{n=0}^{N}\omega(n) - m - 1) \times 3)$:
$$
\begin{bmatrix}
x_{0,t+m+1} & y_{0,t+m+1} & z_{0,t+m+1}\\
x_{0,t+m+2} & y_{0,t+m+2} & z_{0,t+m+2}\\
...\\
x_{0,\omega(0)} & y_{0,\omega(0)} & z_{0,\omega(0)}\\
...\\
x_{N,t+m+1} & y_{N,t+m+1} & z_{N,t+m+1}\\
x_{N,t+m+2} & y_{N,t+m+2} & z_{N,t+m+2}\\
...\\
x_{N,\omega(N)} & y_{N,\omega(N)} & z_{N,\omega(N)}\\
\end{bmatrix}
$$

In [5]:
def get_full_data_from_mat(file_path, output_index=None, type='pre_pn'):
    data = scipy.io.loadmat(file_path)

    full_labels = []
    full_data = []
    
#     print(data['info_collect'][0])
    for i in range(1, data['info_collect'].shape[0]):
        arr = data['info_collect'][i]
        
        pnfr = np.row_stack((arr[2], arr[11]))
        pnaff = np.row_stack((arr[4], arr[13]))
        
        for j in range(pnfr.shape[0]):
            if j != 0 and j != 1:
                pnfr[j,:] = (pnfr[j-2,:] + pnfr[j-1,:] + pnfr[j,:])/3
                pnaff[j,:] = (pnaff[j-2,:] + pnaff[j-1,:] + pnaff[j,:])/3
            else:
                pnfr[j,:] = pnfr[j,:]
                pnaff[j,:] = pnaff[j,:]
        
        full_labels.append(np.column_stack((pnfr, 
                                            np.row_stack((arr[3], arr[12])), 
                                            np.row_stack((arr[10], arr[19])))))
        
        full_data.append(np.column_stack((pnfr, 
                                          np.row_stack((arr[3], arr[12])), 
                                          pnaff, 
                                          np.row_stack((arr[5], arr[14])), 
                                          np.row_stack((arr[6], arr[15])), 
                                          np.row_stack((arr[7], arr[16])), 
                                          np.row_stack((arr[8], arr[17])), 
                                          np.row_stack((arr[9], arr[18])), 
                                          np.row_stack((arr[10], arr[19])))))

#     print(full_data[0].shape)
    full_data = np.asarray(full_data,dtype=object)
    full_labels = np.asarray(full_labels,dtype=object)
    
    random.seed(10)
    data_samples = 5472 #5498 
    k = 4352
    
    lag = 1
    front_offset = 11
#     print(full_data[0].shape)
    training_data = full_data[:k] 
    validation_data = full_data[k:data_samples]
    training_labels = full_labels[:k]
    validation_labels = full_labels[k:data_samples]
    
    td = [] 
    tl = []
    vd = []
    vl = []
    for i, d in enumerate(training_data):
        sample = d[:-1,:]
        label = training_labels[i][1:,:]
        for j in range(sample.shape[0]-previous_time):
            t = sample[j:j+previous_time,:]
            td.append(t.reshape((1,-1)))
            t2 = label[j+previous_time,:]
            tl.append(t2.reshape((1,-1)))
    td = np.vstack(td)
    tl = np.vstack(tl)
    
    
    for i, d in enumerate(validation_data):
        sample = d[:-1,:]
        label = validation_labels[i][1:,:]
        for j in range(sample.shape[0]-previous_time):
            t = sample[j:j+previous_time,:]
            vd.append(t.reshape((1,-1)))
            t2 = label[j+previous_time,:]
            vl.append(t2.reshape((1,-1)))
    vd = np.vstack(vd)
    vl = np.vstack(vl)

    training_dataset = TensorDataset(torch.Tensor(td), torch.Tensor(tl))
    validation_dataset = TensorDataset(torch.Tensor(vd), torch.Tensor(vl))

    return training_dataset, validation_dataset
# get_full_data_from_mat(DATA_PATH)

(<torch.utils.data.dataset.TensorDataset at 0x7face44c1130>,
 <torch.utils.data.dataset.TensorDataset at 0x7face44c1fd0>)

In [7]:
def train_model(model,save_filepath,training_loader,validation_loader,epochs,device):
    epochs_list = []
    train_loss_list = []
    val_loss_list = []
    training_len = len(training_loader.dataset)
    validation_len = len(validation_loader.dataset)
    
#     feedback_arr = torch.zeros(batch_size, 90)
    
    #splitting the dataloaders to generalize code
    data_loaders = {"train": training_loader, "val": validation_loader}
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_func = nn.MSELoss()
#     loss_func = nn.L1Loss()
    decay_rate = 0.93 #decay the lr each step to 93% of previous lr
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

    total_start = time.time()

    """
    You can easily adjust the number of epochs trained here by changing the number in the range
    """
    for epoch in tqdm(range(epochs), position=0, leave=True):
        start = time.time()
        train_loss = 0.0
        val_loss = 0.0
        temp_loss = 100000000000000.0
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train(True)
            else:
                model.train(False)

            running_loss = 0.0
            for i, (x, y) in enumerate(data_loaders[phase]):
                x = x.to(device)
                output = model(x)
                y = y.to(device)
                loss = loss_func(torch.squeeze(output), torch.squeeze(y)) 
                #backprop             
                optimizer.zero_grad()           
                if phase == 'train':
                    loss.backward()
                    optimizer.step()                                      

                #calculating total loss
                running_loss += loss.item()
            
            if phase == 'train':
                train_loss = running_loss
                lr_sch.step()
            else:
                val_loss = running_loss

        end = time.time()
        # shows total loss
        if epoch%5 == 0:
            print('[%d, %5d] train loss: %.6f val loss: %.6f' % (epoch + 1, i + 1, train_loss, val_loss))
#         print(end - start)
        
        #saving best model
        if val_loss < temp_loss:
            torch.save(model, save_filepath)
            temp_loss = val_loss
        epochs_list.append(epoch)
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)
    total_end = time.time()
#     print(total_end - total_start)
    #Creating loss csv
    loss_df = pds.DataFrame(
        {
            'epoch': epochs_list,
            'training loss': train_loss_list,
            'validation loss': val_loss_list
        }
    )
    # Writing loss csv, change path to whatever you want to name it
    
    loss_df.to_csv(LOSS_FILE, index=None)
    return train_loss_list, val_loss_list

In [8]:
model1 = MODEL(input_size,hidden_size,output_size)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model1.to(device)

f_tr, f_va = get_full_data_from_mat(DATA_PATH)

# Turn datasets into iterable dataloaders
f_training_loader = DataLoader(dataset=f_tr,batch_size=batch_size)
f_validation_loader = DataLoader(dataset=f_va,batch_size=batch_size)

pnfr_training_loss, pnfr_validation_loss = train_model(model1,PATH,f_training_loader,
                                                       f_validation_loader,epochs,device)

  0%|          | 0/50 [00:00<?, ?it/s]

[1,  6125] train loss: 3048557.341652 val loss: 737220.934910
[6,  6125] train loss: 2558882.646463 val loss: 661238.938334
[11,  6125] train loss: 2446869.373968 val loss: 644653.442577
[16,  6125] train loss: 2402438.409304 val loss: 634241.124636
[21,  6125] train loss: 2359269.151180 val loss: 615347.841049
[26,  6125] train loss: 2333556.103411 val loss: 605548.642773
[31,  6125] train loss: 2300620.510614 val loss: 596472.856260
[36,  6125] train loss: 2290563.177565 val loss: 599856.876688
[41,  6125] train loss: 2274732.108633 val loss: 593338.082266
[46,  6125] train loss: 2267168.187887 val loss: 593289.669682


In [9]:
def r2_full_score_eval(model, testing_dataloader, k=None):
    output_list = []
    labels_list = []
    temp_list = []
    for i, (x, y) in enumerate(testing_dataloader):
        output = model(x)         
        output_list.append(output.detach().cpu().numpy())
        labels_list.append(y.detach().cpu().numpy())
        if k != None and i == k-1:
            break
#     print("Output list size: {}".format(len(output_list)))
#     print(output_list[0].shape)
    output_list = np.concatenate(output_list, axis=0)
    labels_list = np.concatenate(labels_list, axis=0)
#     print(output_list.shape)
#     print(labels_list.shape)
    print(r2_score(labels_list, output_list))
    return output_list, labels_list

### Mask
Variables:
$$
x = \text{PN Firing Rate}\\
y = \text{ITN Firing Rate}\\
z = \text{Local Field Potential}\\
t = \text{Timestep}\\
m = \text{Number of Previous Timesteps}\\
e_{x} = \text{External Input}\\
$$


Output Array:
$$
\begin{bmatrix}
x_{n,t+m} & y_{n,t+m} & z_{n,t+m}\\
\end{bmatrix}
$$
Feedback Array: 
$$
\begin{bmatrix}
x_{n,t-1} & y_{n,t-1} & z_{n,t-1} & x_{n,t} & ... & z_{n,t+m-1}\\
\end{bmatrix}
$$
Replace index 0, 1, 2 with Output Array and Roll

Rolled Feedback Array: 
$$
\begin{bmatrix}
x_{n,t} & y_{n,t} & z_{n,t} & x_{n,t+1} & ... & z_{n,t+m}\\
\end{bmatrix}
$$
Input Array:
$$
\begin{bmatrix}
x_{n,t} & y_{n,t} & e_{1,n,t} & e_{2,n,t} & e_{3,n,t} & e_{4,n,t} & e_{5,n,t} & e_{6,n,t} & z_{n,t} & x_{n,t+1} & ... & z_{n,t+m}\\
\end{bmatrix}
$$
New Output:
$$
\begin{bmatrix}
x_{n,t+m+1} & y_{n,t+m+1} & z_{n,t+m+1}\\
\end{bmatrix}
$$

In [10]:
def r2_score_eval(model, testing_dataloader, start=10, k=None):
    output_list = []
    labels_list = []
    temp_list = []
    mask = [True, True, False, False, False, False, False, False, True]
    mask = list(np.tile(mask, previous_time))
    feedback_arr = torch.zeros(1, 3*previous_time)
    for i, (x, y) in enumerate(testing_dataloader):
        x2 = x.numpy().copy()
        if i >= start:
            x2[:,mask] = feedback_arr.numpy()
        output = model(torch.Tensor(x2))     
        
        feedback_arr = torch.roll(feedback_arr, -3, 1)
        feedback_arr[:,0] = output[:,0].detach()
        feedback_arr[:,1] = output[:,1].detach()
        feedback_arr[:,2] = output[:,2].detach()
        feedback_arr.detach()
        
        output_list.append(output.detach().cpu().numpy())
        labels_list.append(y.detach().cpu().numpy())
        if k != None and i == k-1:
            break
#     print("Output list size: {}".format(len(output_list)))
#     print(output_list[0].shape)
    output_list = np.concatenate(output_list, axis=0)
    labels_list = np.concatenate(labels_list, axis=0)
    print(r2_score(labels_list, output_list))
    return output_list, labels_list

In [11]:
model1 = torch.load(PATH)
model1.eval()

# training_dataset, validation_dataset = get_data_from_mat(DATA_PATH)
# training_loader = DataLoader(dataset=training_dataset,batch_size=1)
# validation_loader = DataLoader(dataset=validation_dataset,batch_size=1)

training_loader = DataLoader(dataset=f_tr,batch_size=1)
validation_loader = DataLoader(dataset=f_va,batch_size=1)

start = 200
k = 5
end= start + k

model1.to('cpu')

ft_output_list, ft_labels_list = r2_full_score_eval(model1, f_training_loader,end)
fv_output_list, fv_labels_list = r2_full_score_eval(model1, f_validation_loader,end)
print(ft_output_list.shape)
print(ft_labels_list.shape)

t_output_list, t_labels_list = r2_score_eval(model1, training_loader,start,end)
v_output_list, v_labels_list = r2_score_eval(model1, validation_loader,start,end)
print(t_output_list.shape)
print(t_labels_list.shape)
# print(t_labels_list[0,:,:])

0.4841068861722142
0.48792639523401365
(6560, 3)
(6560, 3)
0.4060221308260064
0.49388899349136595
(205, 3)
(205, 3)


In [14]:
fig, ax = plt.subplots(nrows=3, ncols=2)
fig.tight_layout()
# ax[0,0].plot(range(epochs), pnfr_training_loss)
# ax[0,0].set_title('Validation Loss')
# ax[0,0].set_ylabel('Loss')
# ax[0,0].set_xlabel('Epoch')

# ax[0,1].plot(range(epochs), pnfr_validation_loss)
# ax[0,1].set_title('Training Loss')
# ax[0,1].set_ylabel('Loss')
# ax[0,1].set_xlabel('Epoch')


ax[0,0].plot(np.arange(start-10,end), v_labels_list[start-10:end,0], color='blue', label='Labels')
ax[0,0].plot(np.arange(start-10,end), v_output_list[start-10:end,0], color='red',label='Internal Loop')
ax[0,0].plot(np.arange(start-10,end), fv_output_list[start-10:end,0], color='green',label='Training')
ax[0,0].set_title('Validation PN FR')
ax[0,0].set_ylabel('PN FR')
ax[0,0].set_xlabel('Time and Sample')
# ax[0,0].legend()

ax[0,1].plot(np.arange(start-10,end), t_labels_list[start-10:end,0], color='blue',label='Labels')
ax[0,1].plot(np.arange(start-10,end), t_output_list[start-10:end,0], color='red',label='Internal Loop')
ax[0,1].plot(np.arange(start-10,end), ft_output_list[start-10:end,0], color='green',label='Training')
ax[0,1].set_title('Training PN FR')
ax[0,1].set_ylabel('PN FR')
ax[0,1].set_xlabel('Time')
# ax[0,1].legend()

ax[1,0].plot(np.arange(start-10,end), v_labels_list[start-10:end,1], color='blue',label='Labels')
ax[1,0].plot(np.arange(start-10,end), v_output_list[start-10:end,1], color='red',label='Internal Loop')
ax[1,0].plot(np.arange(start-10,end), fv_output_list[start-10:end,1], color='green',label='Training')
ax[1,0].set_title('Validation ITN FR')
ax[1,0].set_ylabel('ITN FR')
ax[1,0].set_xlabel('Time')
# ax[1,0].legend()

ax[1,1].plot(np.arange(start-10,end), t_labels_list[start-10:end,1], color='blue',label='Labels')
ax[1,1].plot(np.arange(start-10,end), t_output_list[start-10:end,1], color='red',label='Internal Loop')
ax[1,1].plot(np.arange(start-10,end), ft_output_list[start-10:end,1], color='green',label='Training')
ax[1,1].set_title('Training ITN FR')
ax[1,1].set_ylabel('ITN FR')
ax[1,1].set_xlabel('Time')
# ax[1,1].legend()

ax[2,0].plot(np.arange(start-10,end), v_labels_list[start-10:end,2], color='blue',label='Labels')
ax[2,0].plot(np.arange(start-10,end), v_output_list[start-10:end,2], color='red',label='Internal Loop')
ax[2,0].plot(np.arange(start-10,end), fv_output_list[start-10:end,2], color='green',label='Training')
ax[2,0].set_title('Validation LFP')
ax[2,0].set_ylabel('LFP')
ax[2,0].set_xlabel('Time')
# ax[2,0].legend()

ax[2,1].plot(np.arange(start-10,end), t_labels_list[start-10:end,2], color='blue',label='Labels')
ax[2,1].plot(np.arange(start-10,end), t_output_list[start-10:end,2], color='red',label='Internal Loop')
ax[2,1].plot(np.arange(start-10,end), ft_output_list[start-10:end,2], color='green',label='Training')
ax[2,1].set_title('Training LFP')
ax[2,1].set_ylabel('LFP')
ax[2,1].set_xlabel('Time')
ax[2,1].legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7face432c070>

In [13]:
plt.close('all')