In [1]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import time

from libworm.model.beta_neuron import NeuronNetwork, from_connectome
from libworm.data import connectomes, traces
from libworm import preprocess
from libworm.functions import set_neurons, tcalc_s_inf, set_trace

import json

In [7]:
_, trace, trace_labels, label2index, timestamps = traces.load_trace()
timestamps = timestamps - timestamps[0]

chemical, gapjn = connectomes.load_cook_connectome()
neurons = connectomes.get_main_neurons(chemical, gapjn)
neurons.sort(key=lambda item: f"AAA{label2index[item]:04d}{item}" if item in label2index else item)
model = from_connectome(chemical, gapjn, neurons)

cell = "SMBVR"

first_removal = [label2index[key] for key in label2index if key not in neurons]
trace = np.delete(trace, first_removal, axis=0)

del_index = 0
size = trace.shape[0]

for i in range(size):
    if i not in label2index.values():
        trace = np.delete(trace, (del_index), axis=0)
    else:
        del_index += 1

trace_pairs = [(time, trace[:, i]) for i, time in enumerate(timestamps)]

In [8]:
model.store_data = False

In [9]:
model.simple_run(0.001, 1.0, limiter=True, direct_s=False)

##########

In [10]:
rust_data = [-2.641593587539065, -3.463425039260168, -2.4703011821665433, -3.1720655044941912, -4.032053275577738, -2.3288620547375225, -15.754297082844465, -3.318435453196809, -3.9292183661154056, -2.7301553563714283, -2.673331947731194, -2.8759793471782795, -2.4454152636158066, -4.206067117408041, -4.446945922154035, -2.5631981819485796, -2.8535401338767348, -3.038297713118824, -5.890049196297366, -8.424305983231722, -3.519958672290442, -2.8010975842785975, -2.573504167109626, -1.9902593233890324, -3.4556477793907634, -5.082453204542678, -3.7830587871940686, -3.2881861726779587, -3.141828161711719, -5.315876186128618, -6.304471469867526, -2.9184743350064863, -3.818963463311209, -2.4879885626354454, -3.199925871134787, -4.294327800295138, -4.285308540272236, -4.418507126219828, -14.145860213055526, -3.0785387644946742, -16.70435211203899, -3.5200788453974767, -5.3657996985326, -2.649859237231871, -5.6496113859338175, -2.8143366410836865, -3.12208478885013, -2.0038935154901534, -4.204736726717702, -3.052383742103656, -2.8621308057803216, -2.8090500041860476, -2.5227093213718157, -3.4400583341892115, -4.331319435339919, -3.5067378395168833, -3.132181070434509, -3.971467197262812, -2.5360206962972693, -3.144690041665768, -2.1191397674345134, -18.836425521469423, -2.4585549815018406, -3.355214822527155, -2.8728319405510274, -4.682368628323502, -2.551299903047112, -5.33254623382012, -4.744710668706843, -2.8319446572071527, -2.8868599781815627, -6.118070133866784, -7.784000053904412, -5.378124597402811, -3.9933551889761523, -2.2153722559126594, -26.096366497978043, -3.187819811276852, -2.949935749950914, -3.693732624137969, -3.016062435029361, -5.8661992916981625, -5.082561129185513, -2.2534086318218702, -12.546535252815563, -3.311733476117176, -7.051268193746005, -3.7372105524832757, -2.949196693517657, -3.325201830308552, -2.0383510411591796, -2.5432707561293753, -2.6962786058332546, -4.137933934701366, -7.711679002548325, -4.255869975200972, -3.8776169390516046, -2.7203629415658357, -2.8248868948737567, -4.419666101665773, -3.1920520710913762, -2.373180674971791, -2.8319441479216154, -2.3711474541425153, -3.3684840098184217, -2.821438687086958, -2.734008021888524, -2.556342450849199, -3.413686998173962, -2.8083383517989007, -4.872771713201246, -3.343423998888959, -4.474450895218446, -4.21904611456199, -3.7980469979540197, -4.7639637174620155, -4.077734267083452, -4.256205384697719, -4.908711123956558, -3.3071641914786225, -3.7057032448022094, -4.227164472826841, -3.824564319875932, -3.863101350617692, -5.514985619735745, -4.611234398113965, -5.724585124149614, -2.4081013783918084, -2.9386597504613157, -3.3463176731578894, -3.0910183733339136, -4.276085044703539, -4.724502316594131, -4.061816349079984, -4.115867350465093, -3.986125816999108, -5.919099206147226, -4.934382412650125, -6.14980598097428, -5.914504669110993, -5.260609255656881, -4.303770557477695, -4.5752200894064945, -3.0507228656855263, -2.364402456403062, -4.285359164080943, -4.763213761046742, -4.310193776895637, -3.7986664937386814, -4.175185289868765, -4.346376960927397, -3.6830560722759387, -4.039914180801088, -4.648291333940467, -5.829542768174671, -6.171446011104282, -4.362929613879324, -4.572506348302602, -7.8332528660443055, -8.233637229963302, -5.065749691200527, -5.045653216278518, -4.703971377754473, -23.951834778721977, -36.250877288181485, -30.651867919619622, -28.474993395323043, -28.144605393695127, -15.041617702803688, -4.829265873813812, -9.747087522338735, -7.9721346351744495, -4.3106336110063435, -5.9997078433547255, -6.903703531535303, -4.830556349616099, -5.381458606503264, -8.391235015941596, -7.694318676161048, -8.25868121567472, -7.36362856285149, -6.183839380374595, -6.362009235322208, -5.07749372816445, -5.996102144224814, -5.763213231961706, -5.409543314332033, -5.523148643018384, -6.422422236839719, -5.609040234165169, -6.687387632325228, -4.936088591933886, -5.279344707921924, -4.699324574260612, -4.77928559974474, -7.660434480514826, -8.213694279564624, -4.701200663708312, -5.26646062839259, -6.309577017543431, -6.865981236456344, -4.6043070269881055, -5.371460532892396, -5.503980563228987, -7.649326025356575, -6.178171912930247, -6.38799689321454, -3.139517778172116, -5.770268949961997, -6.9762706468257, -5.57522978917413, -4.545881809629281, -3.6018583563898834, -3.78713266963663, -6.062865548975718, -1.8916153338156088, -12.608371812710903, -2.873954415414715, -2.9510588504143116, -2.3887198371861404, -3.4168948169879836, -2.559299837340089, -2.458798618983731, -4.348176464529082, -3.670876211215935, -8.024337305622218, -4.538815287513011, -3.9406892859809663, -2.988224822674377, -3.794991181655444, -3.3283844138073393, -3.0228071338706575, -3.75951367686857, -3.120009254557389, -3.767578413941444, -3.1785915665369977, -3.0879212113912327, -5.257073885120739, -3.635235213765076, -3.8253901394738454, -7.270540394004563, -7.227742525179206, -3.9011098147203835, -3.725309236170525, -4.700057953602986, -4.34256067772114, -4.50774186035037, -4.439018803612847, -4.559291258237158, -4.944653628804407, -6.679628376193498, -3.79314281572579, -4.121743638892777, -4.188966019819524, -4.60207247131913, -4.992516120023794, -4.830872835919496, -4.735313469282626, -6.047087088347695, -8.114969225776965, -6.021480235132905, -13.808189806413818, -14.026205704459299, -13.258693080737974, -7.922442384627366, -6.921992310005356, -7.924505886575328, -16.45625325064588, -26.964563386533218, -31.345330438151564, -38.71984396634861, -37.64849260314483, -33.112843348086905, -34.74133993794044, -33.14943826719487, -24.829274073928428, -30.181250541288957, -19.19358595768863, -14.953442085339613, -12.041738663950525]

