## Settings


In [None]:
%env S2_IP=192.168.1.7
%env STM_IP=192.168.1.6
%env MULTI_NODE_BOARD=0

In [None]:
import os
if not all(key in os.environ for key in ["S2_IP", "STM_IP", "MULTI_NODE_BOARD"]):
    raise KeyError("Environment incomplete!")
    
S2_IP = os.environ["S2_IP"]
STM_IP = os.environ["STM_IP"]
MULTI_NODE_BOARD = os.environ["MULTI_NODE_BOARD"] == "1"
print(S2_IP, STM_IP, "multi node board" if FULL_BOARD else "single node board")

## Imports

In [None]:
import numpy as np
import matplotlib.tri as mtri
import matplotlib.pyplot as plt
from spinnaker2 import snn, hardware
from spinnaker2.experiment_backend import BackendSettings, ExperimentBackendType
from spinnaker2.experiment_backends.backend_settings import ROUTING
import scipy.sparse as sp
import scipy.sparse.linalg as spla
import skfem as fem
from skfem.models.poisson import laplace

## Generate mesh

In [None]:
def poisson_unitdisk_variable_f(nrefs: int):
    """
    Solve âˆ‡u = f(x,y) in unit disk, u=0 on boundrary,
    with f(x,y) ) 12 - 60 * (x - 0.25)^2 - 60 * (y + 0.13)^2
    """
    mesh = fem.MeshTri().init_circle(nrefs)
    basis = fem.CellBasis(mesh, fem.ElementTriP1())

    def f_fun(v ,w):
        x, y = w.x
        f_val = -(12 - 60 * (x - 0.25) ** 2 - 60 * (y + 0.13) ** 2)
        return f_val * v

    # Assemble stiffness
    A = laplace.assemble(basis)

    # Assemble RHS
    b = fem.asm(fem.LinearForm(f_fun), basis)

    # Assemble drichlet boundrary conditions (u=0 on boundrary)
    D = basis.get_dofs()
    A, b = fem.enforce(A, b,D=D)
    
    return A.tocsr(), b, basis

def float_to_signed_sparse(matrix, x_bits=6, scale=None):
    """
    Quantizize a sparse matrix to signed integers with specific bit-width.
    """
    if not sp.issparse(matrix):
        matrix = sp.csr_matrix(matrix)

    x_bits -= 1 # one bit for sign

    matrix = matrix.astype(np.float32).tocoo() # use float32, keep sparse
    max_int = 2**x_bits - 1
    min_int = -2**x_bits

    # compute scale if not given
    if scale is None:
        max_val = np.max(np.abs(matrix.data))
        scale = max_val / max_int if max_val != 0 else 1.0

    # efficient quantitization in sparse form
    int_data = np.empty_like(matrix.data, dtype=np.int32)
    for i in range(len(matrix.data)):
        val = matrix.data[i] / scale
        int_data[i] = int(np.clip(np.round(val), min_int, max_int))

    int_matrix = sp.coo_matrix((int_data, (matrix.row, matrix.col)), shape=matrix.shape, dtype=np.int32).tocsr()
    return int_matrix, scale


In [None]:
nrefs = 5
A, b, m = poisson_unitdisk_variable_f(nrefs)
A_quant, A_scale = float_to_signed_sparse(A, x_bits=21)

print(f"Number of mesh nodes {len(b)}")

## Solve system using conjugate gradient

In [None]:
# solve system using conjugate gradient
u, _ = spla.cg(A_quant*A_scale, b)

# calculate exact solution using direct solver
u_exact = spla.spsolve(A, b)

tri = mtri.Triangulation(m.mesh.p[0, :], m.mesh.p[1, :], m.mesh.t.T)
fig = plt.figure(figsize=(12, 5))
ax1 = fig.add_subplot(121, projection='3d')
surf = ax1.plot_trisurf(tri, u, cmap='viridis', edgecolor='none')
ax1.set_title('Numerical Solution')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_yticklabels([])
ax1.set_xticklabels([])
ax1.set_zticklabels([])
plt.colorbar(surf, ax=ax1)

