In [1]:
RunningInCOLAB = 'google.colab' in str(get_ipython())
COLAB_PRE = 'Neuron_Burst_Analysis/'
if RunningInCOLAB:
    !git clone https://github.com/MJC598/Neuron_Burst_Analysis.git
    paths.LOSS_FILE = COLAB_PRE + paths.LOSS_FILE
    paths.PATH = COLAB_PRE + paths.PATH

In [2]:
%matplotlib widget
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence
from numpy.random import default_rng
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import scipy.io
from scipy import signal
import random
import time
import pandas as pds
from sklearn.metrics import r2_score, mean_squared_error
import copy

from utils import preprocess, metrics
from config import params, paths
from models import LFPNet

s = 67

rs = RandomState(MT19937(SeedSequence(s)))
rng = default_rng(seed=s)
torch.manual_seed(s)

plt.rcParams.update({'font.size': 32})

In [3]:
plt.close('all')

In [4]:
def train_model(model,save_filepath,training_loader,validation_loader,epochs,device):
    epochs_list = []
    train_loss_list = []
    val_loss_list = []
    training_len = len(training_loader.dataset)
    validation_len = len(validation_loader.dataset)
    
#     feedback_arr = torch.zeros(params.BATCH_SIZE, 90)
    
    #splitting the dataloaders to generalize code
    data_loaders = {"train": training_loader, "val": validation_loader}
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    loss_func = nn.MSELoss()
#     loss_func = nn.L1Loss()
    decay_rate = 0.98 #decay the lr each step to 98% of previous lr
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

    total_start = time.time()

    """
    You can easily adjust the number of epochs trained here by changing the number in the range
    """
    for epoch in tqdm(range(epochs), position=0, leave=True):
        start = time.time()
        train_loss = 0.0
        val_loss = 0.0
        temp_loss = 100000000000000.0
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train(True)
            else:
                model.train(False)

            running_loss = 0.0
            for i, (x, y) in enumerate(data_loaders[phase]):
                if params.RECURRENT_NET:
                    x = torch.transpose(x, 2, 1)
                x = x.to(device)
                output = model(x)
                y = y.to(device)
#                 if i%100000 == 0 and epoch%5 == 0:
#                     print(output)
#                     print(y)
                loss = loss_func(torch.squeeze(output), torch.squeeze(y)) 
                #backprop             
                optimizer.zero_grad()
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
#                     if i%100000 == 0 and epoch%5 == 0:
#                         print(model.cn1.weight.grad)
#                         print(model.cn2.weight.grad)
#                         print(model.fc1.weight.grad)
#                         print(model.fc2.weight.grad)

                #calculating total loss
                running_loss += loss.item()
            
            if phase == 'train':
                train_loss = running_loss
                lr_sch.step()
            else:
                val_loss = running_loss

        end = time.time()
        # shows total loss
        if epoch%5 == 0:
            print('[%d, %5d] train loss: %.6f val loss: %.6f' % (epoch + 1, i + 1, train_loss, val_loss))
#         print(end - start)
        
        #saving best model
        if train_loss < temp_loss:
            torch.save(model, save_filepath)
            temp_loss = train_loss
        epochs_list.append(epoch)
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)
    total_end = time.time()
#     print(total_end - total_start)
    #Creating loss csv
    loss_df = pds.DataFrame(
        {
            'epoch': epochs_list,
            'training loss': train_loss_list,
            'validation loss': val_loss_list
        }
    )
    # Writing loss csv, change path to whatever you want to name it
    
    loss_df.to_csv(paths.LOSS_FILE, index=None)
    return train_loss_list, val_loss_list

In [None]:
f_tr, f_va = preprocess.get_inVivo_LFP()

In [None]:
# f_tr, f_va, f_data = preprocess.get_filteredLFP()
# f_tr, f_va, f_data = get_rawLFP()
# f_tr, f_va, t_filt, v_filt = preprocess.get_end1D()#f_data, t_filt, v_filt, f_filt = get_end1D()

# f_tr, f_va, f_data = preprocess.get_rawLFP()


# noise = get_WN(channels=2)
# sin = get_sin()

# burst, fburst = preprocess.get_burstLFP()

# Turn datasets into iterable dataloaders
train_loader = DataLoader(dataset=f_tr,batch_size=params.BATCH_SIZE, shuffle=True)
# tfilt_loader = DataLoader(dataset=t_filt,params.BATCH_SIZE=params.BATCH_SIZE)
val_loader = DataLoader(dataset=f_va,batch_size=params.BATCH_SIZE)
# vfilt_loader = DataLoader(dataset=v_filt,params.BATCH_SIZE=params.BATCH_SIZE)

# full_loader = DataLoader(dataset=f_data,batch_size=params.BATCH_SIZE)

# ffull_loader = DataLoader(dataset=f_filt,params.BATCH_SIZE=params.BATCH_SIZE)
# noise_loader = DataLoader(dataset=noise,params.BATCH_SIZE=params.BATCH_SIZE)
# sine_loader = DataLoader(dataset=sin,params.BATCH_SIZE=params.BATCH_SIZE)

