In [12]:
import jax
import jax.numpy as jnp
from jax import random
import optax
import equinox as eqx
from functools import partial
import numpy as np
import matplotlib.pyplot as plt

from foundational_ssm.models import S5
from omegaconf import OmegaConf

In [9]:
conf = {
    'model': {
        'input_dim': 182,
        'output_dim': 2,
        'd_state': 64,
        'num_layers': 2,
        'hidden_dim': 64,
        'dropout': 0.1,
        'ssm_core':'s5'
    },
    'optimizer': {
        'lr': 0.0005,
        'weight_decay': 0.01  # Added common parameter
    },
    'training': {
        'batch_size': 64,
        'epochs': 2000
    },
    'device': 'cuda',
    'framework': 'jax'
}

In [13]:
args = OmegaConf.create(conf)

model_key = random.PRNGKey(0)
model = S5(
    key= model_key,
    num_blocks=args.model.num_layers,
    N=args.model.input_dim,
    ssm_size=args.model.d_state,
    ssm_blocks=1,
    H=args.model.hidden_dim,
    output_dim=args.model.output_dim,
)

In [27]:
eqx.is_inexact_array

<function equinox._filters.is_inexact_array(element: Any) -> bool>

In [None]:
dynamic, static = eqx.partition(model, eqx.is_inexact_array)

Original model type: <class 'foundational_ssm.models.s5.S5'>
Filtered model type: <class 'foundational_ssm.models.s5.S5'>


In [28]:
eqx.partition(model, eqx.is_array)[0].__static_attributes__

('blocks', 'linear_encoder', 'linear_layer')

In [30]:
dynamic, static = eqx.partition(model, eqx.is_array)

dynamic.linear_encoder

Linear(
  weight=f32[64,182],
  bias=f32[64],
  in_features=182,
  out_features=64,
  use_bias=True
)

In [40]:
dynamic.blocks[0].ssm

S5Layer(
  Lambda_re=f32[32],
  Lambda_im=f32[32],
  B=f32[32,64,2],
  C=f32[64,32,2],
  D=f32[64],
  log_step=f32[32,1],
  H=None,
  P=None,
  conj_sym=None,
  clip_eigs=None,
  discretisation=None,
  step_rescale=None
)

In [69]:
leaves, leaves_def = jax.tree.flatten(dynamic)

In [70]:
leaves_def.children()[0].node_data()[1]

('weight', 'bias'), (('in_features', 182), ('out_features', 64), ('use_bias', True))

In [85]:
leaves_def.children()[0]

PyTreeDef(CustomNode(Linear[('weight', 'bias'), (('in_features', 182), ('out_features', 64), ('use_bias', True))], [*, *]))

In [86]:
leaves_def.children()[1].node_data()

(list, None)

In [90]:
import jax.tree_util as jtu

for path, value in jtu.tree_flatten_with_path(model)[0]:
    if  eqx.is_array(value):
        print(".".join(str(p) for p in path), value.shape if hasattr(value, 'shape') else value, eqx.is_array(value))

.linear_encoder..weight (64, 182) True
.linear_encoder..bias (64,) True
.blocks.[0]..norm..first_time_index..init () True
.blocks.[0]..norm..state_index..init.[0] (64,) True
.blocks.[0]..norm..state_index..init.[1] (64,) True
.blocks.[0]..ssm..Lambda_re (32,) True
.blocks.[0]..ssm..Lambda_im (32,) True
.blocks.[0]..ssm..B (32, 64, 2) True
.blocks.[0]..ssm..C (64, 32, 2) True
.blocks.[0]..ssm..D (64,) True
.blocks.[0]..ssm..log_step (32, 1) True
.blocks.[0]..glu..w1..weight (64, 64) True
.blocks.[0]..glu..w1..bias (64,) True
.blocks.[0]..glu..w2..weight (64, 64) True
.blocks.[0]..glu..w2..bias (64,) True
.blocks.[1]..norm..first_time_index..init () True
.blocks.[1]..norm..state_index..init.[0] (64,) True
.blocks.[1]..norm..state_index..init.[1] (64,) True
.blocks.[1]..ssm..Lambda_re (32,) True
.blocks.[1]..ssm..Lambda_im (32,) True
.blocks.[1]..ssm..B (32, 64, 2) True
.blocks.[1]..ssm..C (64, 32, 2) True
.blocks.[1]..ssm..D (64,) True
.blocks.[1]..ssm..log_step (32, 1) True
.blocks.[1].