In [12]:
import numpy as np
import pandas as pd
import torch
import torchode as to
from hydra import compose, initialize
from omegaconf import OmegaConf

from bioplnn.utils import AttrDict, initialize_model

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float64)

In [14]:
with initialize(version_base=None, config_path="config/", job_name="testing"):
    config = compose(
        config_name="config",
        overrides=["data=cifar10v1", "model=topo_sunny"],
    )
config = OmegaConf.to_container(config, resolve=True)
config = AttrDict(config)
sunny_root = "/om/user/sunnyd/dros-vision/"

In [15]:
# Load data
neuron_ids = pd.read_csv(sunny_root + "visual_neuron_ids.csv")
column_data = torch.tensor(np.load(sunny_root + "moving_mnist_vision.npy"))
input_neuron_ids = pd.read_csv(sunny_root + "visual_column_assignment.csv")
lookup = dict(zip(neuron_ids["root_id"].to_list(), neuron_ids.index))
input_indices = torch.tensor(input_neuron_ids["root_id"].map(lookup))
column_ids = torch.tensor(input_neuron_ids["column_ids"] - 1)

In [16]:
num_neurons = len(neuron_ids)
num_input_neurons = len(input_neuron_ids)

In [17]:
def save_connectivity():
    connectivity_hh = np.load(sunny_root + "visual_adj_matrix_20240711.npy")
    connectivity_hh = torch.tensor(connectivity_hh).to_sparse_coo()

    # Load connectivity_ih
    indices_ih = torch.stack(
        (
            input_indices,
            torch.arange(input_neuron_ids.shape[0]),
        )
    )

    values_ih = torch.ones(indices_ih.shape[1])

    connectivity_ih = torch.sparse_coo_tensor(
        indices_ih,
        values_ih,
        (num_neurons, num_input_neurons),
        check_invariants=True,
    ).coalesce()

    # Save connectivity
    torch.save(connectivity_hh, config.model.rnn_kwargs.connectivity_hh)
    torch.save(connectivity_ih, config.model.rnn_kwargs.connectivity_ih)

    # Save input indices
    torch.save(input_indices, config.model.rnn_kwargs.input_indices)


# save_connectivity()

# connectivity_hh = torch.load(
#     config.model.rnn_kwargs.connectivity_hh, weights_only=True
# )
# connectivity_ih = torch.load(
#     config.model.rnn_kwargs.connectivity_ih, weights_only=True
# )
# input_indices = torch.load(
#     config.model.rnn_kwargs.input_indices, weights_only=True
# )

In [19]:
config.model.rnn_kwargs.bias = False
rnn = initialize_model(config.data.dataset, config.model).rnn

In [20]:
def f(t, y, args):
    h = y.t()

    x = args["x"]
    bias = args["bias"]
    leak = args["leak"]
    T = args["tau"]

    x_t = x[:, int(t * 10), column_ids].t()

    h = 1 / T * (-h * leak + bias + rnn.hh(h) + 10 * rnn.ih(x_t))

    return h.t()


rnn.forward = f

In [21]:
y0 = torch.randn(len(neuron_ids)).unsqueeze(0).to(device)
t_eval = torch.linspace(0, 1, 20).unsqueeze(0).to(device)

term = to.ODETerm(rnn, with_args=True)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
solver = to.AutoDiffAdjoint(step_method, step_size_controller).to(device)
solver = torch.compile(solver)

In [22]:
# Solve ODE
problem = to.InitialValueProblem(y0=y0, t_eval=t_eval)
sol = solver.solve(
    problem,
    args={
        "x": column_data[:1].to(device),
        "bias": 0.1,
        "leak": 20.0,
        "tau": 10.0,
    },
)
ys = sol.ys

In [23]:
ys.shape

torch.Size([1, 20, 47521])