# burst_loader = DataLoader(dataset=burst,params.BATCH_SIZE=params.BATCH_SIZE)
# fburst_loader = DataLoader(dataset=fburst,params.BATCH_SIZE=params.BATCH_SIZE)

In [None]:
model1 = params.MODEL(params.INPUT_SIZE,params.HIDDEN_SIZE,params.OUTPUT_SIZE)
# model1 = torch.load(paths.PATH)
model_initial = copy.deepcopy(model1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model1.to(device)
# for name, param in model1.named_parameters():
#     print(name)
#     param.requires_grad = True

# model1.rnn.weight_ih_l0.requires_grad = True
# model1.rnn.weight_hh_l0.requires_grad = True
# model1.rnn.bias_ih_l0.requires_grad = True
# model1.rnn.bias_hh_l0.requires_grad = True

# model1.convs2k3.weight.requires_grad = True
# model1.convs2k3.bias.requires_grad = True
# model1.dilation.weight.requires_grad = True
# model1.dilation.bias.requires_grad = True
# model1.convs1k5.weight.requires_grad = True
# model1.convs1k5.bias.requires_grad = True
# model1.convs1k3.weight.requires_grad = True
# model1.convs1k3.bias.requires_grad = True

pnfr_training_loss, pnfr_validation_loss = train_model(model1,paths.PATH,train_loader,
                                                       val_loader,params.EPOCHS,device)

In [None]:
model1 = torch.load(paths.PATH)
model1.eval()

start = 10
k = 10
end= (start + k) if k != None else None

model1.to('cpu')

t_pred, t_real = metrics.r2_eval(model1, train_loader, k=end)
v_pred, v_real = metrics.r2_eval(model1, val_loader, k=end)

# f_pred, f_real = r2_eval(model1, full_loader,filt=ffull_loader, k=end)
# print(f_real)
# t_pred, t_real = r2_eval(model1, train_loader, filt=None ,k=end)
# v_pred, v_real = r2_eval(model1, val_loader, filt=None, k=end)
# f_pred, f_real = r2_eval(model1, full_loader,filt=None, k=end)
# n_pred, n_real = r2_eval(model1, noise_loader, end)
# s_pred, s_real = r2_eval(model1, sine_loader, end)

# b_pred, b_real = r2_eval(model1, burst_loader, filt=fburst_loader,k=end)

# for i in range(len(s_pred)):
#     print("output: {} label: {}".format(s_pred[i], s_real[i]))

In [None]:
print("Train MSE: {:f}".format(mean_squared_error(t_real, t_pred)))
print("Val MSE: {:f}".format(mean_squared_error(v_real, v_pred)))
# print("Full MSE: {:f}".format(mean_squared_error(f_real, f_pred)))
# print("Burst MSE: {:f}".format(mean_squared_error(b_real, b_pred)))
print(t_real.shape)
print(t_pred.shape)

In [None]:
tp, vp, tr, vr = t_pred, v_pred, t_real, v_real

In [None]:
# print(next(iter(train_loader))[0][:,2,0])
# print(next(iter(burst_loader))[0][:,2,0])
t_pred = tp[:,0]
v_pred = vp[:,0]
t_real = tr[:,0]
v_real = vr[:,0]
print(t_pred.shape)
print(t_real.shape)

In [None]:
fig1, ax1 = plt.subplots(nrows=1, ncols=2)
fig1.tight_layout()
ax1[0].plot(range(params.EPOCHS), pnfr_training_loss)
ax1[0].set_title('Training Loss')
ax1[0].set_ylabel('Loss')
ax1[0].set_xlabel('Epoch')

ax1[1].plot(range(params.EPOCHS), pnfr_validation_loss)
ax1[1].set_title('Validation Loss')
ax1[1].set_ylabel('Loss')
ax1[1].set_xlabel('Epoch')

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=1)
fig.tight_layout()

ax[0].plot(np.arange(0,end), v_real[0:end], color='blue',label='Labels')
# ax[2,0].plot(np.arange(start-10,end), v_output_list[start-10:end,2], color='red',label='Internal Loop')
print(np.arange(0,end).shape, v_pred[:end].shape)
ax[0].scatter(np.arange(0,end), v_pred[:end], color='slateblue',label='Predicted t+1')
# ax[0].scatter(np.arange(start-params.OUTPUT_SIZE,end), v_pred[start-2:end+8,1], color='lightsteelblue',label='Training t+2')
# ax[0].scatter(np.arange(start-params.OUTPUT_SIZE,end), v_pred[start-3:end+7,2], color='gray',label='Training t+3')
# ax[0].scatter(np.arange(start-params.OUTPUT_SIZE,end), v_pred[start-4:end+6,3], color='sienna',label='Training t+4')
# ax[0].scatter(np.arange(start-params.OUTPUT_SIZE,end), v_pred[start-5:end+5,4], color='magenta',label='Training t+5')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-6:end+4,5], color='aquamarine',label='Training t+6')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-7:end+3,6], color='darkorange',label='Training t+7')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-8:end+2,7], color='brown',label='Training t+8')
# ax[0].scatter(np.arange(start-10,end), v_pred[start-9:end+1,8], color='purple',label='Training t+9')
# ax[0].plot(np.arange(start-10,end), v_pred[start-10:end], color='green',label='Training t+10')


