In [1]:
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !git clone https://github.com/MJC598/Neuron_Burst_Analysis.git

In [2]:
%matplotlib notebook
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import scipy.io
import random
import time
import pandas as pds
from sklearn.metrics import r2_score

torch.manual_seed(0)

<torch._C.Generator at 0x7fec10bd3af0>

In [3]:
class FCN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FCN,self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = self.tanh(self.fc1(x))
        out = self.fc2(x)
        return out

In [4]:
TIMESTEPS = 85
FRONT_TIME = -50
BACK_TIME = 40
T_START = 50+FRONT_TIME
T_END = 50+BACK_TIME
MODEL = FCN
OUTPUT = 'FR_LFP'
LOSS_FILE = ('losses/bursts/losses_' + str(MODEL) + 
             '_' + OUTPUT + '_' + str(FRONT_TIME) + 
             '_' + str(T_END) + '_10ms_lag_full.csv')
PATH = ('models/bursts/' + str(MODEL) + '_' + OUTPUT + 
        '_' + str(FRONT_TIME) + '_' + str(T_END) + 
        '_10ms_lag_full.pth')
DATA_PATH = 'data/bursts/burst_separatePNITNv2.mat'
COLAB_PRE = 'Neuron_Burst_Analysis/'
if RunningInCOLAB:
    LOSS_FILE = COLAB_PRE + LOSS_FILE
    PATH = COLAB_PRE + PATH
    DATA_PATH = COLAB_PRE + DATA_PATH

# Specific Model Parameters
input_features = 9
previous_time = 10
input_size = input_features * previous_time
hidden_size = 95
output_size = 3
batch_size = 32
num_layers = 1
batch_first = True
dropout = 0.0
epochs = 50

In [5]:
def pad_arr(l, max_dur=0):
    temp = []
    for ar in l:
        npad = ((0,max_dur-ar.shape[0]), (0,0))
        arr = np.pad(ar, pad_width=npad, mode='constant', constant_values=-1)
        temp.append(arr)
    return temp

In [6]:
def get_truncated_data_from_mat(file_path, output_index=None, type='pre_pn'):
    data = scipy.io.loadmat(file_path)
    duration = []
    amp = []
    pb_fr_pn = []
    pb_fr_itn = []
    pb_aff_pn = []
    pb_aff_itn = []
    pb_exc_pn = []
    pb_inh_pn = []
    pb_exc_itn = []
    pb_inh_itn = []
    pb_lfp = []
    wb_fr_pn = []
    wb_fr_itn = []
    wb_aff_pn = []
    wb_aff_itn = []
    wb_exc_pn = []
    wb_inh_pn = []
    wb_exc_itn = []
    wb_inh_itn = []
    wb_lfp = []

#     print(data['info_collect'][0])
#     print(data['info_collect'].shape[0])
    for i in range(1, data['info_collect'].shape[0]):
        arr = data['info_collect'][i]
        
#         print(arr.shape)
        
        duration.append(arr[0])
        amp.append(arr[1])
        
        pb_fr_pn.append(arr[2])
        pb_fr_itn.append(arr[3])
        pb_aff_pn.append(arr[4])
        pb_aff_itn.append(arr[5])
        pb_exc_pn.append(arr[6])
        pb_inh_pn.append(arr[7])
        pb_exc_itn.append(arr[8])
        pb_inh_itn.append(arr[9])
        pb_lfp.append(arr[10])
        
        wb_fr_pn.append(arr[11])
        wb_fr_itn.append(arr[12])
        wb_aff_pn.append(arr[13])
        wb_aff_itn.append(arr[14])
        wb_exc_pn.append(arr[15])
        wb_inh_pn.append(arr[16])
        wb_exc_itn.append(arr[17])
        wb_inh_itn.append(arr[18])
        wb_lfp.append(arr[19])

    min_dur = 1000
    max_dur = 0
    for i, ar in enumerate(wb_fr_pn):
        if ar.shape[0] < min_dur:
            min_dur = ar.shape[0]
            mi = i
        if ar.shape[0] > max_dur:
            max_dur = ar.shape[0]
            ma = i
            
#     print(ma)
#     print(max_dur)
    
