In [1]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
import time
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from libworm.torch.beta_net import BetaNeuronNet, from_connectome
from libworm.data import connectomes, traces
from libworm import preprocess
from libworm.functions import set_neurons, tcalc_s_inf, set_trace
#from libworm.training import basic_train

In [2]:
def new_s(V_m):
    a_r = 1
    a_d = 5
    beta = 0.125
    V_th = -15 #??
    sig = 1 / (1 + np.exp(-beta * (V_m - V_th)))

    return (a_r * sig) / (a_r * sig + a_d)

In [3]:
def basic_train(model, criterion, optimizer, points, labels,
          cell_labels, epoches=5, batch=64,
          timestep=0.001, data_timestep=0.60156673, do_print=True):

    dataset = TensorDataset(points, labels)
    dataloader = DataLoader(dataset, batch_size=batch, shuffle=True)
    
    total_start_time = time.time()
    loss = -1
    for i in range(1,epoches+1):
        for points_batch, labels_batch in dataloader:
            start_time = time.time()
            optimizer.zero_grad()
            
            sim_time = 0.0
            next_timestamp = 0.0

            voltage = torch.ones((points_batch.shape[0], len(cell_labels))) * -20.0
            gates = torch.ones((points_batch.shape[0], len(cell_labels))) * -20.0

            with torch.no_grad():
                #Prepare
                for i in range(points_batch.shape[2]):
                    while True:
                        if sim_time >= next_timestamp:
                            inter = F.pad(points_batch[:, :, i], (0, voltage.shape[1] - points_batch.shape[1]), "constant", 0.0)
                            voltage = voltage.where(inter == 0.0, inter)
                            gates = tcalc_s_inf(voltage)
                            next_timestamp += data_timestep
                            break;
                        voltage, gates,_,_ = model(voltage, gates, timestep)
                        sim_time += timestep
                    break

            while sim_time < next_timestamp:
                voltage, gates,_,_ = model(voltage, gates, timestep)
                sim_time += timestep


            final_output = voltage[:, :labels_batch.shape[1]]
            
            #Compare
            loss = criterion(final_output, labels_batch)
            loss.backward()
            
            optimizer.step()
            
            end_time = time.time()
            print(f"Batch Complete: Time {start_time - end_time} Loss: {loss.item()}")
        
        

            
    total_end_time = time.time()
    total_time_taken = total_end_time - total_start_time
    if(do_print):
        print(f"Total Time {total_time_taken}")
        
    final_loss = loss.item()
    
    return (total_time_taken, final_loss)

In [4]:
"""
This process is likely currently wrong
"""

torch.manual_seed(4687)

trace, _, trace_labels, label2index, timestamps = traces.load_trace()
timestamps = timestamps - timestamps[0]

chemical, gapjn = connectomes.load_cook_connectome()
neurons = connectomes.get_main_neurons(chemical, gapjn)
neurons.sort(key=lambda item: f"AAA{label2index[item]:04d}{item}" if item in label2index else item)
model = from_connectome(chemical, gapjn, neurons)

cell = "SMBVR"

first_removal = [label2index[key] for key in label2index if key not in neurons]
trace = np.delete(trace, first_removal, axis=0)

del_index = 0
size = trace.shape[0]

for i in range(size):
    if i not in label2index.values():
        trace = np.delete(trace, (del_index), axis=0)
    else:
        del_index += 1

voltage = preprocess.trace2volt(trace)

points, labels = preprocess.window_split(voltage, window_size = 16, points_size = 15)
points = torch.from_numpy(np.squeeze(points))
labels = torch.from_numpy(np.squeeze(labels))

train_x, test_x, train_y, test_y = train_test_split(points, labels, train_size=0.1)

optimiser = optim.Adam(model.parameters(), lr=0.001)
crit = nn.MSELoss()

In [5]:
labels.shape

torch.Size([1584, 101])

In [6]:
results = basic_train(model, crit, optimiser,
                train_x, train_y, neurons,
                epoches=2, batch=6, timestep=0.01)

Batch Complete: Time -0.23248744010925293 Loss: 653.1237154894314
Batch Complete: Time -0.17736506462097168 Loss: 649.5425761881489
Batch Complete: Time -0.17402935028076172 Loss: 643.2438723141395
Batch Complete: Time -0.16890311241149902 Loss: 642.9759388654593
Batch Complete: Time -0.1834876537322998 Loss: 643.6777355323189
Batch Complete: Time -0.18149089813232422 Loss: 643.2411870793468
Batch Complete: Time -0.18052148818969727 Loss: 640.8189611675426
Batch Complete: Time -0.16945862770080566 Loss: 649.7647942528965
Batch Complete: Time -0.18130064010620117 Loss: 648.4245876210554
Batch Complete: Time -0.17880678176879883 Loss: 638.1279731361715
Batch Complete: Time -0.1853346824645996 Loss: 638.0032292845351
Batch Complete: Time -0.174424409866333 Loss: 642.9300634256706
Batch Complete: Time -0.18352150917053223 Loss: 647.1594580609232
Batch Complete: Time -0.17026662826538086 Loss: 641.6702697839244
Batch Complete: Time -0.19524168968200684 Loss: 643.5961697616516
Batch Complete

In [7]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([10.0543, 10.0545, 10.0543, 10.0060, 10.0543, 10.0543, 10.0541, 10.0542,
        10.0541, 10.0542, 10.0060,  9.9940, 10.0542,  9.9940, 10.0541, 10.0542,
         9.9940,  9.9940, 10.0547, 10.0543, 10.0194, 10.0541, 10.0544, 10.0547,
        10.0541, 10.0060, 10.0541, 10.0060, 10.0541, 10.0547,  9.9940, 10.0528,
         9.9940,  9.9940, 10.0542, 10.0539, 10.0060, 10.0541, 10.0528, 10.0540,
        10.0060, 10.0541, 10.0541, 10.0544, 10.0060, 10.0060, 10.0060, 10.0543,
        10.0540, 10.0546, 10.0529, 10.0541, 10.0540, 10.0543, 10.0539, 10.0540,
        10.0542, 10.0530, 10.0543, 10.0529, 10.0542, 10.0529, 10.0542, 10.0542,
        10.0060, 10.0529, 10.0541, 10.0544, 10.0541, 10.0541, 10.0541, 10.0543,
        10.0544, 10.0543, 10.0542, 10.0060, 10.0544, 10.0060, 10.0540, 10.0543,
        10.0543, 10.0060, 10.0060, 10.0546, 10.0544, 10.0547, 10.0547,  9.9940,
        10.0543,  9.9940, 10.0541, 10.0544, 10.0543, 10.0545,  9.9940, 10.0541,
        10.0527, 1

In [72]:
#torch.save({
#            'model_state_dict': model.state_dict(),
#            'optimizer_state_dict': optimiser.state_dict(),
#            }, "checkpoints/ruby/modellr001dt01.pt")