In [1]:
# kan_mamote/test_kan_mammote.py

import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import warnings
from typing import Optional, Tuple, Dict, List

# Suppress all warnings for cleaner output during testing,
# but be cautious in production code.
warnings.filterwarnings('ignore')


# --- Import Project Modules ---
try:
    from src.utils.config import KANMAMMOTEConfig
    from src.models.kan_mammote import KANMAMMOTE
    from src.models.continuous_mamba_block import ContinuousMambaBlock
    from src.models.k_mote import K_MOTE
    from src.models.moe_router import MoERouter
    from src.layers.basis_functions import FourierBasis, GaussianKernelBasis, WaveletBasis
    from src.models.kan.MatrixKANLayer import MatrixKANLayer # For Spline, explicit import
    from src.layers.dynamic_mamba_ssm import DynamicMambaSSM # The Mamba core itself
    from src.models.regularization import KANMAMMOTE_RegularizationLosses

    print("✅ All KAN-MAMMOTE modules imported successfully!")
except ImportError as e:
    print(f"❌ FATAL ERROR: Failed to import KAN-MAMMOTE modules. Please check your file structure and PYTHONPATH.")
    print(f"Error details: {e}")
    sys.exit(1)

# --- Global Test Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32 # Use float32 for most tests, switch to bfloat16/float16 for Mamba if needed

# kan_mamote/test_kan_mammote.py (or test_new.ipynb code block)

# ... (imports)

DEFAULT_CONFIG = KANMAMMOTEConfig(
    d_model=128,
    D_time=64, # Keep D_time as 64
    num_layers=2,
    input_feature_dim=10,
    output_dim_for_task=1,
    K_top=4,
    use_aux_features_router=True,
    raw_event_feature_dim=5,
    # --- FIX for Mamba Stability/Standard Configs ---
    mamba_d_state=128,  # A common d_state, often d_model
    mamba_d_conv=4,
    mamba_expand=2,
    mamba_headdim=32,   # nheads = d_ssm / 32
    mamba_d_ssm=None,   # Let d_ssm default to d_inner (d_inner = 2*128 = 256)
                        # So, d_ssm = 256. nheads = 256 / 32 = 8.
                        # dt_modulation_proj will be 64 -> 8.
    # ------------------------------------------------
    mamba_dt_min=0.001,
    mamba_dt_max=0.1,
    mamba_dt_init_floor=1e-4,
    mamba_bias=False,
    mamba_conv_bias=True,
    mamba_chunk_size=256,
    mamba_use_mem_eff_path=True,
    mamba_layer_idx=None,
    spline_grid_size=8,
    spline_degree=3,
    device=DEVICE,
    dtype=DTYPE
)
# ... (rest of the test file)

print(f"🔧 Testing on device: {DEVICE} with dtype: {DTYPE}")
print(f"🔧 Using default config: d_model={DEFAULT_CONFIG.d_model}, D_time={DEFAULT_CONFIG.D_time}, num_layers={DEFAULT_CONFIG.num_layers}")

# --- Helper Function for Running Tests ---
def run_test(name: str, test_func: callable) -> bool:
    print(f"\n--- Running Test: {name} ---")
    try:
        result = test_func()
        if result:
            print(f"✅ Test PASSED: {name}")
            return True
        else:
            print(f"⚠️ Test COMPLETED WITH WARNINGS: {name}")
            return True # Consider a warning as non-fatal for test suite completion
    except Exception as e:
        print(f"❌ Test FAILED: {name}")
        print(f"   Error: {e}")
        # import traceback
        # traceback.print_exc() # Uncomment for full traceback on failure
        return False

# --- Helper Function for Data Generation ---
def create_dummy_data(batch_size: int, seq_len: int, input_dim: int, aux_dim: int = 0) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    timestamps = torch.randn(batch_size, seq_len, device=DEVICE, dtype=DTYPE) # Use raw random normal
    # Apply normalization to timestamps, e.g., to [-1, 1]
    # For random normal, [-3, 3] usually covers 99.7% of data.
    # Normalize to [-1, 1] range:
    timestamps = torch.clamp(timestamps, -3.0, 3.0) / 3.0 # Normalize to approx [-1, 1]
    # Or if you want [0,1]:
    # timestamps = (timestamps - timestamps.min()) / (timestamps.max() - timestamps.min()) # Min-max over current batch
    
    # NOTE: For real data, you should calculate min/max over the *entire* dataset and use those fixed values.
    # For dummy data, a simple clamp/divide or in-batch min/max is fine for testing.
    
    features = torch.randn(batch_size, seq_len, input_dim, device=DEVICE, dtype=DTYPE)
    
    aux_features = None
    if aux_dim > 0:
        aux_features = torch.randn(batch_size, seq_len, aux_dim, device=DEVICE, dtype=DTYPE)
    
    return timestamps, features, aux_features

# --- Individual Component Test Functions ---

def test_01_config_initialization():
    """Tests if KANMAMMOTEConfig can be initialized and holds expected attributes."""
    config = KANMAMMOTEConfig(d_model=64, K_top=4, device='cpu')
    assert hasattr(config, 'd_model') and config.d_model == 64
    assert hasattr(config, 'K_top') and config.K_top == 4
    assert hasattr(config, 'device') and config.device == 'cpu'
    return True

def test_02_kan_mammote_initialization():
    """Tests if the full KANMAMMOTE model can be initialized."""
    model = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE)
    param_count = sum(p.numel() for p in model.parameters())
    assert param_count > 0, "Model should have parameters."
    assert next(model.parameters()).device == DEVICE, "Model not on correct device."
    print(f"   Model initialized successfully with {param_count:,} parameters.")
    return True

