In [1]:
import sys
sys.path.insert(0,'..')
from gp3.inference.mfsvi import MFSVI
from gp3.likelihoods.likelihoods import Poisson
from gp3.utils import data as sim
from gp3.kernels.kernels import rbf
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
from plotly import tools
from IPython.display import display
init_notebook_mode(connected=True)
import GPy
import warnings
warnings.filterwarnings('ignore')
from tqdm import trange
import numpy as np
from lif import LIFLike, LIFSim

In [2]:
X = sim.sim_X_equispaced(D = 2, N_dim = 20, lower=0, upper=100)
f = sim.sim_f(X, rbf, np.log(np.array([40., 1.0])), mu = 4.) #- 1e-3*np.sum(np.square(X-50), 1)

lif_gen = LIFSim()
spikes = lif_gen.sim(f)

trace_func = go.Scatter3d(x = X[:,0], y = X[:,1], z=f, mode = 'markers', marker=dict(size = 2,), name = "cell shape")
trace_draws = go.Scatter3d(x = X[:,0], y = X[:,1], z=spikes, mode = 'markers', marker=dict(size = 2,), name = "spike times")
fig = tools.make_subplots(rows=1, cols=2, specs=[[{'is_3d': True}, {'is_3d': True}]])
fig.append_trace(trace_func, 1, 1)
fig.append_trace(trace_draws, 1, 2)
iplot(fig)

This is the format of your plot grid:
[ (1,1) scene1 ]  [ (1,2) scene2 ]



In [3]:
mu = np.ones(X.shape[0])*4
inf_svi = MFSVI(rbf, np.log(np.array([40., 1.0])), LIFLike(ts = spikes.astype(np.int32)), X, spikes.astype(np.int32), mu)
inf_svi.run(500)

ELBO: -12017.71 | KL: 11437.41 | logL: -580.31: 100%|██████████| 500/500 [00:07<00:00, 63.66it/s] 


In [4]:
trace_svipred = go.Scatter3d(x = X[:,0], y = X[:,1], z=inf_svi.predict(), mode = 'markers', marker=dict(size = 2,), name = "svi prediction")
trace_svivar = go.Scatter3d(x = X[:,0], y = X[:,1], z=np.exp(inf_svi.q_S), mode = 'markers', marker=dict(size = 2,), name = "svi variances")
fig = tools.make_subplots(rows=1, cols=3, specs=[[{'is_3d': True}, {'is_3d': True}, {'is_3d': True}]])
fig.append_trace(trace_func, 1, 1)
fig.append_trace(trace_svipred, 1, 2)
fig.append_trace(trace_svivar, 1, 3)
iplot(fig)

This is the format of your plot grid:
[ (1,1) scene1 ]  [ (1,2) scene2 ]  [ (1,3) scene3 ]