ax[0].set_title('Validation LFPNet')
ax[0].set_ylabel('LFP')
ax[0].set_xlabel('Time')
# ax[2,0].legend()

ax[1].plot(np.arange(0,end), t_real[:end], color='blue',label='Labels')
# a[2,1].plot(np.arange(start-10,end), t_output_list[start-10:end,2], color='red',label='Internal Loop')
ax[1].scatter(np.arange(0,end), t_pred[:end], color='slateblue',label='Training t+1')
# ax[1].scatter(np.arange(start-params.OUTPUT_SIZE,end), t_pred[start-2:end+8,1], color='lightsteelblue',label='Training t+2')
# ax[1].scatter(np.arange(start-params.OUTPUT_SIZE,end), t_pred[start-3:end+7,2], color='gray',label='Training t+3')
# ax[1].scatter(np.arange(start-params.OUTPUT_SIZE,end), t_pred[start-4:end+6,3], color='sienna',label='Training t+4')
# ax[1].scatter(np.arange(start-params.OUTPUT_SIZE,end), t_pred[start-5:end+5,4], color='magenta',label='Training t+5')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-6:end+4,5], color='aquamarine',label='Training t+6')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-7:end+3,6], color='darkorange',label='Training t+7')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-8:end+2,7], color='brown',label='Training t+8')
# ax[1].scatter(np.arange(start-10,end), t_pred[start-9:end+1,8], color='purple',label='Training t+9')
# ax[1].plot(np.arange(start-10,end), t_pred[start-10:end], color='green',label='Training t+10')

ax[1].set_title('Training LFPNet')
ax[1].set_ylabel('LFP')
ax[1].set_xlabel('Time')
ax[0].legend(loc=2, prop={'size': 10})

# import plotly.tools as tls
# plotly_fig = tls.mpl_to_plotly(fig)
# plotly_fig.write_html("testfile.html")
plt.show()

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1)
fig.tight_layout()

ax.plot(np.arange(start,end), f_real[start:end], color='blue',label='Labels')
ax.scatter(np.arange(start,end), f_pred[start:end], color='red',label='Training t+10')
ax.set_title('Full LFP vs Time')
ax.set_ylabel('Signal')
ax.set_xlabel('Time')

# ax[1].plot(np.arange(start,end), n_real[start:end], color='blue',label='Labels')
# ax[1].scatter(np.arange(start,end), n_pred[start:end], color='red',label='Training t+10')
# ax[1].set_title('Noise')
# ax[1].set_ylabel('Signal')
# ax[1].set_xlabel('Time')z

# ax.plot(np.arange(start,end), s_real[start:end], color='blue',label='Labels')
# ax.scatter(np.arange(start,end), s_pred[start:end], color='red',label='Training t+10')
# ax.set_title('Sine')
# ax.set_ylabel('LFP')
# ax.set_xlabel('Time')

plt.show()

In [None]:
import loss_landscapes
import loss_landscapes.metrics
from mpl_toolkits.mplot3d import axes3d, Axes3D 

STEPS = 100
# model_initial = params.MODEL(params.INPUT_SIZE,params.HIDDEN_SIZE,params.OUTPUT_SIZE)
model_final = copy.deepcopy(model1)


# data that the evaluator will use when evaluating loss
x, y = iter(noise_loader).__next__()
metric = loss_landscapes.metrics.Loss(nn.MSELoss(), x, y)


loss_data_fin = loss_landscapes.random_plane(model_final, metric, 10000, STEPS, normalization='model', deepcopy_model=True)
# plt.contour(loss_data_fin, levels=50)
# plt.title('Loss Contours around Trained Model')
# plt.show()

In [None]:
fig = plt.figure()
ax = Axes3D(fig)
X = np.array([[j for j in range(STEPS)] for i in range(STEPS)])
Y = np.array([[i for _ in range(STEPS)] for i in range(STEPS)])
ax.plot_surface(X, Y, loss_data_fin, rstride=1, cstride=1, cmap='viridis', edgecolor='none')
ax.set_title('Surface Plot of Loss Landscape')
fig.show()

In [None]:
print(model1.state_dict()['conv_block.0.weight'])
print(model1.state_dict()['conv_block.0.bias'])
print(model1.state_dict()['conv_block.2.weight'])
print(model1.state_dict()['conv_block.2.bias'])
print(model1.state_dict()['ck1s1.weight'])
print(model1.state_dict()['ck1s1.bias'])
print(model1.state_dict()['fc1.weight'])
print(model1.state_dict()['fc1.bias'])
print(model1.state_dict()['fc2.weight'])
print(model1.state_dict()['fc2.bias'])

In [None]:
print(model1.state_dict())