In [1]:
import sys
import hydra
from hydra.core.global_hydra import GlobalHydra

sys.path.append('..')

In [2]:
def get_model(model_name, **kwargs):
    # Clear the existing Hydra instance if it is already initialized
    if GlobalHydra().is_initialized():
        GlobalHydra().clear()
    # Initialize Hydra and load the configuration
    hydra.initialize(config_path="../configs/task/context_aggregator")
    assert model_name in ['s4', 'mamba', 'transformer']
    model_cfg = hydra.compose(config_name=model_name)
    model_cfg.x_dim = 128
    for key, value in kwargs.items():
        setattr(model_cfg, key, value)
    model = hydra.utils.instantiate(model_cfg)
    return model

In [6]:
s4 = get_model('s4', mixer_type="S4")
s4d = get_model('s4', mixer_type="S4D")
mamba1 = get_model('mamba', mixer_type="Mamba1")
mamba2 = get_model('mamba', mixer_type="Mamba2")
transf = get_model('transformer')

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path="../configs/task/context_aggregator")


In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

print(f"S4 model parameters: {count_parameters(s4):,}")
print(f"S4D model parameters: {count_parameters(s4d):,}")
print(f"Mamba1 model parameters: {count_parameters(mamba1):,}")
print(f"Mamba2 model parameters: {count_parameters(mamba2):,}")
print(f"Transformer model parameters: {count_parameters(transf):,}")

S4 model parameters: 858,752
S4D model parameters: 727,680
Mamba1 model parameters: 4,131,456
Mamba2 model parameters: 4,039,712
Transformer model parameters: 2,174,592