def test_03_kan_mammote_forward_pass_full_sequence():
    """Tests the full KANMAMMOTE forward pass with a sequence of data."""
    model = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE)
    model.eval() # Set to eval mode

    batch_size, seq_len = 4, 16
    timestamps, features, aux_features = create_dummy_data(
        batch_size, seq_len, DEFAULT_CONFIG.input_feature_dim, DEFAULT_CONFIG.raw_event_feature_dim
    )

    with torch.no_grad():
        outputs, regularization_losses = model(timestamps, features, aux_features)

    expected_output_shape = (batch_size, seq_len, DEFAULT_CONFIG.output_dim_for_task)
    assert outputs.shape == expected_output_shape, \
        f"Output shape mismatch: Expected {expected_output_shape}, got {outputs.shape}"
    assert not torch.isnan(outputs).any(), "NaN values detected in model outputs."
    assert not torch.isinf(outputs).any(), "Inf values detected in model outputs."
    assert isinstance(regularization_losses, dict), "Regularization losses should be a dictionary."
    assert 'load_balance_loss' in regularization_losses, "Load balance loss not found."
    assert regularization_losses['load_balance_loss'].item() >= 0, "Load balance loss should be non-negative."
    print(f"   KANMAMMOTE full forward pass successful. Output shape: {outputs.shape}")
    return True

def test_04_kan_mammote_backward_pass_full_sequence():
    """Tests backward pass for the full KANMAMMOTE model."""
    model = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE)
    model.train() # Set to training mode

    batch_size, seq_len = 4, 16
    timestamps, features, aux_features = create_dummy_data(
        batch_size, seq_len, DEFAULT_CONFIG.input_feature_dim, DEFAULT_CONFIG.raw_event_feature_dim
    )
    
    # Generate dummy target matching model output shape
    # Need a forward pass to get actual output shape first
    with torch.no_grad():
        sample_outputs, _ = model(timestamps, features, aux_features)
    target = torch.randn_like(sample_outputs, device=DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    optimizer.zero_grad()

    outputs, regularization_losses = model(timestamps, features, aux_features)
    task_loss = nn.MSELoss()(outputs, target)
    total_loss = task_loss + sum(reg_loss for reg_loss in regularization_losses.values())
    total_loss.backward()

    # Check for gradients
    grads_exist = 0
    nan_grads = 0
    inf_grads = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                grads_exist += 1
                if torch.isnan(param.grad).any():
                    nan_grads += 1
                if torch.isinf(param.grad).any():
                    inf_grads += 1
            else:
                print(f"   Warning: Parameter {name} has requires_grad=True but no grad computed.")

    assert grads_exist > 0, "No gradients computed for any trainable parameter."
    assert nan_grads == 0, f"NaN gradients detected in {nan_grads} parameters."
    assert inf_grads == 0, f"Inf gradients detected in {inf_grads} parameters."
    print(f"   KANMAMMOTE backward pass successful. Gradients computed for {grads_exist} parameters.")
    return True

def test_05_kan_mammote_training_step():
    """Tests if KANMAMMOTE parameters update after a training step."""
    model = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Store initial parameters
    initial_params = [p.clone().detach() for p in model.parameters()]

    batch_size, seq_len = 4, 16
    timestamps, features, aux_features = create_dummy_data(
        batch_size, seq_len, DEFAULT_CONFIG.input_feature_dim, DEFAULT_CONFIG.raw_event_feature_dim
    )
    
    # Generate dummy target
    model.eval()
    with torch.no_grad():
        sample_outputs, _ = model(timestamps, features, aux_features)
    target = torch.randn_like(sample_outputs, device=DEVICE)

    model.train()
    optimizer.zero_grad()
    outputs, regularization_losses = model(timestamps, features, aux_features)
    task_loss = nn.MSELoss()(outputs, target)
    total_loss = task_loss + sum(reg_loss for reg_loss in regularization_losses.values())
    total_loss.backward()
    optimizer.step()

    params_changed = 0
    for initial_p, current_p in zip(initial_params, model.parameters()):
        if not torch.equal(initial_p, current_p):
            params_changed += 1
    
    assert params_changed > 0, "No parameters changed after optimizer step."
    print(f"   KANMAMMOTE training step successful. {params_changed} parameters changed.")
    return True

def test_06_continuous_mamba_block_forward():
    """Tests a single ContinuousMambaBlock's forward pass including state management."""
    block = ContinuousMambaBlock(d_model=DEFAULT_CONFIG.d_model, config=DEFAULT_CONFIG).to(DEVICE)
    block.eval()

    batch_size = 8
    # Dummy inputs for a single timestep
    uk_current = torch.randn(batch_size, DEFAULT_CONFIG.d_model, device=DEVICE, dtype=DTYPE)
    tk_current = torch.randn(batch_size, 1, device=DEVICE, dtype=DTYPE)
    tk_previous = torch.randn(batch_size, 1, device=DEVICE, dtype=DTYPE)

    # Initialize states (first time step, so all zeros)
    # Re-derive state shapes from Mamba2 logic for robustness
    d_state = DEFAULT_CONFIG.mamba_d_state
    d_conv = DEFAULT_CONFIG.mamba_d_conv
    mamba_expand = DEFAULT_CONFIG.mamba_expand
    nheads = DEFAULT_CONFIG.d_model // DEFAULT_CONFIG.mamba_headdim # nheads calculation
    headdim = DEFAULT_CONFIG.mamba_headdim
    d_inner_for_conv_state = DEFAULT_CONFIG.d_model * mamba_expand

    current_conv_state = torch.zeros(batch_size, d_inner_for_conv_state, d_conv - 1, device=DEVICE, dtype=DTYPE)
    current_ssm_state = torch.zeros(batch_size, nheads, headdim, d_state, device=DEVICE, dtype=DTYPE)

    with torch.no_grad():
        hk_current, expert_weights, expert_mask, next_conv_state, next_ssm_state = block(
            uk_current, tk_current, tk_previous, current_conv_state, current_ssm_state
        )

    expected_hk_shape = (batch_size, DEFAULT_CONFIG.d_model)
    assert hk_current.shape == expected_hk_shape, f"hk_current shape mismatch: {hk_current.shape}"
    assert next_conv_state.shape == current_conv_state.shape, "next_conv_state shape mismatch."
    assert next_ssm_state.shape == current_ssm_state.shape, "next_ssm_state shape mismatch."
    assert not torch.equal(current_conv_state, next_conv_state), "Conv state did not update."
    assert not torch.equal(current_ssm_state, next_ssm_state), "SSM state did not update."
    
    assert not torch.isnan(hk_current).any(), "NaN in hk_current."
    assert not torch.isinf(hk_current).any(), "Inf in hk_current."

    print(f"   ContinuousMambaBlock forward successful. hk_current shape: {hk_current.shape}")
    print(f"   States updated from input states: Conv update detected: {not torch.equal(current_conv_state, next_conv_state)}, SSM update detected: {not torch.equal(current_ssm_state, next_ssm_state)}")
    return True


def test_07_k_mote_forward_pass():
    """Tests K_MOTE's forward pass with and without auxiliary features."""
    k_mote = K_MOTE(DEFAULT_CONFIG).to(DEVICE)
    k_mote.eval()

    batch_size = 8
    timestamp_input = torch.randn(batch_size, 1, device=DEVICE, dtype=DTYPE)
    auxiliary_features = torch.randn(batch_size, DEFAULT_CONFIG.raw_event_feature_dim, device=DEVICE, dtype=DTYPE)

    with torch.no_grad():
        # Test with auxiliary features
        emb_with_aux, weights_with_aux, mask_with_aux = k_mote(timestamp_input, auxiliary_features)
        assert emb_with_aux.shape == (batch_size, DEFAULT_CONFIG.D_time), "K_MOTE output shape mismatch with aux."
        assert weights_with_aux.shape == (batch_size, 4), "K_MOTE weights shape mismatch with aux."
        assert mask_with_aux.shape == (batch_size, 4), "K_MOTE mask shape mismatch with aux."
        assert not torch.isnan(emb_with_aux).any(), "NaN in K_MOTE output with aux."

        # Test without auxiliary features (pass None)
        emb_no_aux, weights_no_aux, mask_no_aux = k_mote(timestamp_input, None)
        assert emb_no_aux.shape == (batch_size, DEFAULT_CONFIG.D_time), "K_MOTE output shape mismatch without aux."
        print("   K_MOTE forward pass successful with and without auxiliary features.")
        return True

def test_08_moe_router_forward_pass():
    """Tests MoERouter's forward pass."""
    router = MoERouter(input_dim=(1 + DEFAULT_CONFIG.raw_event_feature_dim), num_experts=4, config=DEFAULT_CONFIG).to(DEVICE)
    router.eval()

    batch_size = 8
    timestamp_input = torch.randn(batch_size, 1, device=DEVICE, dtype=DTYPE)
    auxiliary_features = torch.randn(batch_size, DEFAULT_CONFIG.raw_event_feature_dim, device=DEVICE, dtype=DTYPE)
    router_input = torch.cat([timestamp_input, auxiliary_features], dim=1)

    with torch.no_grad():
        logits, weights = router(timestamp_input, auxiliary_features)
    
    assert logits.shape == (batch_size, 4), "Router logits shape mismatch."
    assert weights.shape == (batch_size, 4), "Router weights shape mismatch."
    assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size, device=DEVICE)), "Router weights do not sum to 1."
    assert not torch.isnan(logits).any(), "NaN in router logits."
    print("   MoERouter forward pass successful.")
    return True

