In [1]:
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !git clone https://github.com/MJC598/Neuron_Burst_Analysis.git

In [2]:
%matplotlib notebook
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import scipy.io
import random
import time
import pandas as pds
from sklearn.metrics import r2_score

torch.manual_seed(0)

<torch._C.Generator at 0x7fe9453a5af0>

In [3]:
class FCN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FCN,self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = self.tanh(self.fc1(x))
        out = self.fc2(x)
        return out

In [4]:
TIMESTEPS = 85
FRONT_TIME = -50
BACK_TIME = 40
T_START = 50+FRONT_TIME
T_END = 50+BACK_TIME
MODEL = FCN
OUTPUT = 'FR_LFP'
LOSS_FILE = ('losses/bursts/losses_' + str(MODEL) + 
             '_' + OUTPUT + '_' + str(FRONT_TIME) + 
             '_' + str(T_END) + '_10ms_lag_aff_temp.csv')
# PATH = ('models/bursts/' + str(MODEL) + '_' + OUTPUT + 
#         '_' + str(FRONT_TIME) + '_' + str(T_END) + 
#         '_10ms_lag_aff_temp.pth')
PATH = 'models/<class \'__main__.FCN\'>_FR_LFP_-50_90_full.pth'

DATA_PATH = 'data/bursts/burst_separatePNITNv2.mat'
COLAB_PRE = 'Neuron_Burst_Analysis/'
if RunningInCOLAB:
    LOSS_FILE = COLAB_PRE + LOSS_FILE
    PATH = COLAB_PRE + PATH
    DATA_PATH = COLAB_PRE + DATA_PATH

# Specific Model Parameters
input_features = 9
previous_time = 10
input_size = input_features * previous_time
hidden_size = 95
output_size = 3
batch_size = 1
num_layers = 1
batch_first = True
dropout = 0.0
epochs = 50

In [5]:
def get_full_data_from_mat(file_path, output_index=None, type='pre_pn'):
    data = scipy.io.loadmat(file_path)

    full_labels = []
    full_data = []

    for i in range(1, data['info_collect'].shape[0]):
        arr = data['info_collect'][i]
        
        full_labels.append(np.column_stack((np.row_stack((arr[2], arr[11])), 
                                            np.row_stack((arr[3], arr[12])), 
                                            np.row_stack((arr[10], arr[19])))))
        
        full_data.append(np.column_stack((np.row_stack((arr[2], arr[11])), 
                                          np.row_stack((arr[3], arr[12])), 
                                          np.row_stack((arr[4], arr[13])), 
                                          np.row_stack((arr[5], arr[14])), 
                                          np.row_stack((arr[6], arr[15])), 
                                          np.row_stack((arr[7], arr[16])), 
                                          np.row_stack((arr[8], arr[17])), 
                                          np.row_stack((arr[9], arr[18])), 
                                          np.row_stack((arr[10], arr[19])))))
    
    full_data = np.asarray(full_data,dtype=object)
    full_labels = np.asarray(full_labels,dtype=object)
    
    random.seed(10)
    data_samples = 5472 #5498 
    k = 4352
    
    lag = 1
    front_offset = 11
    
    training_data = full_data[:k] 
    validation_data = full_data[k:data_samples]
    training_labels = full_labels[:k]
    validation_labels = full_labels[k:data_samples]
    
    td = [] 
    tl = []
    vd = []
    vl = []
    for i, sample in enumerate(training_data):
        label = training_labels[i]
        for j in range(sample.shape[0]-10):
            t = sample[j:j+10,:]
            td.append(t.reshape((1,-1)))
            t2 = label[j+10,:]
            tl.append(t2.reshape((1,-1)))
    td = np.vstack(td)
    tl = np.vstack(tl)
    print(td.shape)
    print(tl.shape)
    
    
    for i, sample in enumerate(validation_data):
        label = validation_labels[i]
        for j in range(sample.shape[0]-10):
            t = sample[j:j+10,:]
            vd.append(t.reshape((1,-1)))
            t2 = label[j+10,:]
            vl.append(t2.reshape((1,-1)))
    vd = np.vstack(vd)
    vl = np.vstack(vl)
    print(vd.shape)
    print(vl.shape)

    training_dataset = TensorDataset(torch.Tensor(td), torch.Tensor(tl))
    validation_dataset = TensorDataset(torch.Tensor(vd), torch.Tensor(vl))

    return training_dataset, validation_dataset