ax2 = fig.add_subplot(122, projection='3d')
surf = ax2.plot_trisurf(tri, u_exact, cmap='plasma', edgecolor='none')
ax2.set_title('Exact Solution')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
ax2.set_yticklabels([])
ax2.set_xticklabels([])
ax2.set_zticklabels([])
plt.colorbar(surf, ax=ax2)

plt.show()

## Hyperparams

| Parameter | Value | Description | Effect |
| --------- | ----- | ----------- | ------ |
| npm       | $8$ | neurons per node | |
| r         | $npm \over 2$ | number of neurons with positive sign according to $\Gamma$ | |
| num_timesteps | $5000$ | number of timesteps for the simulation | |
| sys_tick_in_s | $10^{-3}$ | seconds per timestep | |
| gamma ($\vert\Gamma\vert$) | $50$ | $\vert\Gamma\vert$ | increase $\rightarrow$ bigger spike contribs $\rightarrow$ faster movement of $\dot x$ and less energy, possible overshoot (also too small $\rightarrow$ stuck/bias) |
| theta ($\theta$) | ${1 \over 2} \vert \Gamma \vert ^2$ | neuron spike threshold (also: reset by subtraction of $\theta$) | |
| lambda_max ($\lambda_{max}$) | absolute of highest eigenvalue of $A$ |  | |
| dt | ${1 \over 4} \lambda_{max}$ | integration step size. $dt$ should be $< 1 \over {2 \lambda_{max}}$. In the [paper](https://arxiv.org/pdf/2501.10526#page=22&zoom=auto,0,200) $2^{-12}$ is used. | increase $\rightarrow$ faster, fewer steps (more energy efficient), but could get unstable if too high |
| tau ($\tau$) | $30$ | used for $\lambda_d$ and $\lambda_v$ | |
| lambda_d ($\lambda_d$) | $1 \over {\tau dt}$ | readout leak | |
| lambda_v ($\lambda_v$) | $2 \over {\tau dt}$ | membrane leak | |
| omega_n ($\omega_n$) | $2.0$ | ? something with k_p and k_i | increase $\rightarrow$ faster convergence (potential oscillation/instability) |
| zeta ($\zeta$) | $4.0$ | ? something with k_i | increase $\rightarrow$ stronger damping (less oscilation, slower convergence) |
| k_p ($k_p$) | $\omega_n^2$ | ? neuron model param | |
| k_i ($k_i$) | $2 \omega_n \zeta$ | ? neuron model param | |
| sigma ($\sigma_v$) | $0.00225$ | standard derviation of gaussian noise, $\sigma_v$ (from paper) | |
| steady_state | $0.4$ | ? neuron model param and used for obtaining solution | |


In [None]:
npm = 8                            # neurons per node
r = npm // 2                       # num neurons with positive sign

num_timesteps = 5000              # timesteps
sys_tick_in_s = 1e-3              # (default is 1 ms)
gamma = 50
theta = 0.5 * (gamma ** 2)        # neuron spike threshold
lambda_max, _ = spla.eigsh(A, k=1, which='LM')
lambda_max = np.abs(lambda_max[0])
dt = 1.0 / (4 * lambda_max)

tau = 30
lambda_d = 1 / (tau * dt)
lambda_v = 2 / (tau * dt)

omega_n = 2.0
zeta = 4.0
k_p = omega_n**2
k_i = 2 * omega_n * zeta

sigma = 0.00225
steady_state = 0.4

In [None]:
def create_conn(A, npm):
    row_indices, col_indices = A.nonzero()
    nnz = len(row_indices)
    conns = np.zeros((nnz, 4), dtype=int)

    for idx, (i_mesh, j_mesh) in enumerate(zip(row_indices, col_indices)):
        value = A[i_mesh, j_mesh]
        conns[idx] = [
            j_mesh * npm,  # pre_neuron_id
            i_mesh * npm,  # post_neuron_id
            value,
            0
        ]
    return conns.tolist()

nmesh = len(b)
neuron_params = {
    "gb": [b[i] * gamma for i in range(nmesh)], # Mesh param
    "threshold": theta, # Global param
    "scale": A_scale * (gamma ** 2), # Global param
    "dt": dt, # Global param
    "gamma": gamma, # Global param
    "lambda_d": lambda_d, # Global param
    "lambda_v": lambda_v, # Global param
    "k_p": k_p, # Global param
    "k_i": k_i, # Global param
    "sigma": sigma, # Global param
    "steady_state": steady_state, # Global param
}

nb_neurons = nmesh * npm
print(f"Total number of neurons: {nb_neurons}")

In [None]:
cores_per_chip = 148
chips_per_board = 48
max_cores = chips_per_board * cores_per_chip if MULTI_NODE_BOARD else cores_per_chip
meshes_per_core = nmesh // max_cores if nmesh % max_cores == 0 else (nmesh // max_cores) + 1

neurons_per_core = meshes_per_core * npm
nb_cores = nb_neurons // neurons_per_core if nb_neurons % neurons_per_core == 0 else (nb_neurons // neurons_per_core) + 1

print(f"Number of cores used: {nb_cores}")
print(f"Meshes per core: {meshes_per_core}")
print(f"Neurons per core: {neurons_per_core}")

max_neurons_per_core = 2048
if neurons_per_core > max_neurons_per_core:
    raise ValueError(f"Neurons per core ({neurons_per_core}) exceed maximum allowed ({max_neurons_per_core})")

pop = snn.Population(size=nb_neurons, neuron_model="neurofem_2048", params=neuron_params, name="pop", record=["x_mean"])
pop.set_max_atoms_per_core(neurons_per_core)


conns = create_conn(A_quant, npm)

A_density = np.count_nonzero(A_quant) / A_quant.size
print(f"Sparse matrix A density: {A_density*100:.4f}%")

conns_density = len(conns) / ((nmesh * npm)**2)
print(f"Connection density: {conns_density*100:.6f}%")

proj = snn.Projection(pre=pop, post=pop, connections=conns, name="proj")

net = snn.Network("Ax=b Network")
net.add(pop, proj)

In [None]:
settings = BackendSettings()
settings.routing_type = ROUTING.C2C
settings.rebuild_apps = True

hardware_kwargs = {
    "experiment_backend_settings": ExperimentBackendType.SPINNMAN2,
}

if MULTI_NODE_BOARD:
    hardware_kwargs["stm_ip"] = STM_IP
    hw = hardware.SpiNNcloud48NodeBoard(**hardware_kwargs)
else:
    hardwawre_kwargs["eth_ip"] = S2_IP
    hw = hardware.SpiNNaker2Chip(**hardware_kwargs)

hw.run(net, num_timesteps, debug=False, mapping_only=False, sys_tick_in_s=sys_tick_in_s) # low frequency mode (150Hz)

def get_solution(x_means):
    solution = []
    for i in range(nmesh):
        index = i % meshes_per_core + (i // meshes_per_core) * npm * meshes_per_core
        r = x_means[index] / (num_timesteps * steady_state + 1)
        solution.append(r)
    return np.array(solution)

x_mean = pop.get_x_mean()
solution = get_solution(x_mean)

In [None]:
# Plot the numerical solution
fig = plt.figure(figsize=(12, 5))
# Numerical Solution
ax1 = fig.add_subplot(121, projection='3d')
surf = ax1.plot_trisurf(tri, solution, cmap='viridis', edgecolor='none')
ax1.set_title('Numerical Solution (SpiNNaker2)')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_yticklabels([])
ax1.set_xticklabels([])
ax1.set_zticklabels([])
plt.colorbar(surf, ax=ax1)

# Exact Solution
ax2 = fig.add_subplot(122, projection='3d')
surf = ax2.plot_trisurf(tri, u_exact, cmap='plasma', edgecolor='none')
ax2.set_title('Exact Solution')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
ax2.set_yticklabels([])
ax2.set_xticklabels([])
ax2.set_zticklabels([])
plt.colorbar(surf, ax=ax2)

plt.show()