def test_09_basis_functions_forward_pass():
    """Tests forward pass for each individual basis function."""
    # Note: KANLayer internally transforms the 1D input to D_time dimensions first.
    # So we provide (N, D_T) to the BasisFunction, not (N, 1).
    # This aligns with how KANLayer's forward calls basis_function.
    
    batch_size = 8
    d_time = DEFAULT_CONFIG.D_time
    test_input = torch.randn(batch_size, d_time, device=DEVICE, dtype=DTYPE) # Input to basis function from KANLayer's x_prime

    # Fourier Basis
    fourier_basis = FourierBasis(d_time, DEFAULT_CONFIG).to(DEVICE)
    fourier_output = fourier_basis(test_input)
    assert fourier_output.shape == (batch_size, d_time), "Fourier Basis output shape mismatch."
    assert not torch.isnan(fourier_output).any(), "NaN in Fourier Basis output."

    # Spline Basis (MatrixKANLayer)
    # MatrixKANLayer takes (N, in_dim) and returns (y, ...)
    spline_basis = MatrixKANLayer(in_dim=1, out_dim=d_time, num=DEFAULT_CONFIG.spline_grid_size, k=DEFAULT_CONFIG.spline_degree, device=DEVICE).to(DEVICE)
    # MatrixKANLayer directly expects (N, 1) if its in_dim is 1
    # We should test it with what K_MOTE provides to it directly (timestamp_input)
    timestamp_input_for_spline = torch.randn(batch_size, 1, device=DEVICE, dtype=DTYPE)
    spline_output_tuple = spline_basis(timestamp_input_for_spline) # This returns a tuple
    spline_output = spline_output_tuple[0] # The main output
    assert spline_output.shape == (batch_size, d_time), "Spline Basis output shape mismatch."
    assert not torch.isnan(spline_output).any(), "NaN in Spline Basis output."

    # Gaussian Kernel Basis
    gaussian_basis = GaussianKernelBasis(d_time, DEFAULT_CONFIG).to(DEVICE)
    gaussian_output = gaussian_basis(test_input)
    assert gaussian_output.shape == (batch_size, d_time), "Gaussian Basis output shape mismatch."
    assert not torch.isnan(gaussian_output).any(), "NaN in Gaussian Basis output."

    # Wavelet Basis
    wavelet_basis = WaveletBasis(d_time, DEFAULT_CONFIG).to(DEVICE)
    wavelet_output = wavelet_basis(test_input)
    assert wavelet_output.shape == (batch_size, d_time), "Wavelet Basis output shape mismatch."
    assert not torch.isnan(wavelet_output).any(), "NaN in Wavelet Basis output."

    print("   All Basis Functions forward passes successful.")
    return True


