In [None]:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoMultivariateNormal,AutoLowRankMultivariateNormal,AutoLaplaceApproximation,AutoIAFNormal
from pyro.optim import Adam
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from rfest.utils import fetch_data,build_design_matrix


import os
# Setting seed for reproducibility
pyro.set_rng_seed(0)

def define_adjacency_matrix(n_time_bins, n_frequency_bins):
    adjacency_matrix = np.zeros((n_time_bins * n_frequency_bins, n_time_bins * n_frequency_bins))

    for i in range(n_time_bins):
        for j in range(n_frequency_bins):
            current_index = i * n_frequency_bins + j

            # time adjacency
            if i > 0:  # not the first time bin
                adjacency_matrix[current_index, (i - 1) * n_frequency_bins + j] = 1
            if i < n_time_bins - 1:  # not the last time bin
                adjacency_matrix[current_index, (i + 1) * n_frequency_bins + j] = 1

            # frequency adjacency
            if j > 0:  # not the first frequency bin
                adjacency_matrix[current_index, i * n_frequency_bins + j - 1] = 1
            if j < n_frequency_bins - 1:  # not the last frequency bin
                adjacency_matrix[current_index, i * n_frequency_bins + j + 1] = 1

    return adjacency_matrix
# Define the model

def model(X, Y, adjacency_matrix):
    num_features = adjacency_matrix.shape[0]

    alpha = pyro.param("alpha", torch.tensor(1), constraint=dist.constraints.interval(0.0001,10000))
    rho = pyro.param("rho", torch.tensor(1.), constraint=dist.constraints.interval(0.00001,0.9999))

    #multipPrec=alpha/(1-rho.pow(2))
    #precision_matrix = torch.zeros((num_features, num_features))
    #precision_matrix[adjacency_matrix == 1] = -rho
    #precision_matrix[torch.eye(num_features) == 1] = 1+rho.pow(2)
    #precision_matrix=precision_matrix*multipPrec

    precision_matrix = torch.zeros((num_features, num_features))
    precision_matrix[adjacency_matrix == 1] = rho
    precision_matrix[torch.eye(num_features) == 1] = alpha



    beta = pyro.sample("beta", dist.MultivariateNormal(torch.zeros(num_features), precision_matrix=(precision_matrix)))

    with pyro.plate("data", len(Y)):
        #mu = torch.exp(X.matmul(beta))  # compute the expected response using log link
        #y = pyro.sample("y", dist.Poisson(mu), obs=Y)
        mu = X.matmul(beta)
        #sigma = pyro.sample("sigma", dist.HalfNormal(1.))
        y = pyro.sample("y", dist.Normal(mu, 1), obs=Y)




In [None]:

# Generate synthetic data
#num_features = 3
#num_data = 1000
#true_beta = torch.tensor([1.0, 2.0, 3.0])
#X = torch.randn(num_data, num_features)
#Y = torch.matmul(X, true_beta) + 0.5 * torch.randn(num_data)


dat=fetch_data(2)
timelags=30
X=dat['X']
Xdsgn = build_design_matrix(X, timelags)

X=torch.tensor(Xdsgn)
X=X.float()
Y=torch.tensor(dat['y'])
Y=Y.float()
adjacency_matrix = torch.ones((X.shape[1], X.shape[1]))  # Fully connected graph
adjacency_matrix[torch.eye(750) == 1]=2

In [None]:
timelags=25
X=loadmat(os.getcwd()+'/data/X.mat')['X']
Xdsgn = build_design_matrix(X, timelags)

X=torch.tensor(Xdsgn)
X=X.float()
Y=torch.tensor(loadmat(os.getcwd()+'/data/y.mat')['Y'].flatten()).float()
adjacency_matrix = torch.ones((X.shape[1], X.shape[1]))  # Fully connected graph
adjacency_matrix[torch.eye(X.shape[1]) == 1]=2

In [None]:
# Define the guide function
#guide = AutoMultivariateNormal(model)
guide =AutoLowRankMultivariateNormal(model)

# Set up the optimizer and inference algorithm
adam_params = {"lr": 0.01}
optimizer = Adam(adam_params)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# Run inference and print progress
num_steps = 300
for step in range(num_steps):
    print(step)
    loss = svi.step(X, Y, adjacency_matrix)
    if step % 500 == 0:
        print(f"Step {step}, loss: {loss}")




# Get parameters to compute credible intervals from posterior
A=pyro.get_param_store()['AutoMultivariateNormal.loc'].detach()+(2.5*pyro.get_param_store()['AutoMultivariateNormal.scale'].detach())
B=pyro.get_param_store()['AutoMultivariateNormal.loc'].detach()-(2.5*pyro.get_param_store()['AutoMultivariateNormal.scale'].detach())

beta=np.array(pyro.get_param_store()['AutoMultivariateNormal.loc'].detach()).reshape(15,50)
sns.heatmap(beta*(np.array((np.sign(A)*np.sign(B)).reshape(15,50))>0))




# Print the learned parameters
print("Learned parameters:")
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name).item())

In [None]:
#If we want poisson spike counts.

# mu = torch.exp(X.matmul(beta))  # compute the expected response using log link
# y = pyro.sample("y", dist.Poisson(mu), obs=Y)