In [1]:
cd kan-mammote/

[Errno 2] No such file or directory: 'kan-mammote/'
/mnt/c/Users/peera/Desktop/kan-mammotev2


In [2]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from typing import List, Tuple, Dict, Optional
import sys
import os

# Now, import the modules from your KAN-MAMMOTE project
try:
    from src.utils.config import KANMAMMOTEConfig
    from src.models.kan_mammote import KANMAMMOTE # <<< IMPORTANT: Import the main KANMAMMOTE model
    from src.models.k_mote import K_MOTE # For expert names and direct access (though via KANMAMMOTE)
    from src.models.kan.MatrixKANLayer import MatrixKANLayer # Explicit import for type checking
except ImportError as e:
    print(f"FATAL ERROR: Could not import KAN-MAMMOTE modules. Please check your path setup and file structure.")
    print(f"Error details: {e}")
    sys.exit(1)

# Configure Matplotlib for saving figures (non-interactive backend)
plt.switch_backend('Agg')
sns.set_theme(style="whitegrid") # Use seaborn for better aesthetics

def check_expert_usage(model: KANMAMMOTE, config: KANMAMMOTEConfig):
    """
    Checks and visualizes the usage frequency of each expert in K-MOTE
    within the KANMAMMOTE model.
    """
    print("\n--- Checking Expert Usage ---")

    # Generate dummy data for forward pass
    batch_size = 64
    seq_len = 100
    input_feature_dim = config.input_feature_dim
    
    # Timestamps (e.g., uniformly spaced or random)
    timestamps = torch.linspace(0, 10, steps=seq_len).unsqueeze(0).repeat(batch_size, 1) # (B, L)
    features = torch.randn(batch_size, seq_len, input_feature_dim) # (B, L, F)
    # Auxiliary features are optional, create if config implies their use
    auxiliary_features = torch.randn(batch_size, seq_len, config.raw_event_feature_dim) if config.use_aux_features_router and config.raw_event_feature_dim > 0 else None

    # Move to device
    timestamps = timestamps.to(config.device, dtype=config.dtype)
    features = features.to(config.device, dtype=config.dtype)
    if auxiliary_features is not None:
        auxiliary_features = auxiliary_features.to(config.device, dtype=config.dtype)

    # Perform a forward pass to collect expert weights for loss
    # KANMAMMOTE.forward returns (model_output, regularization_losses_dict)
    with torch.no_grad():
        _, regularization_losses_dict = model(timestamps, features, auxiliary_features)

    # The KANMAMMOTE.forward collects `all_expert_weights_for_loss` and `all_expert_selection_masks`
    # from the K-MOTE instance of the *first* Mamba block (`l_idx == 0`).
    # We retrieve the load_balance_loss, which is derived from the aggregated expert weights.
    
    # To get the raw expert weights for visualization:
    # We need to access the K_MOTE router's last computed weights.
    # This is slightly tricky because KANMAMMOTE only returns the *loss* values,
    # not the raw weights for visualization.
    # A cleaner way is to make KANMAMMOTE forward also return expert_selection_mask
    # or iterate through the K-MOTE instances in the blocks directly.

    # Let's directly call the K-MOTE of the first block to get a fresh set of masks for plotting.
    # This ensures we get the actual masks, not just the loss component.
    first_mamba_block = model.mamba_blocks[0]
    k_mote_instance = first_mamba_block.k_mote
    
    # Re-run K-MOTE for a sample batch to get expert_selection_mask directly for visualization
    # Use the first timestamp and auxiliary feature from the dummy data
    sample_tk_current = timestamps[:, 0].unsqueeze(1) # (batch_size, 1)
    sample_aux_features = auxiliary_features[:, 0, :] if auxiliary_features is not None else None

    _, raw_expert_weights, expert_selection_mask = k_mote_instance(sample_tk_current, sample_aux_features)
    
    expert_names = k_mote_instance.expert_names
    
    # Sum selected experts across the batch
    # (batch_size, num_experts) -> (num_experts,)
    total_selections_per_expert = expert_selection_mask.sum(dim=0).cpu().numpy()
    
    num_samples_for_viz = batch_size
    # Calculate average selection count per sample, then convert to percentage.
    # If K_top=1, percentages sum to 100%. If K_top > 1, they sum to K_top * 100%.
    usage_percentages = (total_selections_per_expert / num_samples_for_viz) * 100

    print(f"K-MOTE Expert Usage (K_top={config.K_top}, from first block's K-MOTE):")
    for i, name in enumerate(expert_names):
        print(f"- {name}: {usage_percentages[i]:.2f}% selected on average per sample")

    # Plotting
    plt.figure(figsize=(9, 6))
    sns.barplot(x=expert_names, y=usage_percentages, palette="viridis")
    plt.title(f'K-MOTE Expert Usage Percentage (K_top={config.K_top})')
    plt.xlabel('Expert Type')
    plt.ylabel(f'Average Selection Count (% of samples where expert was chosen)')
    # Adjust Y-axis limit to accommodate K_top > 1 scenarios gracefully
    plt.ylim(0, config.K_top * 100 + (10 if config.K_top == 1 else 20)) # Add some buffer
    plt.tight_layout()
    plt.savefig('kmote_expert_usage.png')
    print("Expert usage plot saved to kmote_expert_usage.png")
    plt.close()

