In [1]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
import time
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from libworm.torch.beta_net import BetaNeuronNet, from_connectome
from libworm.data import connectomes, traces
from libworm import preprocess
from libworm.functions import set_neurons, tcalc_s_inf, set_trace

In [2]:
def new_s(V_m):
    a_r = 1
    a_d = 5
    beta = 0.125
    V_th = -15 #??
    sig = 1 / (1 + np.exp(-beta * (V_m - V_th)))

    return (a_r * sig) / (a_r * sig + a_d)

In [3]:

V = torch.from_numpy(np.array([40.0, -40.0]))
s = torch.from_numpy(np.array([new_s(40.0), new_s(-40.0)]))
G_leak = np.array([10.0 for V_m in V])
E_leak = np.array([-35.0 for V_m in V])
G_syn = np.array([[0.0, 50.0], [80.0, 0.0]])
E_syn = np.array([0.0 for V_m in V])
G_gap = np.array([[0.0, 100.0], [100.0, 0.0]])

net = BetaNeuronNet(G_leak, E_leak, G_syn, E_syn, G_gap)

In [4]:
def train(model, criterion, optimizer, points, labels,
          data_labels, label2index, epoches=5,
          batch=64, timestep=0.001, data_timestep=0.60156673, do_print=True):

    dataset = TensorDataset(points, labels)
    dataloader = DataLoader(dataset, batch_size=batch, shuffle=True)
    
    total_start_time = time.time()
    loss = -1
    for i in range(1,epoches+1):
        for points_batch, labels_batch in dataloader:
            start_time = time.time()
            optimizer.zero_grad()
            
            sim_time = 0.0
            next_timestamp = 0.0

            voltage = torch.ones((batch, len(data_labels)))
            gates = torch.ones((batch, len(data_labels)))

            #Prepare
            for i in range(points_batch.shape[2]):
                while True:
                    if sim_time >= next_timestamp:
                        inter = F.pad(points_batch[:, :, i], (0, 200), "constant", 0)
                        
                        gates = tcalc_s_inf(voltage)
                        next_timestamp += data_timestep
                        break;
                    voltage, gates,_,_ = model(voltage, gates, timestep)
                    sim_time += timestep

            final_output = torch.zeros(labels_batch.shape)
            
            for i in range(labels_batch.shape[2]):
                while True:
                    if sim_time >= next_timestamp:
                        set_trace(voltage, final_output, i, data_labels, label2index)
                        next_timestamp += data_timestep
                        break;
                        
                    voltage, gates,_,_ = model(voltage, gates, timestep)
                    sim_time += timestep
            
            #Compare
            loss = criterion(final_output, labels_batch)
            print(loss)
            loss.backward()

            for param in model.parameters():
                print(param.grad)
                break
            
            optimizer.step()
            
            end_time = time.time()
            print(f"Batch Complete: Time {start_time - end_time} Loss: {loss.item()}")
        
        

            
    total_end_time = time.time()
    total_time_taken = total_end_time - total_start_time
    if(do_print):
        print(f"Total Time {total_time_taken}")
        
    final_loss = loss.item()
    
    return (total_time_taken, final_loss)

In [5]:
torch.manual_seed(4687)

trace, _, _, label2index, timestamps = traces.load_trace()
timestamps = timestamps - timestamps[0]

voltage = preprocess.trace2volt(trace)

points, labels = preprocess.window_split(voltage)
points = torch.from_numpy(points)
labels = torch.from_numpy(labels)

train_x, test_x, train_y, test_y = train_test_split(points, labels, train_size=0.1)

chemical, gapjn = connectomes.load_cook_connectome()
neurons = connectomes.get_main_neurons(chemical, gapjn)
model = from_connectome(chemical, gapjn, neurons)

optimiser = optim.Adam(model.parameters(), lr=0.000001)
crit = nn.MSELoss()

In [6]:
inter = F.pad(points[:, :, 0], (0, 200), "constant", 0)

In [7]:
inter.shape

torch.Size([1570, 352])

In [None]:
results = train(model, crit, optimiser,
                train_x, train_y, neurons,
                label2index, epoches=1, batch=6, timestep=0.005)

tensor(664.1306, grad_fn=<MseLossBackward0>)