#     wb_fr_pn = pad_arr(wb_fr_pn, max_dur)
    
    wb_fr_pn = [ar[:min_dur,:] for ar in wb_fr_pn]

#     wb_fr_itn = pad_arr(wb_fr_itn, max_dur)
    
    wb_fr_itn = [ar[:min_dur,:] for ar in wb_fr_itn]
    
#     wb_aff_pn = pad_arr(wb_aff_pn, max_dur)
    
    wb_aff_pn = [ar[:min_dur,:] for ar in wb_aff_pn]
    
#     wb_aff_itn = pad_arr(wb_aff_itn, max_dur)
    
    wb_aff_itn = [ar[:min_dur,:] for ar in wb_aff_itn]
    
#     wb_exc_pn = pad_arr(wb_exc_pn, max_dur)
    
    wb_exc_pn = [ar[:min_dur,:] for ar in wb_exc_pn]
    
#     wb_inh_pn = pad_arr(wb_inh_pn, max_dur)
    
    wb_inh_pn = [ar[:min_dur,:] for ar in wb_inh_pn]
    
#     wb_exc_itn = pad_arr(wb_exc_itn, max_dur)
    
    wb_exc_itn = [ar[:min_dur,:] for ar in wb_exc_itn]
    
#     wb_inh_itn = pad_arr(wb_inh_itn, max_dur)
    
    wb_inh_itn = [ar[:min_dur,:] for ar in wb_inh_itn]

#     wb_lfp = pad_arr(wb_lfp, max_dur)
        
    wb_lfp = [ar[:min_dur,:] for ar in wb_lfp]
    
    t1 = np.concatenate((pb_fr_pn, wb_fr_pn), axis=1)
    t2 = np.concatenate((pb_fr_itn, wb_fr_itn), axis=1)
    t3 = np.concatenate((pb_aff_pn, wb_aff_pn), axis=1)
    t4 = np.concatenate((pb_aff_itn, wb_aff_itn), axis=1)
    t5 = np.concatenate((pb_exc_pn, wb_exc_pn), axis=1)
    t6 = np.concatenate((pb_inh_pn, wb_inh_pn), axis=1)
    t7 = np.concatenate((pb_exc_itn, wb_exc_itn), axis=1)
    t8 = np.concatenate((pb_inh_itn, wb_inh_itn), axis=1)
    t9 = np.concatenate((pb_lfp, wb_lfp), axis=1)
    
#     full_labels = np.concatenate((pb_fr_pn, pb_fr_itn, pb_lfp), axis=2)
    
#     front_data = np.concatenate((pb_fr_pn, pb_fr_itn, pb_aff_pn, pb_aff_itn,pb_exc_pn, 
#                                  pb_inh_pn, pb_exc_itn, pb_inh_itn, pb_lfp), axis=2)
    
#     rear_data = np.concatenate((wb_fr_pn, wb_fr_itn, wb_aff_pn, wb_aff_itn,
#                                 wb_exc_pn, wb_inh_pn, wb_exc_itn, wb_inh_itn, wb_lfp), axis=2)
    
#     full_data = rear_data

    full_labels = np.concatenate((t1,t2,t9),axis=2)
    full_data = np.concatenate((t1,t2,t3,t4,t5,t6,t7,t8,t9), axis=2)
    print(full_data.shape)
    
#     for j in range(3):
#         x = full_labels[:,:,j]
#         full_labels[:,:,j] = (x - np.min(x))/(np.max(x)-np.min(x))
    