In [11]:
model.big_V - rust_data

array([-0.01362222, -0.01434486, -0.01319629, -0.01362675, -0.01574997,
       -0.01331144,  0.02704071, -0.01094617, -0.01188219, -0.01350337,
       -0.01317782, -0.01406301, -0.01288543, -0.01651611, -0.01122452,
       -0.01328891, -0.01376212, -0.01372888, -0.03430451, -0.00967155,
       -0.01545334, -0.01325295, -0.01266738, -0.01073713, -0.01136353,
       -0.01424081, -0.01508927, -0.01580634, -0.01349095, -0.03359261,
       -0.01088171, -0.0132662 , -0.01203701, -0.01315239, -0.01363737,
       -0.01501846, -0.0160954 , -0.01006718,  0.02229325, -0.01411601,
        0.02628981, -0.01282592, -0.00950353, -0.0122724 , -0.00871322,
       -0.01459583, -0.01359438, -0.01124531, -0.0161322 , -0.01481991,
       -0.014053  , -0.01262492, -0.01265096, -0.01482685, -0.00944871,
       -0.01336799, -0.01395603, -0.01665803, -0.01349533, -0.01427452,
       -0.01246583,  0.03858428, -0.01306307, -0.0135744 , -0.01347213,
       -0.0128637 , -0.0132215 , -0.01471855, -0.01643537, -0.01

In [10]:
with open("processed_data/neurons.json", "w") as file:
    json.dump(neurons, file)

In [3]:
def eval(model, start_index, data):
    np.seterr(all='raise')

    try:
        for i in range(15):
            runtime =  data[start_index + i + 1][0] - data[start_index + i][0]
            points = data[start_index + i][1]
            model.big_V[:len(points)] = points
            model.simple_run(0.01, runtime, show_progress=False)
    
        start_index += 15
        error = 0
        
        for i in range(15):
            runtime =  data[start_index + i + 1][0] - data[start_index + i][0]
            points = data[start_index + i][1]
    
            model.simple_run(0.01, runtime, show_progress=False)
            error += np.sum(np.square(model.big_V[:len(points)] - points))
            
    except FloatingPointError:
        error = 10 ** 10

    return error


In [4]:
def reduce2index(array_like):
    return [i for i, item in enumerate(array_like.flatten()) if item != 0]

def reduce(array_like):
    return [item for item in array_like.flatten() if item != 0]

def expand(values, indices, shape):
    size = 1

    for item in shape:
        size *= item

    items = [0] * size

    for i, index in enumerate(indices):
        items[index] = values[i]

    return np.array(items).reshape(shape)

In [5]:
class ModelFactory:
    def __init__(self, length, G_syn_index, G_gap_index):
        self.length = length
        self.G_syn_index = G_syn_index
        self.G_gap_index = G_gap_index

    def get_spec(self):
        return self.length, len(self.G_syn_index), len(self.G_gap_index)

    def build(self, inital_V_value, short_G_syn, short_G_gap, short_E_syn, G_leak, E_leak):

        big_V = np.array([inital_V_value] * self.length)
        
        G_syn = expand(short_G_syn, self.G_syn_index, (self.length, self.length))
        E_syn = expand(short_E_syn, self.G_syn_index, (self.length, self.length))
        
        G_gapjn = expand(short_G_gap, self.G_gap_index, (self.length, self.length))
        
        return NeuronNetwork(big_V, G_syn, G_gapjn, E_syn, labels=neurons, G_leak = G_leak, E_leak = E_leak)

    def build_random(self, rng):
        neuron_size, syn_size, gap_size = self.get_spec()

        inital_v = rng.uniform(-10, 10)
        G_syn = rng.uniform(0, 100, syn_size)
        E_syn = rng.uniform(-100, 100, syn_size)
        G_gap = rng.uniform(0, 100, gap_size)
        G_leak = rng.uniform(0, 100, neuron_size)
        E_leak = rng.uniform(-100, 100, neuron_size)

        return self.build(inital_v, G_syn, G_gap, E_syn, G_leak, E_leak)

In [6]:
factory = ModelFactory(len(neurons),
                       reduce2index(model.big_G_syn),
                       reduce2index(model.big_G_gap))

neuron_size, syn_size, gap_size = factory.get_spec()
rng = np.random.default_rng(3471)

models = [factory.build_random(rng) for i in range(10)]
evals = [eval(model, 50, trace_pairs) for model in models]

In [26]:
eval(model, 50, trace_pairs)

######################################################################################################################################################################################################################################################################################################################

53823.49032367453

In [8]:
280*280

78400

In [27]:
new_model.report()

Neurons 280 (280)
V_max = -0.12468994397378985 (0)
V_min = -0.12468994397378985 (0)


In [28]:
eval(new_model, 50, trace_pairs)

##

10000000000