In [1]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
import time
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from libworm.torch.beta_net import BetaNeuronNet, from_connectome
from libworm.data import connectomes, traces
from libworm import preprocess
from libworm.functions import set_neurons, tcalc_s_inf, set_trace

In [2]:
def new_s(V_m):
    a_r = 1
    a_d = 5
    beta = 0.125
    V_th = -15 #??
    sig = 1 / (1 + np.exp(-beta * (V_m - V_th)))

    return (a_r * sig) / (a_r * sig + a_d)

In [3]:

V = torch.from_numpy(np.array([40.0, -40.0]))
s = torch.from_numpy(np.array([new_s(40.0), new_s(-40.0)]))
G_leak = np.array([10.0 for V_m in V])
E_leak = np.array([-35.0 for V_m in V])
G_syn = np.array([[0.0, 50.0], [80.0, 0.0]])
E_syn = np.array([0.0 for V_m in V])
G_gap = np.array([[0.0, 100.0], [100.0, 0.0]])

net = BetaNeuronNet(G_leak, E_leak, G_syn, E_syn, G_gap)

In [4]:
def train(model, criterion, optimizer, points, labels,
          cell_labels, epoches=5, batch=64,
          timestep=0.001, data_timestep=0.60156673, do_print=True):

    dataset = TensorDataset(points, labels)
    dataloader = DataLoader(dataset, batch_size=batch, shuffle=True)
    
    total_start_time = time.time()
    loss = -1
    for i in range(1,epoches+1):
        for points_batch, labels_batch in dataloader:
            start_time = time.time()
            optimizer.zero_grad()
            
            sim_time = 0.0
            next_timestamp = 0.0

            voltage = torch.ones((batch, len(cell_labels))) * -20.0
            gates = torch.ones((batch, len(cell_labels))) * -20.0

            
            #Prepare
            for i in range(points_batch.shape[2]):
                while True:
                    if sim_time >= next_timestamp:
                        inter = F.pad(points_batch[:, :, i], (0, voltage.shape[1] - points_batch.shape[1]), "constant", 0.0)
                        voltage = voltage.where(inter == 0.0, inter)
                        gates = tcalc_s_inf(voltage)
                        next_timestamp += data_timestep
                        break;
                    voltage, gates,_,_ = model(voltage, gates, timestep)
                    sim_time += timestep

            print(sim_time)
            print(next_timestamp)

            while sim_time < next_timestamp:
                voltage, gates,_,_ = model(voltage, gates, timestep)
                sim_time += timestep


            print(voltage)
            final_output = voltage[:, :labels_batch.shape[1]]
            print(final_output)
            
            #Compare
            loss = criterion(final_output, labels_batch)
            loss.backward()

            for param in model.parameters():
                print(param.grad)
            
            optimizer.step()
            
            end_time = time.time()
            print(f"Batch Complete: Time {start_time - end_time} Loss: {loss.item()}")
        
        

            
    total_end_time = time.time()
    total_time_taken = total_end_time - total_start_time
    if(do_print):
        print(f"Total Time {total_time_taken}")
        
    final_loss = loss.item()
    
    return (total_time_taken, final_loss)

In [5]:
torch.manual_seed(4687)

trace, _, trace_labels, label2index, timestamps = traces.load_trace()
timestamps = timestamps - timestamps[0]

chemical, gapjn = connectomes.load_cook_connectome()
neurons = connectomes.get_main_neurons(chemical, gapjn)
neurons.sort(key=lambda item: f"AAA{label2index[item]:04d}{item}" if item in label2index else item)
model = from_connectome(chemical, gapjn, neurons)

cell = "SMBVR"

first_removal = [label2index[key] for key in label2index if key not in neurons]
trace = np.delete(trace, first_removal, axis=0)

del_index = 0
size = trace.shape[0]

for i in range(size):
    if i not in label2index.values():
        trace = np.delete(trace, (del_index), axis=0)
    else:
        del_index += 1

voltage = preprocess.trace2volt(trace)

points, labels = preprocess.window_split(voltage, window_size = 16, points_size = 15)
points = torch.from_numpy(np.squeeze(points))
labels = torch.from_numpy(np.squeeze(labels))

train_x, test_x, train_y, test_y = train_test_split(points, labels, train_size=0.1)

optimiser = optim.Adam(model.parameters(), lr=0.0001)
crit = nn.MSELoss()

In [None]:
results = train(model, crit, optimiser,
                train_x, train_y, neurons,
                epoches=1, batch=6, timestep=0.005)

8.424999999999917
9.02350095
tensor([[-25.9700, -27.5865, -27.1629,  ...,  -0.3846,  -0.2893,  -0.2318],
        [-24.7836, -28.6360, -26.6926,  ...,  -0.3846,  -0.2893,  -0.2318],
        [-26.1626, -27.5442, -26.5992,  ...,  -0.3846,  -0.2893,  -0.2318],
        [-24.9665, -28.3307, -26.1691,  ...,  -0.3846,  -0.2893,  -0.2318],
        [-25.4628, -28.5630, -25.6076,  ...,  -0.3846,  -0.2893,  -0.2318],
        [-24.6177, -27.9238, -26.3622,  ...,  -0.3846,  -0.2893,  -0.2318]],
       dtype=torch.float64, grad_fn=<WhereBackward0>)
tensor([[-25.9700, -27.5865, -27.1629, -28.0318, -27.9089, -27.6074, -27.9274,
         -26.1901, -26.5756, -27.6298, -24.4038, -26.2039, -26.5368, -27.8549,
         -26.6593, -28.4778, -26.5457, -25.4878, -29.6140, -28.7832, -29.4057,
         -25.6314, -28.5284, -27.5018, -28.5653, -25.5438, -26.7207, -28.7954,
         -24.9171, -27.4728, -27.6707, -25.6833, -23.3747, -26.1090, -27.8100,
         -28.8184, -29.1437, -28.6304, -24.5208, -22.8106, -24.21