In [1]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
import time
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from libworm.torch.beta_net import BetaNeuronNet, from_connectome
from libworm.data import connectomes, traces
from libworm import preprocess
from libworm.functions import set_neurons, tcalc_s_inf, set_trace
from libworm.training import basic_train

In [2]:
def new_s(V_m):
    a_r = 1
    a_d = 5
    beta = 0.125
    V_th = -15 #??
    sig = 1 / (1 + np.exp(-beta * (V_m - V_th)))

    return (a_r * sig) / (a_r * sig + a_d)

In [3]:

V = torch.from_numpy(np.array([40.0, -40.0]))
s = torch.from_numpy(np.array([new_s(40.0), new_s(-40.0)]))
G_leak = np.array([10.0 for V_m in V])
E_leak = np.array([-35.0 for V_m in V])
G_syn = np.array([[0.0, 50.0], [80.0, 0.0]])
E_syn = np.array([0.0 for V_m in V])
G_gap = np.array([[0.0, 100.0], [100.0, 0.0]])

net = BetaNeuronNet(G_leak, E_leak, G_syn, E_syn, G_gap)

In [4]:
def train(model, criterion, optimizer, points, labels,
          cell_labels, epoches=5, batch=64,
          timestep=0.001, data_timestep=0.60156673, do_print=True):

    dataset = TensorDataset(points, labels)
    dataloader = DataLoader(dataset, batch_size=batch, shuffle=True)
    
    total_start_time = time.time()
    loss = -1
    for i in range(1,epoches+1):
        for points_batch, labels_batch in dataloader:
            start_time = time.time()
            optimizer.zero_grad()
            
            sim_time = 0.0
            next_timestamp = 0.0

            voltage = torch.ones((points_batch.shape[0], len(cell_labels))) * -20.0
            gates = torch.ones((points_batch.shape[0], len(cell_labels))) * -20.0

            
            #Prepare
            for i in range(points_batch.shape[2]):
                while True:
                    if sim_time >= next_timestamp:
                        inter = F.pad(points_batch[:, :, i], (0, voltage.shape[1] - points_batch.shape[1]), "constant", 0.0)
                        voltage = voltage.where(inter == 0.0, inter)
                        gates = tcalc_s_inf(voltage)
                        next_timestamp += data_timestep
                        break;
                    voltage, gates,_,_ = model(voltage, gates, timestep)
                    sim_time += timestep

            while sim_time < next_timestamp:
                voltage, gates,_,_ = model(voltage, gates, timestep)
                sim_time += timestep


            final_output = voltage[:, :labels_batch.shape[1]]
            
            #Compare
            loss = criterion(final_output, labels_batch)
            loss.backward()
            
            optimizer.step()
            
            end_time = time.time()
            print(f"Batch Complete: Time {start_time - end_time} Loss: {loss.item()}")
        
        

            
    total_end_time = time.time()
    total_time_taken = total_end_time - total_start_time
    if(do_print):
        print(f"Total Time {total_time_taken}")
        
    final_loss = loss.item()
    
    return (total_time_taken, final_loss)

In [11]:
torch.manual_seed(4687)

trace, _, trace_labels, label2index, timestamps = traces.load_trace()
timestamps = timestamps - timestamps[0]

chemical, gapjn = connectomes.load_cook_connectome()
neurons = connectomes.get_main_neurons(chemical, gapjn)
neurons.sort(key=lambda item: f"AAA{label2index[item]:04d}{item}" if item in label2index else item)
model = from_connectome(chemical, gapjn, neurons)

cell = "SMBVR"

first_removal = [label2index[key] for key in label2index if key not in neurons]
trace = np.delete(trace, first_removal, axis=0)

del_index = 0
size = trace.shape[0]

for i in range(size):
    if i not in label2index.values():
        trace = np.delete(trace, (del_index), axis=0)
    else:
        del_index += 1

voltage = preprocess.trace2volt(trace)

points, labels = preprocess.window_split(voltage, window_size = 16, points_size = 15)
points = torch.from_numpy(np.squeeze(points))
labels = torch.from_numpy(np.squeeze(labels))

train_x, test_x, train_y, test_y = train_test_split(points, labels, train_size=0.1)

optimiser = optim.Adam(model.parameters(), lr=0.001)
crit = nn.MSELoss()

In [13]:
results = basic_train(model, crit, optimiser,
                train_x, train_y, neurons,
                epoches=30, batch=6, timestep=0.01)

Batch Complete: Time -2.6662378311157227 Loss: 587.1778851406265
Batch Complete: Time -2.8960859775543213 Loss: 585.3597129320218
Batch Complete: Time -2.7724034786224365 Loss: 589.639864671114
Batch Complete: Time -2.9743220806121826 Loss: 596.219682378934
Batch Complete: Time -2.7819461822509766 Loss: 590.2347914037812
Batch Complete: Time -3.4701781272888184 Loss: 582.4690958055418
Batch Complete: Time -2.8219363689422607 Loss: 587.8384309591423
Batch Complete: Time -2.6139931678771973 Loss: 593.5033284401806
Batch Complete: Time -2.57891583442688 Loss: 585.8708737254155
Batch Complete: Time -5.152312755584717 Loss: 585.2333247967224
Batch Complete: Time -4.216818809509277 Loss: 590.4881075092669
Batch Complete: Time -3.6616885662078857 Loss: 590.217066067216
Batch Complete: Time -3.30703067779541 Loss: 585.8152324164896
Batch Complete: Time -4.025714159011841 Loss: 590.8566488355168
Batch Complete: Time -3.6525044441223145 Loss: 591.6664037124865
Batch Complete: Time -3.72701096534

In [10]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimiser.state_dict(),
            }, "checkpoints/ruby/modellr001dt01.pt")