#     for i in range(full_data.shape[0]):
#         for j in range(input_size):
#             x = full_data[i,:,j]
#             full_data[i,:,j] = (x - np.min(x))/(np.max(x)-np.min(x))
    
    
    random.seed(10)
    data_samples = 5472 #5498 
    k = 4352
    full = np.arange(data_samples)
    training_indices = np.random.choice(full, size=k, replace=False)
    validation_indices = np.delete(full,training_indices)
    
    max_index = min_dur+50
    lag = 1
    front_offset = 11
    
    training_data = full_data[:k,0:max_index-lag,:] 
    validation_data = full_data[k:data_samples,0:max_index-lag,:]
    
    if output_index is None:
        training_labels = full_labels[:k,front_offset:max_index,:] 
    else:
        training_labels = full_labels[:k,front_offset:max_index,output_index]
    
    if output_index is None:
        validation_labels = full_labels[k:data_samples,front_offset:max_index,:]
    else:
        validation_labels = full_labels[k:data_samples,front_offset:max_index,output_index]
    
    training_data = np.transpose(training_data, (0,2,1))
    training_labels = np.transpose(training_labels, (0,2,1))
    validation_data = np.transpose(validation_data, (0,2,1))
    validation_labels = np.transpose(validation_labels, (0,2,1))
    
    td = [] 
    tl = []
    vd = []
    vl = []
    for i in range(training_data.shape[0]):
        for j in range(training_data.shape[2]-10):
            t = training_data[i,:,j:j+10]
            td.append(t.reshape((1,-1)))
            t2 = training_labels[i,:,j]
            tl.append(t2.reshape((1,-1)))
    td = np.vstack(td)
    tl = np.vstack(tl)
    print(td.shape)
    print(tl.shape)
    
    
    for i in range(validation_data.shape[0]):
        for j in range(validation_data.shape[2]-10):
            t = training_data[i,:,j:j+10]
            vd.append(t.reshape((1,-1)))
            t2 = training_labels[i,:,j]
            vl.append(t2.reshape((1,-1)))
    vd = np.vstack(vd)
    print(vd.shape)
    vl = np.vstack(vl)
    print(vl.shape)
    
#     print(training_data.shape)
#     print(training_labels.shape)
#     print(validation_data.shape)
#     print(validation_labels.shape)

    training_dataset = TensorDataset(torch.Tensor(td), torch.Tensor(tl))
    validation_dataset = TensorDataset(torch.Tensor(vd), torch.Tensor(vl))

    return training_dataset, validation_dataset
# get_data_from_mat(DATA_PATH)

In [7]:
def get_data_from_mat(file_path, output_index=None, type='pre_pn'):
    data = scipy.io.loadmat(file_path)

    full_labels = []
    full_data = []

    for i in range(1, data['info_collect'].shape[0]):
        arr = data['info_collect'][i]
        
#         duration.append(arr[0])
#         amp.append(arr[1])
        
#         fr_pn.append(np.row_stack((arr[2], arr[11])))
#         fr_itn.append(np.row_stack((arr[3], arr[12])))
#         aff_pn.append(np.row_stack((arr[4], arr[13])))
#         aff_itn.append(np.row_stack((arr[5], arr[14])))
#         exc_pn.append(np.row_stack((arr[6], arr[15])))
#         inh_pn.append(np.row_stack((arr[7], arr[16])))
#         exc_itn.append(np.row_stack((arr[8], arr[17])))
#         inh_itn.append(np.row_stack((arr[9], arr[18])))
#         lfp.append(np.row_stack((arr[10], arr[19])))
        
        full_labels.append(np.column_stack((np.row_stack((arr[2], arr[11])), 
                                            np.row_stack((arr[3], arr[12])), 
                                            np.row_stack((arr[10], arr[19])))))
        
        full_data.append(np.column_stack((np.row_stack((arr[2], arr[11])), 
                                          np.row_stack((arr[3], arr[12])), 
                                          np.row_stack((arr[4], arr[13])), 
                                          np.row_stack((arr[5], arr[14])), 
                                          np.row_stack((arr[6], arr[15])), 
                                          np.row_stack((arr[7], arr[16])), 
                                          np.row_stack((arr[8], arr[17])), 
                                          np.row_stack((arr[9], arr[18])), 
                                          np.row_stack((arr[10], arr[19])))))
    
    full_data = np.asarray(full_data)
    full_labels = np.asarray(full_labels)
    
    random.seed(10)
    data_samples = 5472 #5498 
    k = 4352
    
    lag = 1
    front_offset = 11
    
    training_data = full_data[:k] 
    validation_data = full_data[k:data_samples]
    training_labels = full_labels[:k]
    validation_labels = full_data[k:data_samples]
    
#     if output_index is None:
#         training_labels = full_labels[:k,front_offset:max_index,:] 
#     else:
#         training_labels = full_labels[:k,front_offset:max_index,output_index]
    
#     if output_index is None:
#         validation_labels = full_labels[k:data_samples,front_offset:max_index,:]
#     else:
#         validation_labels = full_labels[k:data_samples,front_offset:max_index,output_index]
    
    td = [] 
    tl = []
    vd = []
    vl = []
    for i, sample in enumerate(training_data):