# get_data_from_mat(DATA_PATH)

In [6]:
def get_data_from_mat(file_path, output_index=None, type='pre_pn'):
    data = scipy.io.loadmat(file_path)

    full_labels = []
    full_data = []

    for i in range(1, data['info_collect'].shape[0]):
        arr = data['info_collect'][i]
        
        full_labels.append(np.column_stack((np.row_stack((arr[2], arr[11])), 
                                            np.row_stack((arr[3], arr[12])), 
                                            np.row_stack((arr[10], arr[19])))))
        
        full_data.append(np.column_stack((np.row_stack((arr[2], arr[11])), 
                                          np.row_stack((arr[3], arr[12])), 
                                          np.row_stack((arr[4], arr[13])), 
                                          np.row_stack((arr[5], arr[14])), 
                                          np.row_stack((arr[6], arr[15])), 
                                          np.row_stack((arr[7], arr[16])), 
                                          np.row_stack((arr[8], arr[17])), 
                                          np.row_stack((arr[9], arr[18])), 
                                          np.row_stack((arr[10], arr[19])))))
    
    full_data = np.asarray(full_data,dtype=object)
    full_labels = np.asarray(full_labels,dtype=object)
    
    random.seed(10)
    data_samples = 5472 #5498 
    k = 4352
    
    lag = 1
    front_offset = 11
    
    training_data = full_data[:k] 
    validation_data = full_data[k:data_samples]
    training_labels = full_labels[:k]
    validation_labels = full_labels[k:data_samples]
    
    td = [] 
    tl = []
    vd = []
    vl = []
    for i, sample in enumerate(training_data):
        label = training_labels[i]
        for j in range(sample.shape[0]-10):
            t = sample[j:j+10,:]
            if j != 0:
                t[:,0] = 0
                t[:,1] = 0
                t[:,8] = 0
            td.append(t.reshape((1,-1)))
            t2 = label[j+10,:]
            tl.append(t2.reshape((1,-1)))
    td = np.vstack(td)
    tl = np.vstack(tl)
    print(td.shape)
    print(tl.shape)
    
    
    for i, sample in enumerate(validation_data):
        label = validation_labels[i]
        for j in range(sample.shape[0]-10):
            t = sample[j:j+10,:]
            if j != 0:
                t[:,0] = 0
                t[:,1] = 0
                t[:,8] = 0
            vd.append(t.reshape((1,-1)))
            t2 = label[j+10,:]
            vl.append(t2.reshape((1,-1)))
    vd = np.vstack(vd)
    vl = np.vstack(vl)
    print(vd.shape)
    print(vl.shape)

    training_dataset = TensorDataset(torch.Tensor(td), torch.Tensor(tl))
    validation_dataset = TensorDataset(torch.Tensor(vd), torch.Tensor(vl))

    return training_dataset, validation_dataset
# get_data_from_mat(DATA_PATH)

In [7]:
def train_model(model,save_filepath,training_loader,validation_loader,epochs):
    
    epochs_list = []
    train_loss_list = []
    val_loss_list = []
    training_len = len(training_loader.dataset)
    validation_len = len(validation_loader.dataset)
    
    feedback_arr = torch.zeros(batch_size, 90)
    
    #splitting the dataloaders to generalize code
    data_loaders = {"train": training_loader, "val": validation_loader}
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_func = nn.MSELoss()
#     loss_func = nn.L1Loss()
    decay_rate = 0.93 #decay the lr each step to 93% of previous lr
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

    total_start = time.time()

    """
    You can easily adjust the number of epochs trained here by changing the number in the range
    """
    for epoch in tqdm(range(epochs), position=0, leave=True):
        start = time.time()
        train_loss = 0.0
        val_loss = 0.0
        temp_loss = 100000000000000.0
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train(True)
            else:
                model.train(False)

            running_loss = 0.0
            for i, (x, y) in enumerate(data_loaders[phase]):
#                 if phase == 'train':
                x = torch.add(x, feedback_arr)
#                     print(x)
#                     print(feedback_arr)
#                 print(y.size())
                output = model(x)
#                 print(output)
                
#                 print(output.size())
#                 print(y.size())
                loss = loss_func(torch.squeeze(output), torch.squeeze(y)) 
                #backprop             
                optimizer.zero_grad()           
                if phase == 'train':
                    loss.backward()
                    optimizer.step()                                      

                #calculating total loss
                running_loss += loss.item()
