In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '..')

In [None]:
"""Constants"""

DT = 1
N_PARALLEL = 1000

In [None]:
'''Get datasets, all names'''
import data

DATASETS = data.get_all_data()
ALL_NAMES = data.get_all_names()
print('All identified neurons :\n', ALL_NAMES)

In [None]:
'''Neuron to select'''

NEURON = 'RID'

In [None]:
'''Get presynaptic neurons'''

import pylab as plt

dfs = data.get_synapses()
dfg = data.get_gaps()

gaps = dfg.index[dfg[NEURON] > 0].tolist()
syns = dfs.index[dfs[NEURON] > 0].tolist()
inputs = gaps + syns
print('Gaps :', gaps)
print('Synapses :', syns)

'''Plot graph'''
import networkx as nx

G = nx.DiGraph()
G.add_nodes_from(gaps + syns)
pos = nx.circular_layout(G)
pos[NEURON] = [0,0]
G.add_edges_from([(g, NEURON) for g in gaps])
G.add_edges_from([(g, NEURON) for g in syns])
nx.draw_networkx_labels(G, pos)
nx.draw_networkx_edges(G, pos, [(g, NEURON) for g in gaps], edge_color='b')
nx.draw_networkx_edges(G, pos, [(g, NEURON) for g in syns], edge_color='r')

plt.axis('off')
plt.draw()

In [None]:
'''Get datasets with neuron labelled'''

recs = []
max_len = 0
print('Datasets with %s identified : ' % NEURON)
for i, d in enumerate(DATASETS):
    if NEURON in d.columns:
        print(i)
        recs.append(d)
        if len(d) > max_len:
            max_len = len(d)

In [None]:
'''Collect input traces and initialize dt'''

import numpy as np
import torch

neurons = inputs + [NEURON]
traces = np.zeros((max_len, len(recs), len(neurons)))
dt = torch.zeros(len(recs),1,1)

plt.figure(figsize = (20,20))
for i, rec in enumerate(recs):
    for inp in inputs:
        if inp not in rec.columns:
            rec[inp] = np.zeros(rec.shape[0])
    traces[:rec.shape[0],i] = rec[neurons]
    dt[i] = 4000 / rec.shape[0]
    plt.subplot(len(recs),1,i+1)
    plt.plot(traces[:,i] + [1.1*j for j in range(len(neurons))])
    plt.yticks([1.1*j for j in range(len(neurons))], neurons)

plt.show()
print('dt : ', dt)

In [None]:
'''Define connections'''

conn_g = [(g, len(inputs)) for g in range(len(gaps))]
conn_s = [(s, len(inputs)) for s in range(len(gaps), len(inputs))]

print(conn_g, conn_s)

In [None]:
'''Create circuit'''

from odynn.circuit import Circuit
from odynn.models import LeakyIntegrate, ChemSyn, GapJunction

def get_circ(N_parallel = 500, dt=dt):
    n_par = LeakyIntegrate.get_default(len(inputs), N_parallel)
    pout = LeakyIntegrate.get_random(1, N_parallel)
    pn = {k: np.concatenate((n_par[k], pout[k]), 0) for k,v in pout.items()}
    n = LeakyIntegrate(init_p=pn, tensors=True, dt=dt)
    ps = ChemSyn.get_random(len(conn_s),N_parallel)
    ps['E'] = np.repeat([[-1],[0],[0],[1]], N_parallel, axis=-1) + np.random.rand(4,N_parallel) * 0.2
    s = ChemSyn([c[0] for c in conn_s], [c[1] for c in conn_s], 
                              init_p=ps, tensors=True, dt=dt)
    pg = GapJunction.get_random(len(conn_g),N_parallel)
    g = GapJunction([c[0] for c in conn_g], [c[1] for c in conn_g], 
                              init_p=pg, tensors=True, dt=dt)
    return Circuit(n, s, g)


get_circ().plot(labels={i: n for i,n in enumerate(inputs + [NEURON])}, img_size=5)

In [None]:
'''Correlation to initialize E'''

from scipy import stats

correlations = np.zeros((len(syns),1))

for r in range(len(recs)):
    for i in range(len(syns)):
        correlations[i] += np.corrcoef(traces[:,r,i], traces[:,r,-1])[0,1]
correlations /= len(recs)

In [None]:
'''Optimize'''

from tqdm import tqdm
import seaborn as sns
from odynn import optim
import pickle

target = torch.Tensor(traces[:, None, :, :, None])
init = target[0]
vmask = torch.zeros((1,1,target.shape[-2],1))
vmask[:,:,-1] = 1
vadd = target.clone()
vadd[:,:,:,-1] = 0
print(init.shape, target.shape)

"""Optimize out neuron"""
circuit = get_circ(N_PARALLEL)
circuit._synapses._param['E'] = torch.Tensor(np.repeat(correlations, N_PARALLEL, axis=-1))
circuit._synapses._param['E'].requires_grad = True

def load_param(name='params%s' % NEURON):
    with open(name, 'rb') as f:
        p = pickle.load(f)
    for sub in [circuit._neurons, circuit._synapses, circuit._gaps]:
        for n in sub._parameter_names:
            sub._param[n] = torch.Tensor(p[n])
            sub._param[n].requires_grad = True
# load_param()

ALIGN = [1.1*n for n in range(target.shape[-2])]
def plots(y, traces, loss):
    for i in range(len(recs)):
        plt.figure(figsize=(15,15))
        best = loss.argmin()
        plt.subplot(211)
        plt.plot(traces[:,0,i,:-1,0].detach().numpy() + ALIGN[:-1], linewidth=1)
        plt.plot(2*traces[:,0,i,-1,0].detach().numpy() + ALIGN[-1], linewidth=1.2, color='r')
        plt.plot(2*y[:,0,i,-1,best].detach().numpy() + ALIGN[-1], linewidth=1.1, linestyle='--', color='k')
        plt.xticks(ALIGN, inputs+[NEURON])
        plt.axis('off')
        plt.subplot(212)
        best_cat = torch.cat( (traces[:,0,i,:,0],y[:,0,i,-1:,best]), dim=1 ).detach().numpy().T
        sns.heatmap(best_cat, cmap='jet', vmin=0, vmax=1)
        plt.show()
        plt.close()
    
losses = []
params = [v for v in circuit.parameters.values()]
optimizer = torch.optim.Adam(params, lr=0.001)

for t in tqdm(range(701)):
    y = circuit.calculate(torch.zeros(traces.shape[0]), init, vmask=vmask, vadd=vadd)

    loss = optim.loss_mse(y, target)

    losses.append(loss.detach().numpy())
    # Upgrade variables
    optimizer.zero_grad()
    loss.mean().backward()
    for v in circuit._neurons.parameters.values():
        v.grad.data[:-1].zero_()
        
    optimizer.step()

    circuit.apply_constraints()
    
    print(loss.mean().detach().numpy(), loss.min().detach().numpy())

    if t%10 == 0:
        plots(y, target, loss)
        if loss.min() <= losses[-1].min():
            with open('paramspip', 'wb') as f:
                p = {k: v.detach().numpy() for k,v in circuit.parameters.items()}
                pickle.dump(p, f)
                
        

plt.plot([l for l in losses], linewidth=0.2)
plt.yscale('log')
plt.show()