In [3]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchode as to
from hydra import compose, initialize
from omegaconf import OmegaConf

from bioplnn.utils import AttrDict, initialize_model

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

In [None]:
with initialize(version_base=None, config_path="config/", job_name="testing"):
    config = compose(
        config_name="config",
        overrides=["data=mnist_conn", "model=conn"],
    )
config = OmegaConf.to_container(config, resolve=True)
config = AttrDict(config)

model = initialize_model(**config.model)
rnn = model.rnn
print(
    f"HH sparsity {rnn.hh.values.shape[0] / (rnn.hh.in_features * rnn.hh.out_features):.2%}"
)

In [7]:
# Load data
sunny_root = "/om/user/sunnyd/dros-vision/"
neuron_ids = pd.read_csv(sunny_root + "visual_neuron_ids.csv")
num_neurons = len(neuron_ids)
input_neuron_ids = pd.read_csv(sunny_root + "visual_column_assignment.csv")
num_input_neurons = len(input_neuron_ids)
column_data = torch.tensor(
    np.load(sunny_root + "moving_mnist_vision.npy"), dtype=torch.float32
)
lookup = dict(zip(neuron_ids["root_id"].to_list(), neuron_ids.index))
input_projection = np.eye(num_neurons)[
    :, input_neuron_ids["root_id"].map(lookup).to_numpy()
]
column_ids = input_neuron_ids["column_ids"] - 1
input_indices = torch.tensor(input_neuron_ids["root_id"].map(lookup))
column_ids = torch.tensor(input_neuron_ids["column_ids"] - 1)

In [8]:
def save_connectivity():
    connectivity_hh = np.load(sunny_root + "visual_adj_matrix_20240711.npy")
    connectivity_hh = torch.tensor(connectivity_hh).to_sparse_coo().coalesce()

    # Load connectivity_ih
    indices_ih = torch.stack(
        (
            input_indices,
            torch.arange(input_neuron_ids.shape[0]),
        )
    )

    values_ih = torch.ones(indices_ih.shape[1])

    connectivity_ih = torch.sparse_coo_tensor(
        indices_ih,
        values_ih,
        (num_neurons, num_input_neurons),
        check_invariants=True,
    ).coalesce()

    # Save connectivity
    torch.save(connectivity_hh, config.model.rnn_kwargs.connectivity_hh)  # type: ignore
    torch.save(connectivity_ih, config.model.rnn_kwargs.connectivity_ih)  # type: ignore

    # Save input indices
    torch.save(input_indices, config.model.rnn_kwargs.input_indices)  # type: ignore


# save_connectivity()
# connectivity_hh = torch.load(
#     config.model.rnn_kwargs.connectivity_hh, weights_only=True
# )
# connectivity_ih = torch.load(
#     config.model.rnn_kwargs.connectivity_ih, weights_only=True
# )
# input_indices = torch.load(
#     config.model.rnn_kwargs.input_indices, weights_only=True
# )

# pool = torch.nn.AdaptiveMaxPool2d((100, 100))
# connectivity = pool(connectivity_hh.unsqueeze(0).to_dense().float())
# plt.imshow(
#     connectivity.squeeze(0).detach().cpu().numpy(),
#     cmap="gray",
# )

In [9]:
def f(t: torch.Tensor, y: torch.Tensor, args: dict):
    h = y.t()

    x = args["x"]
    bias = args["bias"]
    leak = args["leak"]
    column_ids = args["column_ids"]
    T = args["tau"]

    x_t = x[:, int(t * 10), column_ids].t()

    h = 1 / T * (-h * leak + bias + rnn.hh(h) + 10 * rnn.ih(x_t))

    return h.t()


rnn.forward = f

In [12]:
x = column_data[:1].to(device)
y0 = torch.randn(len(neuron_ids)).unsqueeze(0).to(device)
t_eval = torch.linspace(0.0, 1.0, 10).unsqueeze(0).to(device)

term = to.ODETerm(rnn, with_args=True)  # type: ignore
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
solver = to.AutoDiffAdjoint(
    step_method,  # type: ignore
    step_size_controller,  # type: ignore
).to(device)
# solver = to.BacksolveAdjoint(term, step_method, step_size_controller).to(
#     device
# )
# solver = torch.compile(solver)

In [None]:
# Solve ODE
torch.cuda.empty_cache()
problem = to.InitialValueProblem(y0=y0, t_eval=t_eval)  # type: ignore
sol = solver.solve(
    problem,
    args={
        "x": x,
        "bias": 0.1,
        "leak": 20.0,
        "tau": 10.0,
        "column_ids": column_ids,
    },
)
ys = sol.ys.transpose(0, 1)

# Plot
plt.plot(
    ys.squeeze()[:, torch.randint(0, 47521, (25,))].detach().cpu().numpy()
)
plt.show()

In [None]:
preds = model.out_layer(ys[0])
labels = torch.rand((1, 10)).to(device)
print(f"preds: {preds}")
print(f"labels: {labels}")
print(f"preds shape: {preds.shape}")
print(f"labels shape: {labels.shape}")

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss = torch.nn.CrossEntropyLoss()(preds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"loss: {loss}")