def test_10_dynamic_mamba_ssm_step_method():
    """Tests DynamicMambaSSM's step method for single-token processing with states."""
    # Mamba2's d_model refers to the overall input/output dimension,
    # but its internal calculations for state depend on headdim, d_state, expand.
    # The Mamba2 class itself uses d_model as its 'hidden_size'.
    
    d_model_mamba = DEFAULT_CONFIG.d_model
    k_mote_delta_t_embedding_dim = DEFAULT_CONFIG.D_time # The dim of delta_t_embedding

    mamba_ssm = DynamicMambaSSM(
        d_model=d_model_mamba,
        k_mote_delta_t_embedding_dim=k_mote_delta_t_embedding_dim,
        d_state=DEFAULT_CONFIG.mamba_d_state,
        d_conv=DEFAULT_CONFIG.mamba_d_conv,
        expand=DEFAULT_CONFIG.mamba_expand,
        headdim=DEFAULT_CONFIG.mamba_headdim,
        dt_min=DEFAULT_CONFIG.mamba_dt_min,
        dt_max=DEFAULT_CONFIG.mamba_dt_max,
        dt_init_floor=DEFAULT_CONFIG.mamba_dt_init_floor,
        bias=DEFAULT_CONFIG.mamba_bias,
        conv_bias=DEFAULT_CONFIG.mamba_conv_bias,
        device=DEVICE,
        dtype=DTYPE
    ).to(DEVICE)
    mamba_ssm.eval()

    batch_size = 8
    # Input to step method (single token)
    hidden_states_step = torch.randn(batch_size, 1, d_model_mamba, device=DEVICE, dtype=DTYPE)
    
    # dt_modulation_step comes from dt_modulation_proj(delta_t_embedding) (B, nheads)
    nheads = d_model_mamba // DEFAULT_CONFIG.mamba_headdim
    dt_modulation_step_input = torch.randn(batch_size, nheads, device=DEVICE, dtype=DTYPE)

    # Initialize states (zeros for first step)
    d_state = DEFAULT_CONFIG.mamba_d_state
    d_conv = DEFAULT_CONFIG.mamba_d_conv
    mamba_expand = DEFAULT_CONFIG.mamba_expand
    d_inner_for_conv_state = d_model_mamba * mamba_expand

    conv_state = torch.zeros(batch_size, d_inner_for_conv_state, d_conv - 1, device=DEVICE, dtype=DTYPE)
    ssm_state = torch.zeros(batch_size, nheads, DEFAULT_CONFIG.mamba_headdim, d_state, device=DEVICE, dtype=DTYPE)

    with torch.no_grad():
        output, next_conv_state, next_ssm_state = mamba_ssm.step(
            hidden_states=hidden_states_step,
            conv_state=conv_state,
            ssm_state=ssm_state,
            dt_modulation_step=dt_modulation_step_input
        )
    
    assert output.shape == (batch_size, 1, d_model_mamba), "DynamicMambaSSM step output shape mismatch."
    assert next_conv_state.shape == conv_state.shape, "DynamicMambaSSM next_conv_state shape mismatch."
    assert next_ssm_state.shape == ssm_state.shape, "DynamicMambaSSM next_ssm_state shape mismatch."
    assert not torch.equal(conv_state, next_conv_state), "Conv state did not update in DynamicMambaSSM.step."
    assert not torch.equal(ssm_state, next_ssm_state), "SSM state did not update in DynamicMambaSSM.step."

    assert not torch.isnan(output).any(), "NaN in DynamicMambaSSM step output."
    print("   DynamicMambaSSM step method successful with state updates.")
    return True

def test_11_regularization_losses():
    """Tests the regularization loss computations."""
    reg_handler = KANMAMMOTE_RegularizationLosses(DEFAULT_CONFIG)
    
    # Test load_balance_loss
    batch_size = 16
    num_experts = 4 # Hardcoded in K_MOTE
    dummy_expert_weights = torch.rand(batch_size, num_experts, device=DEVICE, dtype=DTYPE)
    # Make them sum to 1 per row for realistic softmax weights
    dummy_expert_weights = dummy_expert_weights / dummy_expert_weights.sum(dim=-1, keepdim=True)

    load_loss = reg_handler.compute_load_balance_loss(dummy_expert_weights)
    assert isinstance(load_loss, torch.Tensor), "Load balance loss is not a tensor."
    assert load_loss.item() >= 0, "Load balance loss should be non-negative."
    print(f"   Load Balance Loss computed: {load_loss.item():.4f}")

    # Test Sobolev and Total Variation (just check if they run without error as they are stubs)
    model_for_stub_loss = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE) # Need a model instance for these stubs
    sobolev_loss = reg_handler.compute_sobolev_l2_loss(model_for_stub_loss)
    total_variation_loss = reg_handler.compute_total_variation_loss(model_for_stub_loss)
    
    assert isinstance(sobolev_loss, torch.Tensor) and sobolev_loss.item() == 0.0, "Sobolev loss stub not working as expected."
    assert isinstance(total_variation_loss, torch.Tensor) and total_variation_loss.item() == 0.0, "Total Variation loss stub not working as expected."
    print("   Sobolev L2 and Total Variation loss stubs ran successfully.")
    return True