#         print(sample.shape)
        label = training_labels[i]
#         print(label.shape)
        for j in range(sample.shape[0]-10):
            t = sample[j:j+10,:]
#             print(t[-1])
#             print(t.shape)
            td.append(t.reshape((1,-1)))
            t2 = label[j+10,:]
#             print(t2)
            tl.append(t2.reshape((1,-1)))
    td = np.vstack(td)
    tl = np.vstack(tl)
    print(td.shape)
    print(tl.shape)
    
    
    for i, sample in enumerate(validation_data):
        label = validation_labels[i]
        for j in range(sample.shape[0]-10):
            t = sample[j:j+10,:]
            vd.append(t.reshape((1,-1)))
            t2 = label[j+10,:]
            vl.append(t2.reshape((1,-1)))
    vd = np.vstack(vd)
    print(vd.shape)
    vl = np.vstack(vl)
    print(vl.shape)
    
#     print(training_data.shape)
#     print(training_labels.shape)
#     print(validation_data.shape)
#     print(validation_labels.shape)

    training_dataset = TensorDataset(torch.Tensor(td), torch.Tensor(tl))
    validation_dataset = TensorDataset(torch.Tensor(vd), torch.Tensor(vl))

    return training_dataset, validation_dataset
# get_data_from_mat(DATA_PATH)

In [8]:
def train_model(model,save_filepath,training_loader,validation_loader,epochs):
    
    epochs_list = []
    train_loss_list = []
    val_loss_list = []
    training_len = len(training_loader.dataset)
    validation_len = len(validation_loader.dataset)

    #splitting the dataloaders to generalize code
    data_loaders = {"train": training_loader, "val": validation_loader}
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_func = nn.MSELoss()
#     loss_func = nn.L1Loss()
    decay_rate = 0.93 #decay the lr each step to 93% of previous lr
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

    total_start = time.time()

    """
    You can easily adjust the number of epochs trained here by changing the number in the range
    """
    for epoch in tqdm(range(epochs), position=0, leave=True):
        start = time.time()
        train_loss = 0.0
        val_loss = 0.0
        temp_loss = 100000000000000.0
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train(True)
            else:
                model.train(False)

            running_loss = 0.0
            for i, (x, y) in enumerate(data_loaders[phase]):
                output = model(x) 
#                 print(output.size())
#                 print(y.size())
                loss = loss_func(torch.squeeze(output), torch.squeeze(y)) 
                #backprop             
                optimizer.zero_grad()           
                if phase == 'train':
                    loss.backward()
                    optimizer.step()                                      

                #calculating total loss
                running_loss += loss.item()
            
            if phase == 'train':
                train_loss = running_loss
                lr_sch.step()
            else:
                val_loss = running_loss

        end = time.time()
        # shows total loss
        if epoch%5 == 0:
            print('[%d, %5d] train loss: %.6f val loss: %.6f' % (epoch + 1, i + 1, train_loss, val_loss))
#         print(end - start)
        
        #saving best model
        if val_loss < temp_loss:
            torch.save(model, save_filepath)
            temp_loss = val_loss
        epochs_list.append(epoch)
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)
    total_end = time.time()
#     print(total_end - total_start)
    #Creating loss csv
    loss_df = pds.DataFrame(
        {
            'epoch': epochs_list,
            'training loss': train_loss_list,
            'validation loss': val_loss_list
        }
    )
    # Writing loss csv, change path to whatever you want to name it
    lf = ('losses/losses_' + str(MODEL) + '_' 
          + OUTPUT + '_' + str(FRONT_TIME) + '_' 
          + str(T_END) + '_fullin.csv')
    loss_df.to_csv(lf, index=None)
    return train_loss_list, val_loss_list

In [9]:
def r2_score_eval(model, testing_dataloader):
    output_list = []
    labels_list = []
    temp_list = []
    for i, (x, y) in enumerate(testing_dataloader):
        output = model(x)         
        output_list.append(output.detach().cpu().numpy())
        labels_list.append(y.detach().cpu().numpy())
#     print("Output list size: {}".format(len(output_list)))
#     print(output_list[0].shape)
    output_list = np.concatenate(output_list, axis=0)
    labels_list = np.concatenate(labels_list, axis=0)