def visualize_expert_functions(model: KANMAMMOTE, config: KANMAMMOTEConfig, num_test_points: int = 200):
    """
    Visualizes the functional form of each expert in K-MOTE.
    It accesses the internal K_MOTE module from the first Mamba block
    and plots the output of each basis function.
    """
    print("\n--- Visualizing Individual Expert Functions ---")

    # Access the K_MOTE module from the first ContinuousMambaBlock
    first_mamba_block = model.mamba_blocks[0]
    k_mote_module: K_MOTE = first_mamba_block.k_mote
    expert_names = k_mote_module.expert_names

    # Generate a range of timestamp inputs
    # A wider range helps see the function's behavior more comprehensively
    test_timestamps = torch.linspace(
        min(config.kan_grid_range[0], -10.0), # Extend range beyond default if grid_range is small
        max(config.kan_grid_range[1], 10.0),
        num_test_points
    ).unsqueeze(1) # (num_test_points, 1)
    test_timestamps = test_timestamps.to(config.device, dtype=config.dtype)

    plt.figure(figsize=(10, 7))
    for name in expert_names:
        expert_module = k_mote_module.experts[name]
        
        with torch.no_grad(): # Disable gradient calculations for visualization
            # KANLayer's forward expects (batch_size, in_features) -> (batch_size, out_features)
            # Here, in_features is 1 (timestamp), out_features is D_time.
            expert_output = expert_module(test_timestamps)
            
            # MatrixKANLayer returns a tuple (y, preacts, postacts, postspline)
            if isinstance(expert_module, MatrixKANLayer):
                expert_output = expert_output[0] # Take the main output 'y'

            # Each expert outputs a (num_test_points, D_time) tensor.
            # To visualize as a single curve, we average across the D_time dimensions.
            # This represents the "average" function learned by that expert.
            visual_output = expert_output.mean(dim=-1).cpu().numpy() # (num_test_points,)

        plt.plot(test_timestamps.cpu().numpy().flatten(), visual_output, label=f'{name} Basis', linewidth=2)

    plt.title('K-MOTE Individual Expert Basis Functions (Mean across D_time)', fontsize=14)
    plt.xlabel('Timestamp Input', fontsize=12)
    plt.ylabel('Basis Function Output (Mean)', fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.axhline(0, color='gray', linestyle='-', linewidth=0.8) # Add a horizontal line at y=0
    plt.tight_layout()
    plt.savefig('kmote_expert_functions.png')
    print("Individual expert functions plot saved to kmote_expert_functions.png")
    plt.close()

# --- Main Execution Block ---
if __name__ == "__main__":
    # 1. Configuration for the KAN-MAMMOTE model
    # Adjust parameters here to match your model's actual configuration if different
    config = KANMAMMOTEConfig(
        d_model=128,          # Main model dimension
        D_time=64,            # Output dimension of K-MOTE experts
        num_layers=2,         # Number of ContinuousMambaBlocks
        input_feature_dim=10, # Dummy input feature dimension
        output_dim_for_task=1, # Dummy output dimension for the final task
        K_top=4,              # Number of top experts selected by MoE router
        use_aux_features_router=False, # Set to True if K-MOTE router uses auxiliary_features
        raw_event_feature_dim=0, # Must be > 0 if use_aux_features_router is True
        
        # Mamba2 specific parameters (ensure these are compatible with your Mamba2 setup)
        mamba_d_state=64,     # State dimension for Mamba SSM
        mamba_headdim=32,     # Dimension per head for Mamba. d_model must be divisible by headdim.
        
        # KANLayer and Basis Function Parameters
        kan_noise_scale=0.1,
        kan_grid_range=[-1, 1], # Initial grid range for Spline (MatrixKANLayer)
        
        # Fourier Basis specific
        fourier_k_prime=10,
        fourier_learnable_params=True,
        
        # RKHS / Gaussian Kernel Basis specific
        rkhs_num_mixture_components=10,
        rkhs_learnable_params=True,
        
        # Wavelet Basis specific
        wavelet_num_wavelets=10,
        wavelet_mother_type='mexican_hat', # or 'morlet'
        wavelet_learnable_params=True,
        
        # Spline Basis (MatrixKANLayer) specific
        spline_grid_size=8,
        spline_degree=3,
        kan_sp_trainable=True, # Scale spline trainable
        kan_sb_trainable=True, # Scale base trainable

        # General training parameters (not directly used in this check script but good to configure)
        learning_rate=1e-3,
        batch_size=32,
        sequence_length=100,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        dtype=torch.float32
    )
    
    # Validate Mamba headdim and d_model for internal consistency
    if config.d_model % config.mamba_headdim != 0:
        print(f"Warning: d_model ({config.d_model}) is not perfectly divisible by mamba_headdim ({config.mamba_headdim}). This might cause issues with Mamba's internal structure.")

    print(f"Model will run on device: {config.device}")
    
    # 2. Instantiate the model
    # Instatiate the correct KANMAMMOTE model
    model = KANMAMMOTE(config).to(config.device, dtype=config.dtype)
    model.eval() # Set to eval mode to disable dropout, router noise, etc., for consistent visualization

    print(f"\nKAN-MAMMOTE Model (from src/models/kan_mammote.py) initialized successfully.")
    print(f"K-MOTE D_time (output dimension of individual experts): {config.D_time}")
    print(f"K-MOTE K_top (number of active experts chosen by router): {config.K_top}")

    # 3. Perform the checks and visualizations
    check_expert_usage(model, config)
    visualize_expert_functions(model, config)

    print("\nAll checks complete. Please review 'kmote_expert_usage.png' and 'kmote_expert_functions.png' in your current directory.")

  from .autonotebook import tqdm as notebook_tqdm


Model will run on device: cuda

KAN-MAMMOTE Model (from src/models/kan_mammote.py) initialized successfully.
K-MOTE D_time (output dimension of individual experts): 64
K-MOTE K_top (number of active experts chosen by router): 4

--- Checking Expert Usage ---


ValueError: not enough values to unpack (expected 5, got 3)

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from typing import List, Tuple, Dict, Optional
import sys
import os


# Now, import the modules from your KAN-MAMMOTE project
try:
    from src.utils.config import KANMAMMOTEConfig
    from src.models.kan_mammote import KANMAMMOTE # <<< IMPORTANT: Import the main KANMAMMOTE model
    from src.models.k_mote import K_MOTE # Used for type hinting and expert access
    from src.models.kan.MatrixKANLayer import MatrixKANLayer # Explicit import for type checking
except ImportError as e:
    print(f"FATAL ERROR: Could not import KAN-MAMMOTE modules. Please check your path setup and file structure.")
    print(f"Error details: {e}")
    sys.exit(1)

# Configure Matplotlib for saving figures (non-interactive backend)
plt.switch_backend('Agg')
sns.set_theme(style="whitegrid") # Use seaborn for better aesthetics

def get_expert_output_snapshot(k_mote_instance: K_MOTE, test_timestamps: torch.Tensor) -> Dict[str, np.ndarray]:
    """
    Helper function to get the current functional output of each K-MOTE expert.
    """
    snapshot = {}
    expert_names = k_mote_instance.expert_names
    
    # Ensure model is in eval mode for consistent visualization (no dropout/noise)
    k_mote_instance.eval() 

    with torch.no_grad(): # Disable gradient calculations
        for name in expert_names:
            expert_module = k_mote_instance.experts[name]
            
            # KANLayer's forward expects (batch_size, in_features) -> (batch_size, out_features)
            # Here, in_features is 1 (timestamp), out_features is D_time.
            expert_output = expert_module(test_timestamps)
            
            # MatrixKANLayer returns a tuple (y, preacts, postacts, postspline)
            if isinstance(expert_module, MatrixKANLayer):
                expert_output = expert_output[0] # Take the main output 'y'

            # Each expert outputs a (num_test_points, D_time) tensor.
            # To visualize as a single curve, we average across the D_time dimensions.
            snapshot[name] = expert_output.mean(dim=-1).cpu().numpy() # (num_test_points,)
    return snapshot

# --- Main Execution Block ---
if __name__ == "__main__":
    # 1. Configuration for the KAN-MAMMOTE model
    config = KANMAMMOTEConfig(
        d_model=128,          # Main model dimension
        D_time=64,            # Output dimension of K-MOTE experts
        num_layers=2,         # Number of ContinuousMambaBlocks
        input_feature_dim=10, # Dummy input feature dimension
        output_dim_for_task=1, # Dummy output dimension for the final task
        K_top=2,              # Number of top experts selected by MoE router
        use_aux_features_router=False, # Set to True if K-MOTE router uses auxiliary_features
        raw_event_feature_dim=0, # Must be > 0 if use_aux_features_router is True
        
        # Mamba2 specific parameters
        mamba_d_state=64,     # State dimension for Mamba SSM
        mamba_headdim=32,     # Dimension per head for Mamba. d_model must be divisible by headdim.
        
        # KANLayer and Basis Function Parameters (relevant for K-MOTE experts)
        kan_noise_scale=0.1,
        kan_grid_range=[-1, 1], # Initial grid range for Spline (MatrixKANLayer)
        
        # Fourier Basis specific
        fourier_k_prime=10,
        fourier_learnable_params=True,
        
        # RKHS / Gaussian Kernel Basis specific
        rkhs_num_mixture_components=10,
        rkhs_learnable_params=True,
        
        # Wavelet Basis specific
        wavelet_num_wavelets=10,
        wavelet_mother_type='mexican_hat', # or 'morlet'
        wavelet_learnable_params=True,
        
        # Spline Basis (MatrixKANLayer) specific
        spline_grid_size=8,
        spline_degree=3,
        kan_sp_trainable=True, # Scale spline trainable
        kan_sb_trainable=True, # Scale base trainable

        # Training parameters for dummy loop
        learning_rate=1e-3,
        num_epochs=5,         # Total epochs to run (0 to 5 means 6 snapshots)
        batch_size=32,
        sequence_length=100,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        dtype=torch.float32
    )
    
    # Validate Mamba headdim and d_model for internal consistency
    if config.d_model % config.mamba_headdim != 0:
        print(f"Warning: d_model ({config.d_model}) is not perfectly divisible by mamba_headdim ({config.mamba_headdim}). This might cause issues with Mamba's internal structure.")

    print(f"Model will run on device: {config.device}")
    
    # 2. Instantiate the model, optimizer, and loss
    model = KANMAMMOTE(config).to(config.device, dtype=config.dtype)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = nn.MSELoss()

    # Define epochs to capture snapshots (0, 1, 2, 3, 4, 5)
    epochs_to_capture = [0] + list(range(1, config.num_epochs + 1)) 
    num_test_points_viz = 200 # Number of points for plotting functions

    # Fixed test timestamps for visualization consistency across epochs
    test_timestamps_viz = torch.linspace(
        min(config.kan_grid_range[0], -10.0), # Extend range beyond default if grid_range is small
        max(config.kan_grid_range[1], 10.0),
        num_test_points_viz
    ).unsqueeze(1).to(config.device, dtype=config.dtype)

    expert_function_snapshots = {} # To store {epoch: {expert_name: output_data}}

    # Access the K-MOTE instance from the first Mamba block (representative)
    # This K_MOTE instance is created inside ContinuousMambaBlock's __init__
    k_mote_instance_for_viz = model.mamba_blocks[0].k_mote
    expert_names = k_mote_instance_for_viz.expert_names

    # --- Capture initial state (Epoch 0) ---
    print("\n--- Capturing Expert Functions at Epoch 0 (Initial State) ---")
    expert_function_snapshots[0] = get_expert_output_snapshot(k_mote_instance_for_viz, test_timestamps_viz)
    print("Captured expert functions at Epoch 0.")

    # --- Dummy Training Loop ---
    print("\n--- Starting Dummy Training for Expert Evolution ---")
    num_batches_per_epoch = 10 # Simulate multiple batches per epoch

    for epoch in range(1, config.num_epochs + 1):
        model.train() # Set model to training mode
        total_loss = 0.0

        for batch_idx in range(num_batches_per_epoch):
            # Generate new dummy data for each batch
            timestamps_batch = torch.linspace(
                0, 10, steps=config.sequence_length
            ).unsqueeze(0).repeat(config.batch_size, 1).to(config.device, dtype=config.dtype)
            features_batch = torch.randn(
                config.batch_size, config.sequence_length, config.input_feature_dim
            ).to(config.device, dtype=config.dtype)
            
            auxiliary_features_batch = None
            if config.use_aux_features_router and config.raw_event_feature_dim > 0:
                auxiliary_features_batch = torch.randn(
                    config.batch_size, config.sequence_length, config.raw_event_feature_dim
                ).to(config.device, dtype=config.dtype)
            
            # Dummy target based on timestamps and features (simple regression task)
            # Ensure target_output matches the output_dim_for_task
            target_output_batch = (
                torch.sin(timestamps_batch * 0.5) * 2 + torch.cos(features_batch.mean(dim=-1)) + 5
            ).unsqueeze(-1).repeat(1, 1, config.output_dim_for_task).to(config.device, dtype=config.dtype)
            
            optimizer.zero_grad()
            model_output, regularization_losses = model(timestamps_batch, features_batch, auxiliary_features_batch)
            
            task_loss = criterion(model_output, target_output_batch)
            
            total_regularization_loss = sum(reg_loss for reg_loss in regularization_losses.values())
            
            loss = task_loss + total_regularization_loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / num_batches_per_epoch
        print(f"Epoch {epoch}/{config.num_epochs}, Avg Loss: {avg_loss:.4f}")

        # --- Capture state after this epoch ---
        print(f"Capturing expert functions after Epoch {epoch}...")
        expert_function_snapshots[epoch] = get_expert_output_snapshot(k_mote_instance_for_viz, test_timestamps_viz)
        print(f"Snapshot taken for Epoch {epoch}.")
        
    print("\n--- Dummy Training Complete ---")

    # --- Plotting Expert Function Evolution ---
    print("\n--- Generating Expert Function Evolution Plot ---")
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(18, 12))
    axes = axes.flatten() # Flatten for easier iteration (ax[0], ax[1], ...)

    # Map expert names to subplot indices
    expert_map = {name: i for i, name in enumerate(expert_names)}
    
    # Generate a color palette for epochs
    colors = sns.color_palette("viridis", n_colors=len(epochs_to_capture))

    for expert_name, ax_idx in expert_map.items():
        ax = axes[ax_idx]
        ax.set_title(f'{expert_name} Basis Function Evolution', fontsize=14)
        ax.set_xlabel('Timestamp Input', fontsize=12)
        ax.set_ylabel('Basis Function Output (Mean)', fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.axhline(0, color='gray', linestyle='-', linewidth=0.8) # Add a horizontal line at y=0

        for i, epoch in enumerate(epochs_to_capture):
            ax.plot(test_timestamps_viz.cpu().numpy().flatten(), 
                    expert_function_snapshots[epoch][expert_name], 
                    label=f'Epoch {epoch}', 
                    color=colors[i],
                    linewidth=2 if epoch == 0 else 1.5, # Make epoch 0 thicker
                    alpha=1.0 if epoch == 0 else 0.8) # Make epoch 0 fully opaque
        ax.legend(loc='best', fontsize=10)

    plt.tight_layout()
    plt.savefig('kmote_expert_functions_evolution.png')
    print("Expert function evolution plot saved to kmote_expert_functions_evolution.png")
    plt.close()

    # You can re-run expert usage check here if desired, though it's less about
    # "evolution" and more about overall distribution for a single run.
    # check_expert_usage(model, config) 

    print("\nAll visualization complete. Please review 'kmote_expert_functions_evolution.png' in your current directory.")

Model will run on device: cuda

--- Capturing Expert Functions at Epoch 0 (Initial State) ---
Captured expert functions at Epoch 0.

--- Starting Dummy Training for Expert Evolution ---
Epoch 1/5, Avg Loss: 10.4433
Capturing expert functions after Epoch 1...
Snapshot taken for Epoch 1.
Epoch 2/5, Avg Loss: 2.3719
Capturing expert functions after Epoch 2...
Snapshot taken for Epoch 2.
Epoch 3/5, Avg Loss: 2.1017
Capturing expert functions after Epoch 3...
Snapshot taken for Epoch 3.
Epoch 4/5, Avg Loss: 2.0524
Capturing expert functions after Epoch 4...
Snapshot taken for Epoch 4.
Epoch 5/5, Avg Loss: 2.0449
Capturing expert functions after Epoch 5...
Snapshot taken for Epoch 5.

--- Dummy Training Complete ---

--- Generating Expert Function Evolution Plot ---
Expert function evolution plot saved to kmote_expert_functions_evolution.png

All visualization complete. Please review 'kmote_expert_functions_evolution.png' in your current directory.