DEBUG: Triton Kernels Available: True
✅ All KAN-MAMMOTE modules imported successfully!
🔧 Testing on device: cuda with dtype: torch.float32
🔧 Using default config: d_model=128, D_time=64, num_layers=2


In [2]:

import traceback
# --- Helper Function for Running Tests ---
def run_test(name: str, test_func: callable) -> bool:
    print(f"\n--- Running Test: {name} ---")
    try:
        result = test_func()
        if result:
            print(f"✅ Test PASSED: {name}")
            return True
        else:
            print(f"⚠️ Test COMPLETED WITH WARNINGS: {name}")
            return True # Consider a warning as non-fatal for test suite completion
    except Exception as e:
        print(f"❌ Test FAILED: {name}")
        print(f"   Error: {e}")
        # import traceback
        # traceback.print_exc() # Uncomment for full traceback on failure
        return False

# --- Helper Function for Data Generation ---
def create_dummy_data(batch_size: int, seq_len: int, input_dim: int, aux_dim: int = 0) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    timestamps = torch.randn(batch_size, seq_len, device=DEVICE, dtype=DTYPE) * 10.0 # Random timestamps
    features = torch.randn(batch_size, seq_len, input_dim, device=DEVICE, dtype=DTYPE)
    
    aux_features = None
    if aux_dim > 0:
        aux_features = torch.randn(batch_size, seq_len, aux_dim, device=DEVICE, dtype=DTYPE)
    
    return timestamps, features, aux_features

# --- Individual Component Test Functions ---

def test_01_config_initialization():
    """Tests if KANMAMMOTEConfig can be initialized and holds expected attributes."""
    config = KANMAMMOTEConfig(d_model=64, K_top=4, device='cpu')
    assert hasattr(config, 'd_model') and config.d_model == 64
    assert hasattr(config, 'K_top') and config.K_top == 4
    assert hasattr(config, 'device') and config.device == 'cpu'
    return True

def test_02_kan_mammote_initialization():
    """Tests if the full KANMAMMOTE model can be initialized and is on the correct device."""
    config = DEFAULT_CONFIG
    
    model = KANMAMMOTE(config).to(DEVICE) # Ensure model is moved right after creation

    param_count = sum(p.numel() for p in model.parameters())
    assert param_count > 0, "Model should have parameters."

    # FIX for 'not on correct device. Expected cuda, got cuda:0'
    for name, param in model.named_parameters():
        assert param.device.type == DEVICE and \
               (param.device.index == 0 if DEVICE == 'cuda' else True), \
               f"Parameter '{name}' (shape {param.shape}) not on correct device. Expected {DEVICE}, got {param.device}"

    print(f"   Model initialized successfully with {param_count:,} parameters on {DEVICE}.")
    return True

# Helper to create dummy data needs to support auxiliary features
def create_dummy_data(batch_size: int, seq_len: int, input_dim: int, aux_dim: int = 0) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    timestamps = torch.randn(batch_size, seq_len, device=DEVICE, dtype=DTYPE) * 10.0
    features = torch.randn(batch_size, seq_len, input_dim, device=DEVICE, dtype=DTYPE)
    
    aux_features = None
    if aux_dim > 0: # Only create if aux_dim is specified
        aux_features = torch.randn(batch_size, seq_len, aux_dim, device=DEVICE, dtype=DTYPE)
    
    return timestamps, features, aux_features

def test_03_kan_mammote_forward_pass_full_sequence():
    """Tests the full KANMAMMOTE forward pass with a sequence of data."""
    model = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE)
    model.eval()

    batch_size, seq_len = 4, 16
    timestamps, features, aux_features = create_dummy_data(
        batch_size, seq_len, DEFAULT_CONFIG.input_feature_dim, DEFAULT_CONFIG.raw_event_feature_dim
    )

    try:
        with torch.no_grad():
            outputs, regularization_losses = model(timestamps, features, aux_features)

        expected_output_shape = (batch_size, seq_len, DEFAULT_CONFIG.output_dim_for_task)
        assert outputs.shape == expected_output_shape, \
            f"Output shape mismatch: Expected {expected_output_shape}, got {outputs.shape}"
        assert not torch.isnan(outputs).any(), "NaN values detected in model outputs."
        assert not torch.isinf(outputs).any(), "Inf values detected in model outputs."
        assert isinstance(regularization_losses, dict), "Regularization losses should be a dictionary."
        assert 'load_balance_loss' in regularization_losses, "Load balance loss not found."
        assert regularization_losses['load_balance_loss'].item() >= 0, "Load balance loss should be non-negative."
        print(f"   KANMAMMOTE full forward pass successful. Output shape: {outputs.shape}")
        return True
    except Exception as e:
        print(f"❌ Test FAILED: 03 Kan Mammote Forward Pass Full Sequence")
        print(f"   Error: {e}")
        traceback.print_exc()
        return False