#                 if phase == 'train':
                feedback_arr = torch.roll(feedback_arr, 9, 1)
                feedback_arr[:,0] = output[:,0].detach()
                feedback_arr[:,1] = output[:,1].detach()
                feedback_arr[:,8] = output[:,2].detach()
                feedback_arr.detach()
#                     print(feedback_arr)
            
            if phase == 'train':
                train_loss = running_loss
                lr_sch.step()
            else:
                val_loss = running_loss

        end = time.time()
        # shows total loss
        if epoch%5 == 0:
            print('[%d, %5d] train loss: %.6f val loss: %.6f' % (epoch + 1, i + 1, train_loss, val_loss))
#         print(end - start)
        
        #saving best model
        if val_loss < temp_loss:
            torch.save(model, save_filepath)
            temp_loss = val_loss
        epochs_list.append(epoch)
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)
    total_end = time.time()
#     print(total_end - total_start)
    #Creating loss csv
    loss_df = pds.DataFrame(
        {
            'epoch': epochs_list,
            'training loss': train_loss_list,
            'validation loss': val_loss_list
        }
    )
    # Writing loss csv, change path to whatever you want to name it
    
    loss_df.to_csv(LOSS_FILE, index=None)
    return train_loss_list, val_loss_list

In [8]:
def r2_full_score_eval(model, testing_dataloader):
    output_list = []
    labels_list = []
    temp_list = []
    for i, (x, y) in enumerate(testing_dataloader):
        output = model(x)         
        output_list.append(output.detach().cpu().numpy())
        labels_list.append(y.detach().cpu().numpy())
#     print("Output list size: {}".format(len(output_list)))
#     print(output_list[0].shape)
    output_list = np.concatenate(output_list, axis=0)
    labels_list = np.concatenate(labels_list, axis=0)
#     print(output_list.shape)
#     print(labels_list.shape)
    print(r2_score(labels_list, output_list))
    return output_list, labels_list

In [9]:
def r2_score_eval(model, testing_dataloader):
    output_list = []
    labels_list = []
    temp_list = []
    feedback_arr = torch.zeros(batch_size, 90)
    for i, (x, y) in enumerate(testing_dataloader):
        x = torch.add(x, feedback_arr)
        output = model(x)     
        
        feedback_arr = torch.roll(feedback_arr, 9, 1)
        feedback_arr[:,0] = output[:,0].detach()
        feedback_arr[:,1] = output[:,1].detach()
        feedback_arr[:,8] = output[:,2].detach()
        feedback_arr.detach()
        
        output_list.append(output.detach().cpu().numpy())
        labels_list.append(y.detach().cpu().numpy())
#     print("Output list size: {}".format(len(output_list)))
#     print(output_list[0].shape)
    output_list = np.concatenate(output_list, axis=0)
    labels_list = np.concatenate(labels_list, axis=0)
#     print(output_list.shape)
#     print(labels_list.shape)
    print(r2_score(labels_list, output_list))
    return output_list, labels_list

In [10]:
model1 = MODEL(input_size,hidden_size,output_size)

f_tr, f_va = get_full_data_from_mat(DATA_PATH)
training_dataset, validation_dataset = get_data_from_mat(DATA_PATH)

# Turn datasets into iterable dataloaders
f_training_loader = DataLoader(dataset=f_tr,batch_size=batch_size)
f_validation_loader = DataLoader(dataset=f_va,batch_size=batch_size)

training_loader = DataLoader(dataset=training_dataset,batch_size=batch_size)
validation_loader = DataLoader(dataset=validation_dataset,batch_size=batch_size)

# pnfr_training_loss, pnfr_validation_loss = train_model(model1,PATH,training_loader,
#                                                        validation_loader,epochs)

(757509, 90)
(757509, 3)
(197107, 90)
(197107, 3)
(757509, 90)
(757509, 3)
(197107, 90)
(197107, 3)


In [11]:
model1 = torch.load(PATH)
model1.eval()

ft_output_list, ft_labels_list = r2_full_score_eval(model1, f_training_loader)
fv_output_list, fv_labels_list = r2_full_score_eval(model1, f_validation_loader)
print(ft_output_list.shape)
print(ft_labels_list.shape)