#     print(output_list.shape)
#     print(labels_list.shape)
    print(r2_score(labels_list, output_list))
    return output_list, labels_list

In [10]:
model1 = MODEL(input_size,hidden_size,output_size)

training_dataset, validation_dataset = get_data_from_mat(DATA_PATH)

# Turn datasets into iterable dataloaders
training_loader = DataLoader(dataset=training_dataset,batch_size=batch_size,shuffle=True)
validation_loader = DataLoader(dataset=validation_dataset,batch_size=batch_size)

p1 = 'models/' + str(MODEL) + '_' + OUTPUT + '_' + str(FRONT_TIME) + '_' + str(T_END) + '_full.pth'
pnfr_training_loss, pnfr_validation_loss = train_model(model1,p1,training_loader,
                                                       validation_loader,epochs)

  return array(a, dtype, copy=False, order=order)


(757509, 90)
(757509, 3)
(197107, 90)
(197107, 9)


  0%|          | 0/50 [00:00<?, ?it/s]

  Variable._execution_engine.run_backward(


KeyboardInterrupt: 

In [None]:
model1 = torch.load(p1)
model1.eval()
t_output_list, t_labels_list = r2_score_eval(model1, training_loader)
v_output_list, v_labels_list = r2_score_eval(model1, validation_loader)
print(t_output_list.shape)
print(t_labels_list.shape)
# print(t_labels_list[0,:,:])

In [None]:
fig, ax = plt.subplots(nrows=4, ncols=2)
fig.tight_layout()
ax[0,0].plot(range(epochs), pnfr_training_loss)
ax[0,0].set_title('Validation Loss')
ax[0,0].set_ylabel('Loss')
ax[0,0].set_xlabel('Epoch')

ax[0,1].plot(range(epochs), pnfr_validation_loss)
ax[0,1].set_title('Training Loss')
ax[0,1].set_ylabel('Loss')
ax[0,1].set_xlabel('Epoch')


ax[1,0].plot(np.arange(v_labels_list.shape[0]), v_labels_list[:,0], color='blue')
ax[1,0].plot(np.arange(v_labels_list.shape[0]), v_output_list[:,0], color='red')
ax[1,0].set_title('Validation PN FR')
ax[1,0].set_ylabel('PN FR')
ax[1,0].set_xlabel('Time and Sample')

ax[1,1].plot(np.arange(t_labels_list.shape[0]), t_labels_list[:,0], color='blue')
ax[1,1].plot(np.arange(t_labels_list.shape[0]), t_output_list[:,0], color='red')
ax[1,1].set_title('Training PN FR')
ax[1,1].set_ylabel('PN FR')
ax[1,1].set_xlabel('Time')

ax[2,0].plot(np.arange(v_labels_list.shape[0]), v_labels_list[:,1], color='blue')
ax[2,0].plot(np.arange(v_labels_list.shape[0]), v_output_list[:,1], color='red')
ax[2,0].set_title('Validation ITN FR')
ax[2,0].set_ylabel('ITN FR')
ax[2,0].set_xlabel('Time')

ax[2,1].plot(np.arange(t_labels_list.shape[0]), t_labels_list[:,1], color='blue')
ax[2,1].plot(np.arange(t_labels_list.shape[0]), t_output_list[:,1], color='red')
ax[2,1].set_title('Training ITN FR')
ax[2,1].set_ylabel('ITN FR')
ax[2,1].set_xlabel('Time')

ax[3,0].plot(np.arange(v_labels_list.shape[0]), v_labels_list[:,2], color='blue')
ax[3,0].plot(np.arange(v_labels_list.shape[0]), v_output_list[:,2], color='red')
ax[3,0].set_title('Validation LFP')
ax[3,0].set_ylabel('LFP')
ax[3,0].set_xlabel('Time')

ax[3,1].plot(np.arange(t_labels_list.shape[0]), t_labels_list[:,2], color='blue')
ax[3,1].plot(np.arange(t_labels_list.shape[0]), t_output_list[:,2], color='red')
ax[3,1].set_title('Training LFP')
ax[3,1].set_ylabel('LFP')
ax[3,1].set_xlabel('Time')

plt.show()