def test_04_kan_mammote_backward_pass_full_sequence():
    """Tests backward pass for the full KANMAMMOTE model."""
    model = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE)
    model.train()

    batch_size, seq_len = 4, 16
    # FIX: Ensure aux_features are generated and passed
    timestamps, features, aux_features = create_dummy_data(
        batch_size, seq_len, DEFAULT_CONFIG.input_feature_dim, DEFAULT_CONFIG.raw_event_feature_dim
    )
    
    with torch.no_grad():
        sample_outputs, _ = model(timestamps, features, aux_features) # Pass aux_features
    target = torch.randn_like(sample_outputs, device=DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    optimizer.zero_grad()

    outputs, regularization_losses = model(timestamps, features, aux_features) # Pass aux_features
    task_loss = nn.MSELoss()(outputs, target)
    total_loss = task_loss + sum(reg_loss for reg_loss in regularization_losses.values())
    total_loss.backward()

    grads_exist = 0
    nan_grads = 0
    inf_grads = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                grads_exist += 1
                if torch.isnan(param.grad).any():
                    nan_grads += 1
                if torch.isinf(param.grad).any():
                    inf_grads += 1
            else:
                print(f"   Warning: Parameter {name} has requires_grad=True but no grad computed.")

    assert grads_exist > 0, "No gradients computed for any trainable parameter."
    assert nan_grads == 0, f"NaN gradients detected in {nan_grads} parameters."
    assert inf_grads == 0, f"Inf gradients detected in {inf_grads} parameters."
    print(f"   KANMAMMOTE backward pass successful. Gradients computed for {grads_exist} parameters.")
    return True

def test_05_kan_mammote_training_step():
    """Tests if KANMAMMOTE parameters update after a training step."""
    model = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    initial_params = [p.clone().detach() for p in model.parameters()]

    batch_size, seq_len = 4, 16
    # FIX: Ensure aux_features are generated and passed
    timestamps, features, aux_features = create_dummy_data(
        batch_size, seq_len, DEFAULT_CONFIG.input_feature_dim, DEFAULT_CONFIG.raw_event_feature_dim
    )
    
    model.eval()
    with torch.no_grad():
        sample_outputs, _ = model(timestamps, features, aux_features) # Pass aux_features
    target = torch.randn_like(sample_outputs, device=DEVICE)

    model.train()
    optimizer.zero_grad()
    outputs, regularization_losses = model(timestamps, features, aux_features) # Pass aux_features
    task_loss = nn.MSELoss()(outputs, target)
    total_loss = task_loss + sum(reg_loss for reg_loss in regularization_losses.values())
    total_loss.backward()
    optimizer.step()

    params_changed = 0
    for initial_p, current_p in zip(initial_params, model.parameters()):
        if not torch.equal(initial_p, current_p):
            params_changed += 1
    
    assert params_changed > 0, "No parameters changed after optimizer step."
    print(f"   KANMAMMOTE training step successful. {params_changed} parameters changed.")
    return True

def test_06_continuous_mamba_block_forward():
    """Tests a single ContinuousMambaBlock's forward_sequence pass."""
    block = ContinuousMambaBlock(d_model=DEFAULT_CONFIG.d_model, config=DEFAULT_CONFIG).to(DEVICE)
    block.eval()

    batch_size, seq_len = 8, 16 # Test with a sequence
    
    # Inputs for forward_sequence: hidden_states (u) and delta_t_embedding
    hidden_states_input = torch.randn(batch_size, seq_len, DEFAULT_CONFIG.d_model, device=DEVICE, dtype=DTYPE)
    
    # Generate a dummy delta_t_embedding. Its last dim should be D_time (64).
    dummy_delta_t_embedding = torch.randn(batch_size, seq_len, DEFAULT_CONFIG.D_time, device=DEVICE, dtype=DTYPE)

    with torch.no_grad():
        # FIX: Call forward_sequence directly, passing the correct arguments
        output_sequence = block.forward_sequence(
            hidden_states=hidden_states_input,
            delta_t_embedding=dummy_delta_t_embedding
        )
    
    expected_output_shape = (batch_size, seq_len, DEFAULT_CONFIG.d_model)
    assert output_sequence.shape == expected_output_shape, f"ContinuousMambaBlock forward_sequence output shape mismatch: {output_sequence.shape}"
    assert not torch.isnan(output_sequence).any(), "NaN in ContinuousMambaBlock forward_sequence output."
    assert not torch.isinf(output_sequence).any(), "Inf in ContinuousMambaBlock forward_sequence output."

    print(f"   ContinuousMambaBlock forward_sequence successful. Output shape: {output_sequence.shape}")
    return True




def test_07_k_mote_forward_pass():
    """Tests K_MOTE's forward pass with and without auxiliary features."""
    k_mote = K_MOTE(DEFAULT_CONFIG).to(DEVICE)
    k_mote.eval()

    batch_size = 8
    timestamp_input = torch.randn(batch_size, 1, device=DEVICE, dtype=DTYPE)
    auxiliary_features_for_k_mote = torch.randn(batch_size, DEFAULT_CONFIG.raw_event_feature_dim, device=DEVICE, dtype=DTYPE) # FIX: Use a specific variable name for clarity

    with torch.no_grad():
        # Test with auxiliary features (this now should work correctly)
        emb_with_aux, weights_with_aux, mask_with_aux = k_mote(timestamp_input, auxiliary_features_for_k_mote)
        assert emb_with_aux.shape == (batch_size, DEFAULT_CONFIG.D_time), "K_MOTE output shape mismatch with aux."
        assert weights_with_aux.shape == (batch_size, 4), "K_MOTE weights shape mismatch with aux."
        assert mask_with_aux.shape == (batch_size, 4), "K_MOTE mask shape mismatch with aux."
        assert not torch.isnan(emb_with_aux).any(), "NaN in K_MOTE output with aux."

        # Test without auxiliary features (pass None)
        # Ensure DEFAULT_CONFIG.use_aux_features_router is respected by K_MOTE's internal logic
        # (router will just use timestamp_input)
        emb_no_aux, weights_no_aux, mask_no_aux = k_mote(timestamp_input, None)
        assert emb_no_aux.shape == (batch_size, DEFAULT_CONFIG.D_time), "K_MOTE output shape mismatch without aux."
        print("   K_MOTE forward pass successful with and without auxiliary features.")
        return True

def test_08_moe_router_forward_pass():
    """Tests MoERouter's forward pass."""
    # Router expects (1 + raw_event_feature_dim) as input_dim
    router_input_dim_calc = 1 + DEFAULT_CONFIG.raw_event_feature_dim if DEFAULT_CONFIG.use_aux_features_router else 1
    router = MoERouter(input_dim=router_input_dim_calc, num_experts=4, config=DEFAULT_CONFIG).to(DEVICE)
    router.eval()

    batch_size = 8
    timestamp_input = torch.randn(batch_size, 1, device=DEVICE, dtype=DTYPE)
    
    # FIX: Prepare router_input according to use_aux_features_router
    if DEFAULT_CONFIG.use_aux_features_router and DEFAULT_CONFIG.raw_event_feature_dim > 0:
        auxiliary_features_for_router = torch.randn(batch_size, DEFAULT_CONFIG.raw_event_feature_dim, device=DEVICE, dtype=DTYPE)
        # Pass the auxiliary features explicitly
        logits, weights = router(timestamp_input, auxiliary_features_for_router)
    else:
        # Pass None if aux features are not used
        logits, weights = router(timestamp_input, None) # This path already works fine for input_dim=1

    assert logits.shape == (batch_size, 4), "Router logits shape mismatch."
    assert weights.shape == (batch_size, 4), "Router weights shape mismatch."
    assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size, device=DEVICE)), "Router weights do not sum to 1."
    assert not torch.isnan(logits).any(), "NaN in router logits."
    print("   MoERouter forward pass successful.")
    return True