t_output_list, t_labels_list = r2_score_eval(model1, training_loader)
v_output_list, v_labels_list = r2_score_eval(model1, validation_loader)
print(t_output_list.shape)
print(t_labels_list.shape)
# print(t_labels_list[0,:,:])

0.4226264914101965
0.4235686638348424
(757509, 3)
(757509, 3)
-0.0360098629043916
-0.03644091423125088
(757509, 3)
(757509, 3)


In [15]:
fig, ax = plt.subplots(nrows=3, ncols=2)
fig.tight_layout()
# ax[0,0].plot(range(epochs), pnfr_training_loss)
# ax[0,0].set_title('Validation Loss')
# ax[0,0].set_ylabel('Loss')
# ax[0,0].set_xlabel('Epoch')

# ax[0,1].plot(range(epochs), pnfr_validation_loss)
# ax[0,1].set_title('Training Loss')
# ax[0,1].set_ylabel('Loss')
# ax[0,1].set_xlabel('Epoch')


ax[0,0].plot(np.arange(v_labels_list.shape[0]), v_labels_list[:,0], color='blue', label='Labels')
ax[0,0].plot(np.arange(v_labels_list.shape[0]), v_output_list[:,0], color='red',label='Internal Loop')
ax[0,0].plot(np.arange(v_labels_list.shape[0]), fv_output_list[:,0], color='green',label='Training')
ax[0,0].set_title('Validation PN FR')
ax[0,0].set_ylabel('PN FR')
ax[0,0].set_xlabel('Time and Sample')
# ax[0,0].legend()

ax[0,1].plot(np.arange(t_labels_list.shape[0]), t_labels_list[:,0], color='blue',label='Labels')
ax[0,1].plot(np.arange(t_labels_list.shape[0]), t_output_list[:,0], color='red',label='Internal Loop')
ax[0,1].plot(np.arange(t_labels_list.shape[0]), ft_output_list[:,0], color='green',label='Training')
ax[0,1].set_title('Training PN FR')
ax[0,1].set_ylabel('PN FR')
ax[0,1].set_xlabel('Time')
# ax[0,1].legend()

ax[1,0].plot(np.arange(v_labels_list.shape[0]), v_labels_list[:,1], color='blue',label='Labels')
ax[1,0].plot(np.arange(v_labels_list.shape[0]), v_output_list[:,1], color='red',label='Internal Loop')
ax[1,0].plot(np.arange(v_labels_list.shape[0]), fv_output_list[:,1], color='green',label='Training')
ax[1,0].set_title('Validation ITN FR')
ax[1,0].set_ylabel('ITN FR')
ax[1,0].set_xlabel('Time')
# ax[1,0].legend()

ax[1,1].plot(np.arange(t_labels_list.shape[0]), t_labels_list[:,1], color='blue',label='Labels')
ax[1,1].plot(np.arange(t_labels_list.shape[0]), t_output_list[:,1], color='red',label='Internal Loop')
ax[1,1].plot(np.arange(t_labels_list.shape[0]), ft_output_list[:,1], color='green',label='Training')
ax[1,1].set_title('Training ITN FR')
ax[1,1].set_ylabel('ITN FR')
ax[1,1].set_xlabel('Time')
# ax[1,1].legend()

ax[2,0].plot(np.arange(v_labels_list.shape[0]), v_labels_list[:,2], color='blue',label='Labels')
ax[2,0].plot(np.arange(v_labels_list.shape[0]), v_output_list[:,2], color='red',label='Internal Loop')
ax[2,0].plot(np.arange(v_labels_list.shape[0]), fv_output_list[:,2], color='green',label='Training')
ax[2,0].set_title('Validation LFP')
ax[2,0].set_ylabel('LFP')
ax[2,0].set_xlabel('Time')
# ax[2,0].legend()

ax[2,1].plot(np.arange(t_labels_list.shape[0]), t_labels_list[:,2], color='blue',label='Labels')
ax[2,1].plot(np.arange(t_labels_list.shape[0]), t_output_list[:,2], color='red',label='Internal Loop')
ax[2,1].plot(np.arange(t_labels_list.shape[0]), ft_output_list[:,2], color='green',label='Training')
ax[2,1].set_title('Training LFP')
ax[2,1].set_ylabel('LFP')
ax[2,1].set_xlabel('Time')
ax[2,1].legend()

plt.show()

<IPython.core.display.Javascript object>