In [1]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
import time
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from libworm.torch.gamma_net import GammaNeuronNet, from_connectome
from libworm.data import connectomes, traces
from libworm import preprocess
from libworm.functions import set_neurons, tcalc_s_inf, set_trace
#from libworm.training import basic_train

In [2]:
def basic_train(model, criterion, optimizer, points, labels,
          cell_labels, epoches=5, batch=64,
          timestep=0.001, data_timestep=0.60156673, do_print=True):

    dataset = TensorDataset(points, labels)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    total_start_time = time.time()
    loss = -1
    
    for i in range(1,epoches+1):
        for points_batch, labels_batch in dataloader:

            points_batch = torch.squeeze(points_batch)
            labels_batch = torch.squeeze(labels_batch)
            
            start_time = time.time()
            optimizer.zero_grad()

            results = model(points_batch, timestep, data_timestep)

            padded_labels = F.pad(labels_batch,
                                  (0, results.shape[0] - labels_batch.shape[0]),
                                  "constant",
                                  0)
            #Compare
            loss = criterion(results, padded_labels)
            loss.backward()
            
            optimizer.step()
            
            end_time = time.time()
            print(f"Batch Complete: Time {start_time - end_time} Loss: {loss.item()}")
 
    total_end_time = time.time()
    total_time_taken = total_end_time - total_start_time
    if(do_print):
        print(f"Total Time {total_time_taken}")
        
    final_loss = loss.item()
    
    return (total_time_taken, final_loss)

In [3]:
torch.manual_seed(4687)
_, trace, trace_labels, label2index, timestamps = traces.load_trace()
timestamps = timestamps - timestamps[0]

chemical, gapjn = connectomes.load_cook_connectome()
neurons = connectomes.get_main_neurons(chemical, gapjn)
neurons.sort(key=lambda item: f"AAA{label2index[item]:04d}{item}" if item in label2index else item)
model = from_connectome(chemical, gapjn, neurons)

not_in_main_section = [label2index[key] for key in label2index if key not in neurons]
not_labelled = [i for i, _ in enumerate(trace[:, 0]) if i not in label2index.values()]

removal = list(set(not_in_main_section).union(not_labelled))

trace = np.delete(trace, removal, axis=0)

points, labels = preprocess.window_split(trace, window_size = 2, points_size = 1)

points = torch.from_numpy(points.squeeze())
labels = torch.from_numpy(labels.squeeze())

train_x, test_x, train_y, test_y = train_test_split(points, labels, train_size=0.1)

optimiser = optim.Adam(model.parameters(), lr=0.001)
crit = nn.MSELoss()

In [6]:
model(points[0], 0.01, 10.0)

tensor([-2.4510, -4.1009, -2.2121, -2.8197, -3.5807, -3.4202, -4.2683, -2.8373,
        -2.5259, -2.4097, -2.1104, -2.5873, -1.9682, -3.0283, -2.6589, -2.2539,
        -2.7347, -2.7286, -5.6825, -3.9146, -3.7411, -2.9963, -1.9164, -3.4260,
        -3.4760, -3.2163, -2.7737, -2.5176, -2.6184, -5.1174, -3.3332, -3.1200,
        -3.4981, -2.8818, -2.6268, -3.8619, -3.6338, -3.9215, -4.7184, -2.7087,
        -5.0857, -2.8634, -3.0128, -2.2239, -3.1470, -2.3011, -2.7978, -1.7358,
        -3.7177, -2.8540, -2.3609, -3.3702, -2.1525, -3.3201, -3.6280, -2.7761,
        -2.7870, -6.2958, -2.2453, -2.7324, -1.8643, -3.7711, -2.2164, -2.5862,
        -2.4835, -3.3272, -2.2202, -4.0398, -3.3971, -2.5025, -3.3704, -3.2671,
        -5.6311, -5.9084, -3.0096, -2.7314, -2.8248, -2.9190, -2.4125, -3.0189,
        -2.7221, -3.1844, -3.3156, -4.3993, -6.3268, -5.6502, -5.8799, -2.7060,
        -3.8920, -2.5007, -2.0515, -4.6782, -2.4053, -3.1120, -4.3757, -3.3757,
        -3.3083, -2.3781, -0.0000, -0.00

In [4]:
results = basic_train(model, crit, optimiser,
                train_x, train_y, neurons,
                epoches=2, batch=6, timestep=0.01)

Batch Complete: Time -6.1863157749176025 Loss: 5.127374925191802
Batch Complete: Time -6.338157653808594 Loss: 3.2899625125327816
Batch Complete: Time -5.987959146499634 Loss: 3.432591177443225
Batch Complete: Time -6.43522834777832 Loss: 2.7273574983410125


KeyboardInterrupt: 

In [5]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([ 9.9960,  9.9960,  9.9960,  9.9960,  9.9960,  9.9961,  9.9960,  9.9961,
         9.9961,  9.9960,  9.9960,  9.9962,  9.9960,  9.9961,  9.9960,  9.9960,
         9.9961,  9.9960,  9.9960,  9.9960,  9.9960,  9.9961,  9.9963,  9.9962,
         9.9961,  9.9960,  9.9960,  9.9961,  9.9961,  9.9960,  9.9962,  9.9961,
         9.9961,  9.9961,  9.9961,  9.9961,  9.9960,  9.9960,  9.9961,  9.9961,
         9.9960,  9.9961,  9.9961,  9.9961,  9.9961,  9.9963,  9.9962,  9.9963,
         9.9961,  9.9960,  9.9961,  9.9961,  9.9961,  9.9960,  9.9961,  9.9962,
         9.9960,  9.9960,  9.9960,  9.9960,  9.9963,  9.9960,  9.9960,  9.9960,
         9.9962,  9.9962,  9.9960,  9.9961,  9.9960,  9.9960,  9.9961,  9.9960,
         9.9960,  9.9960,  9.9961,  9.9961,  9.9960,  9.9960,  9.9961,  9.9962,
         9.9961,  9.9960,  9.9960,  9.9961,  9.9960,  9.9961,  9.9960,  9.9960,
         9.9962,  9.9961,  9.9962,  9.9961,  9.9962,  9.9962,  9.9960,  9.9960,
         9.9962,  