# ... (tests 09, 10, 11 - no changes to these tests' logic) ...
# (The prints I added for KANLayer and K_MOTE still need to be in your
#  src/layers/kan_base_layer.py and src/models/k_mote.py files)

def test_10_dynamic_mamba_ssm_step_method():
    """Tests DynamicMambaSSM's step method for single-token processing with states."""
    import traceback

    print("\n=== DynamicMambaSSM Step Method Test ===")
    d_model_mamba = DEFAULT_CONFIG.d_model
    k_mote_delta_t_embedding_dim = DEFAULT_CONFIG.D_time

    print("-> Initializing DynamicMambaSSM module...")
    mamba_ssm = DynamicMambaSSM(
        d_model=d_model_mamba,
        k_mote_delta_t_embedding_dim=k_mote_delta_t_embedding_dim,
        d_state=DEFAULT_CONFIG.mamba_d_state,
        d_conv=DEFAULT_CONFIG.mamba_d_conv,
        expand=DEFAULT_CONFIG.mamba_expand,
        headdim=DEFAULT_CONFIG.mamba_headdim,
        dt_min=DEFAULT_CONFIG.mamba_dt_min,
        dt_max=DEFAULT_CONFIG.mamba_dt_max,
        dt_init_floor=DEFAULT_CONFIG.mamba_dt_init_floor,
        bias=DEFAULT_CONFIG.mamba_bias,
        conv_bias=DEFAULT_CONFIG.mamba_conv_bias,
        d_ssm=DEFAULT_CONFIG.mamba_d_ssm,
        device=DEVICE,
        dtype=DTYPE
    ).to(DEVICE)
    mamba_ssm.eval()

    batch_size = 8
    print("-> Generating input tensors...")
    hidden_states_step = torch.randn(batch_size, 1, d_model_mamba, device=DEVICE, dtype=DTYPE)
    d_state = DEFAULT_CONFIG.mamba_d_state
    d_conv = DEFAULT_CONFIG.mamba_d_conv
    mamba_expand = DEFAULT_CONFIG.mamba_expand
    d_inner_effective = mamba_expand * d_model_mamba
    d_ssm_effective = d_inner_effective if DEFAULT_CONFIG.mamba_d_ssm is None else DEFAULT_CONFIG.mamba_d_ssm
    ngroups_effective = 1
    conv_channels_for_state = d_ssm_effective + 2 * ngroups_effective * d_state
    nheads = d_ssm_effective // DEFAULT_CONFIG.mamba_headdim
    dt_modulation_step_input = torch.randn(batch_size, nheads, device=DEVICE, dtype=DTYPE)
    conv_state = torch.zeros(batch_size, conv_channels_for_state, d_conv - 1, device=DEVICE, dtype=DTYPE)
    initial_conv_state_copy = conv_state.clone().detach()
    ssm_state = torch.zeros(batch_size, nheads, DEFAULT_CONFIG.mamba_headdim, d_state, device=DEVICE, dtype=DTYPE)
    initial_ssm_state_copy = ssm_state.clone().detach()

    print(f"   [Input] hidden_states_step shape: {hidden_states_step.shape}, min: {hidden_states_step.min().item():.4f}, max: {hidden_states_step.max().item():.4f}")
    print(f"   [Input] dt_modulation_step_input shape: {dt_modulation_step_input.shape}")
    print(f"   [State] conv_state shape: {conv_state.shape}, sum: {conv_state.sum().item():.4f}")
    print(f"   [State] ssm_state shape: {ssm_state.shape}, sum: {ssm_state.sum().item():.4f}")

    try:
        print("-> Running DynamicMambaSSM.step()...")
        with torch.no_grad():
            output, next_conv_state, next_ssm_state = mamba_ssm.step(
                hidden_states=hidden_states_step,
                conv_state=conv_state,
                ssm_state=ssm_state,
                dt_modulation_step=dt_modulation_step_input
            )

        print(f"   [Output] output shape: {output.shape}, min: {output.min().item():.4f}, max: {output.max().item():.4f}, abs_sum: {output.abs().sum().item():.4f}")
        print(f"   [State] next_conv_state shape: {next_conv_state.shape}, changed: {not torch.equal(initial_conv_state_copy, next_conv_state)}")
        print(f"   [State] next_ssm_state shape: {next_ssm_state.shape}, changed: {not torch.equal(initial_ssm_state_copy, next_ssm_state)}")

        assert output.shape == (batch_size, 1, d_model_mamba), f"DynamicMambaSSM step output shape mismatch: {output.shape}"
        assert next_conv_state.shape == conv_state.shape, f"DynamicMambaSSM next_conv_state shape mismatch: {next_conv_state.shape}"
        assert next_ssm_state.shape == ssm_state.shape, f"DynamicMambaSSM next_ssm_state shape mismatch: {next_ssm_state.shape}"
        assert not torch.equal(initial_conv_state_copy, next_conv_state), "Conv state did not update in DynamicMambaSSM.step."
        assert not torch.equal(initial_ssm_state_copy, next_ssm_state), "SSM state did not update in DynamicMambaSSM.step."
        assert not torch.isnan(output).any(), "NaN in DynamicMambaSSM step output."

        print("✅ DynamicMambaSSM step method successful with state updates and valid output.")
        return True
    except Exception as e:
        print("❌ DynamicMambaSSM step method FAILED.")
        print(f"   Error: {e}")
        traceback.print_exc()
        return False

