In [2]:
import torch
print(torch.__version__)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import csv
import os
from tqdm import tqdm

2.6.0+cu126


# KAN-MAMMOTE as Plug-and-Play Time Encoder

The code below shows how to replace LeTE with KAN-MAMMOTE for time encoding in the EventBasedMNIST task. KAN-MAMMOTE can be used as a drop-in replacement that generates richer, adaptive time embeddings.

# Comparison: LeTE vs KAN-MAMMOTE

## Key Differences:

### **LeTE (Original)**
- **Fixed architecture**: Fourier + Spline with fixed mixing ratio (p=0.5)
- **Manual hyperparameter tuning**: Need to set `p` manually
- **Limited adaptability**: Cannot learn optimal mixing for each timestamp

### **KAN-MAMMOTE (Replacement)**
- **Adaptive architecture**: Router dynamically selects best experts for each timestamp
- **Richer expert diversity**: Fourier + Spline + Gaussian Kernel + Wavelet experts
- **Automatic optimization**: No manual hyperparameter tuning for expert mixing
- **Better temporal modeling**: Continuous-time Mamba for sequence modeling

## Expected Benefits:
1. **Better accuracy** due to adaptive expert selection
2. **More interpretable** temporal patterns through expert usage analysis
3. **Plug-and-play compatibility** with existing LSTM architectures
4. **Scalable** to more complex temporal datasets

## Usage as Feature Extractor:
KAN-MAMMOTE works perfectly as a feature extractor - just replace the time encoder and the model learns better temporal representations automatically!

In [1]:
# Step 1: Import KAN-MAMMOTE modules
import sys
sys.path.append('../../')  # Add project root to path

from src.models.kan_mammote import KAN_MAMOTE_Model
from src.utils.config import KANMAMOTEConfig
from src.losses.simple_losses import recommended_feature_loss

# Step 2: Create KAN-MAMMOTE configuration
config = KANMAMOTEConfig(
    # Time embedding dimensions
    D_time=32,  # Match LeTE's embedding_dim
    D_time_per_expert=8,  # 32/4 = 8 per expert
    num_experts=4,  # Fourier, Spline, Gaussian, Wavelet
    
    # MoE settings
    K_top=2,  # Use top-2 experts
    use_aux_features_router=False,  # Just timestamps, no auxiliary features
    
    # Raw event features (in this case, just timestamps)
    raw_event_feature_dim=0,  # No additional features beyond timestamps
    
    # Mamba block settings (for sequential processing)
    hidden_dim_mamba=32,  # Match output dimension
    state_dim_mamba=16,
    
    # Regularization (can be adjusted)
    lambda_moe_load_balancing=0.01,
    lambda_sobolev_l2=0.0,  # Start without additional regularization
    lambda_total_variation_l1=0.0
)

# Step 3: Define KAN-MAMMOTE-based Classifier
class KANMAMOTELSTMClassifier(nn.Module):
    def __init__(self, input_size=784, embedding_dim=32, hidden_dim=128, num_classes=10):
        super(KANMAMOTELSTMClassifier, self).__init__()
        self.embedding_dim = embedding_dim
        
        # KAN-MAMMOTE time encoder (replaces LeTE)
        self.time_encoder = KAN_MAMOTE_Model(config)
        
        # LSTM and classifier (same as before)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x, lengths):
        # Convert pixel indices to timestamps (normalize to [0,1])
        timestamps = x.float() / 784.0  # Normalize pixel indices
        timestamps = timestamps.unsqueeze(-1)  # (batch, seq_len, 1)
        
        # Create dummy event features (since we only have timestamps)
        batch_size, seq_len = x.shape
        event_features = torch.zeros(batch_size, seq_len, 0, device=x.device)  # Empty features
        
        # Get time embeddings from KAN-MAMMOTE
        time_embeddings, moe_info = self.time_encoder(timestamps, event_features)
        # time_embeddings shape: (batch, seq_len, embedding_dim)
        
        # Continue with LSTM processing (same as before)
        packed = pack_padded_sequence(time_embeddings, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (h_n, c_n) = self.lstm(packed)
        h_n = h_n[-1]  # (batch, hidden_dim)
        out = self.fc(h_n)  # (batch, num_classes)
        
        return out, moe_info  # Return MoE info for potential regularization

Using MLP fallback.


TypeError: KANMAMOTEConfig.__init__() got an unexpected keyword argument 'D_time_per_expert'

In [None]:
# Step 4: Train with KAN-MAMMOTE (Feature Extraction Mode)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create KAN-MAMMOTE model
kan_mammote_model = KANMAMOTELSTMClassifier(
    input_size=784, 
    embedding_dim=32, 
    hidden_dim=128, 
    num_classes=10
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kan_mammote_model.parameters(), lr=1e-3)

# Modified evaluation function
def evaluate_kan_mammote(model, data_loader, device, criterion):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for padded_events, lengths, labels in data_loader:
            padded_events = padded_events.float().to(device)
            lengths = lengths.to(device)
            labels = labels.to(device)
            
            outputs, moe_info = model(padded_events, lengths)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    acc = total_correct / total_samples
    return avg_loss, acc

# Training loop for KAN-MAMMOTE
print("Training KAN-MAMMOTE model...")
num_epochs = 50  # Fewer epochs for demo

for epoch in range(num_epochs):
    kan_mammote_model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for batch_idx, (padded_events, lengths, labels) in enumerate(tqdm(train_loader)):
        padded_events = padded_events.float().to(device)
        lengths = lengths.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        # Forward pass with KAN-MAMMOTE
        outputs, moe_info = kan_mammote_model(padded_events, lengths)
        
        # Simple loss (just main task + minimal MoE regularization)
        main_loss = criterion(outputs, labels)
        total_loss_with_reg, loss_dict = recommended_feature_loss(main_loss, moe_info)
        
        total_loss_with_reg.backward()
        optimizer.step()

        total_loss += main_loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    acc = total_correct / total_samples

    # Testing
    test_loss, test_acc = evaluate_kan_mammote(kan_mammote_model, test_loader, device, criterion)

    print(f"KAN-MAMMOTE Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {avg_loss:.4f}, Train Accuracy: {acc:.4f}, "
          f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

print("KAN-MAMMOTE training completed!")

In [4]:
#print cuda available
import torch
if torch.cuda.is_available():
    print("CUDA is available")