def test_11_regularization_losses():
    """Tests the regularization loss computations."""
    reg_handler = KANMAMMOTE_RegularizationLosses(DEFAULT_CONFIG)
    
    # Test load_balance_loss
    batch_size = 16
    num_experts = 4 # Hardcoded in K_MOTE
    dummy_expert_weights = torch.rand(batch_size, num_experts, device=DEVICE, dtype=DTYPE)
    # Make them sum to 1 per row for realistic softmax weights
    dummy_expert_weights = dummy_expert_weights / dummy_expert_weights.sum(dim=-1, keepdim=True)

    load_loss = reg_handler.compute_load_balance_loss(dummy_expert_weights)
    assert isinstance(load_loss, torch.Tensor), "Load balance loss is not a tensor."
    assert load_loss.item() >= 0, "Load balance loss should be non-negative."
    print(f"   Load Balance Loss computed: {load_loss.item():.4f}")

    # Test Sobolev and Total Variation (just check if they run without error as they are stubs)
    model_for_stub_loss = KANMAMMOTE(DEFAULT_CONFIG).to(DEVICE) # Need a model instance for these stubs
    sobolev_loss = reg_handler.compute_sobolev_l2_loss(model_for_stub_loss)
    total_variation_loss = reg_handler.compute_total_variation_loss(model_for_stub_loss)
    
    assert isinstance(sobolev_loss, torch.Tensor) and sobolev_loss.item() == 0.0, "Sobolev loss stub not working as expected."
    assert isinstance(total_variation_loss, torch.Tensor) and total_variation_loss.item() == 0.0, "Total Variation loss stub not working as expected."
    print("   Sobolev L2 and Total Variation loss stubs ran successfully.")
    return True


# --- Main Execution ---
if __name__ == "__main__":
    all_tests_passed = True
    print("--- Starting KAN-MAMMOTE Component Tests ---")
    print(f"Running tests on device: {DEVICE} with dtype: {DTYPE}")
    print(f"Using default config: d_model={DEFAULT_CONFIG.d_model}, D_time={DEFAULT_CONFIG.D_time}, num_layers={DEFAULT_CONFIG.num_layers}")
    #print config device
    print(f"KANMAMMOTEConfig: {DEFAULT_CONFIG}")
    print(f"KANMAMMOTEConfig device: {DEFAULT_CONFIG.device}, dtype: {DEFAULT_CONFIG.dtype}")
    # List of tests to run
    tests = [
        test_01_config_initialization,
        test_02_kan_mammote_initialization,
        test_03_kan_mammote_forward_pass_full_sequence,
        test_04_kan_mammote_backward_pass_full_sequence,
        test_05_kan_mammote_training_step,
        test_06_continuous_mamba_block_forward, # Tests the updated forward with states
        test_07_k_mote_forward_pass,
        test_08_moe_router_forward_pass,
        test_09_basis_functions_forward_pass,
        test_10_dynamic_mamba_ssm_step_method, # Critical for recurrent flow
    ]

    for test_func in tests:
        if not run_test(test_func.__name__.replace('test_', '').replace('_', ' ').title(), test_func):
            all_tests_passed = False

    print("\n--- All Component Tests Finished ---")
    if all_tests_passed:
        print("✅ CONGRATULATIONS! All KAN-MAMMOTE components appear to be working correctly.")
    else:
        print("❌ WARNING: Some KAN-MAMMOTE components failed or completed with warnings. Please review the logs above.")

--- Starting KAN-MAMMOTE Component Tests ---
Running tests on device: cuda with dtype: torch.float32
Using default config: d_model=128, D_time=64, num_layers=2
KANMAMMOTEConfig: <src.utils.config.KANMAMMOTEConfig object at 0x71d355247950>
KANMAMMOTEConfig device: cuda, dtype: torch.float32

--- Running Test: 01 Config Initialization ---
✅ Test PASSED: 01 Config Initialization

--- Running Test: 02 Kan Mammote Initialization ---
Initializing KAN-MAMMOTE with config: <src.utils.config.KANMAMMOTEConfig object at 0x71d355247950>
DEBUG_DMS_INIT: dt_modulation_proj in_features=64, out_features=8 (expected 64 -> 8)
DEBUG_DMS_INIT: Calculated self.nheads=8
DEBUG_DMS_INIT: dt_modulation_proj in_features=64, out_features=8 (expected 64 -> 8)
DEBUG_DMS_INIT: Calculated self.nheads=8
KANMAMMOTE init: Pre-calculated conv_channels_for_state=512, nheads_for_state=8
   Model initialized successfully with 338,381 parameters on cuda.
✅ Test PASSED: 02 Kan Mammote Initialization

--- Running Test: 03 Kan