# MNIST Event-Based Classification with KAN-MOTE

This notebook demonstrates how to use KAN-MOTE (Kernel-Mixture-of-Time-Experts) as a time encoder for event-based MNIST classification, replacing the original LeTE (Learnable Time Encoding) approach.

## Key Features:
- **Adaptive Time Encoding**: KAN-MOTE dynamically selects the best combination of temporal experts
- **Multiple Expert Types**: Fourier, Spline, Gaussian Kernel, and Wavelet experts
- **Top-K Expert Selection**: Efficient computation by activating only the most relevant experts
- **Plug-and-Play**: Direct replacement for LeTE in existing architectures

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import csv
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# Add project root to path for imports
import sys
sys.path.append('.')

# Import KAN-MOTE components
from src.models.k_mote import K_MOTE
from src.utils.config import KANMAMOTEConfig
from src.losses.simple_losses import recommended_feature_loss

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")

ModuleNotFoundError: No module named 'matplotlib'

## Dataset Preparation

We'll convert MNIST images to event sequences by treating non-zero pixels as events, ordered by their positions.

In [None]:
class EventBasedMNIST(Dataset):
    """
    Converts MNIST images to event sequences.
    Each event corresponds to a pixel above the threshold, with timestamp = pixel position.
    """
    def __init__(self, root, train=True, threshold=0.1, transform=None, download=True):
        self.root = root
        self.train = train
        self.threshold = threshold
        self.transform = transform
        
        # Load MNIST dataset
        self.data = datasets.MNIST(root=self.root, train=self.train, 
                                 transform=self.transform, download=download)
        
        # Convert images to event sequences
        self.event_data = []
        self.labels = []
        self.process_data()
    
    def process_data(self):
        """Convert each MNIST image to an event sequence"""
        for img, label in tqdm(self.data, desc="Processing MNIST to events"):
            if isinstance(img, torch.Tensor):
                img_flat = img.view(-1)  # (784,)
            else:
                img_flat = torch.tensor(img).view(-1)
            
            # Find pixels above threshold (events)
            event_indices = torch.nonzero(img_flat > self.threshold).squeeze()
            
            # Handle edge cases
            if event_indices.dim() == 0:
                if event_indices.numel() == 1:
                    event_indices = event_indices.unsqueeze(0)
                else:
                    # No events - create a dummy event at position 0
                    event_indices = torch.tensor([0])
            
            # Sort events by position (temporal order)
            events = torch.sort(event_indices).values
            
            self.event_data.append(events)
            self.labels.append(label)
    
    def __len__(self):
        return len(self.event_data)
    
    def __getitem__(self, idx):
        return self.event_data[idx], self.labels[idx]

def custom_collate_fn(batch):
    """Custom collate function for variable-length event sequences"""
    events_list = []
    labels_list = []
    lengths = []
    
    for events, label in batch:
        events_list.append(events)
        labels_list.append(label)
        lengths.append(events.shape[0])
    
    # Convert to tensors
    labels_tensor = torch.tensor(labels_list, dtype=torch.long)
    lengths_tensor = torch.tensor(lengths, dtype=torch.long)
    
    # Pad sequences
    padded_events = pad_sequence(events_list, batch_first=True, padding_value=0)
    
    return padded_events, lengths_tensor, labels_tensor

## KAN-MOTE Configuration

Setting up the configuration for KAN-MOTE with appropriate parameters for MNIST event sequences.

In [None]:
# Step 1: Import KAN-MOTE modules (updated for DyGMamba-style)
import sys
sys.path.append('.')  # Add project root to path

from src.models.k_mote import K_MOTE
from src.utils.config import KANMAMOTEConfig
from src.losses.simple_losses import recommended_feature_loss

# Step 2: Create KAN-MOTE configuration (updated for DyGMamba-style)
config = KANMAMOTEConfig(
    # Time embedding dimensions
    D_time=32,  # Total embedding dimension
    D_time_per_expert=8,  # 32/4 = 8 per expert
    num_experts=4,  # Fourier, Spline, Gaussian, Wavelet
    
    # MoE settings
    K_top=2,  # Use top-2 experts
    use_aux_features_router=False,  # Just timestamps, no auxiliary features
    
    # Raw event features (in this case, just timestamps)
    raw_event_feature_dim=0,  # No additional features beyond timestamps
    
    # DyGMamba-style Mamba settings
    hidden_dim_mamba=32,  # Match output dimension
    state_dim_mamba=16,
    num_mamba_layers=2,  # Number of Mamba layers
    gamma=0.5,  # Time difference scaling factor
    
    # Regularization (can be adjusted)
    lambda_moe_load_balancing=0.01,
    lambda_sobolev_l2=0.0,  # Start without additional regularization
    lambda_total_variation_l1=0.0
)

# Step 3: Define KAN-MOTE-based Classifier (simplified, no full KAN-MAMMOTE)
class KANMOTELSTMClassifier(nn.Module):
    def __init__(self, input_size=784, embedding_dim=32, hidden_dim=128, num_classes=10):
        super(KANMOTELSTMClassifier, self).__init__()
        self.embedding_dim = embedding_dim
        
        # KAN-MOTE time encoder (just K-MOTE, not full KAN-MAMMOTE)
        self.time_encoder = K_MOTE(config)
        
        # LSTM and classifier (same as before)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x, lengths):
        # Convert pixel indices to timestamps (normalize to [0,1])
        timestamps = x.float() / 784.0  # Normalize pixel indices
        timestamps = timestamps.unsqueeze(-1)  # (batch, seq_len, 1)
        
        # Create dummy event features (since we only have timestamps)
        batch_size, seq_len = x.shape
        event_features = torch.zeros(batch_size, seq_len, 0, device=x.device)  # Empty features
        
        # Get time embeddings from KAN-MOTE
        # K-MOTE expects (batch*seq_len, 1) timestamps
        timestamps_flat = timestamps.view(batch_size * seq_len, 1)
        event_features_flat = event_features.view(batch_size * seq_len, 0) if event_features.shape[-1] > 0 else None
        
        # Get embeddings from K-MOTE
        time_embeddings, expert_weights, expert_mask = self.time_encoder(timestamps_flat, event_features_flat)
        
        # Reshape back to sequence format
        time_embeddings = time_embeddings.view(batch_size, seq_len, self.embedding_dim)
        
        # Continue with LSTM processing (same as before)
        packed = pack_padded_sequence(time_embeddings, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (h_n, c_n) = self.lstm(packed)
        h_n = h_n[-1]  # (batch, hidden_dim)
        out = self.fc(h_n)  # (batch, num_classes)
        
        # Prepare MoE info for regularization
        moe_info = {
            'expert_weights': expert_weights.view(batch_size, seq_len, -1),
            'expert_mask': expert_mask.view(batch_size, seq_len, -1),
            'router_logits': expert_weights.view(batch_size, seq_len, -1)
        }
        
        return out, moe_info  # Return MoE info for potential regularization

## Model Architecture

LSTM classifier using KAN-MOTE for time encoding, replacing the original LeTE approach.

In [None]:
class KANMOTELSTMClassifier(nn.Module):
    """
    LSTM classifier using KAN-MOTE for adaptive time encoding.
    
    Architecture:
    1. KAN-MOTE time encoder: converts pixel positions to rich temporal embeddings
    2. LSTM: processes the sequence of time embeddings
    3. Classifier: final classification layer
    """
    def __init__(self, input_size=784, embedding_dim=32, hidden_dim=128, num_classes=10, config=None):
        super(KANMOTELSTMClassifier, self).__init__()
        self.embedding_dim = embedding_dim
        self.input_size = input_size
        
        # KAN-MOTE time encoder
        self.time_encoder = K_MOTE(config)
        
        # LSTM for sequence processing
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, 
                           batch_first=True, dropout=0.1)
        
        # Classification head
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x, lengths):
        """
        Forward pass
        
        Args:
            x: (batch_size, seq_len) - event positions/timestamps
            lengths: (batch_size,) - sequence lengths
            
        Returns:
            logits: (batch_size, num_classes)
            moe_info: dictionary with MoE information for regularization
        """
        batch_size, seq_len = x.shape
        
        # Normalize timestamps to [0, 1] range
        timestamps = x.float() / (self.input_size - 1)  # Normalize pixel indices
        timestamps = timestamps.view(-1, 1)  # (batch * seq_len, 1)
        
        # Get time embeddings from KAN-MOTE
        time_embeddings, expert_weights, expert_mask = self.time_encoder(
            timestamp_input=timestamps, 
            auxiliary_features=None
        )
        
        # Reshape back to sequence format
        time_embeddings = time_embeddings.view(batch_size, seq_len, self.embedding_dim)
        
        # Process through LSTM
        packed = pack_padded_sequence(time_embeddings, lengths.cpu(), 
                                    batch_first=True, enforce_sorted=False)
        _, (h_n, c_n) = self.lstm(packed)
        
        # Use final hidden state for classification
        h_final = h_n[-1]  # (batch_size, hidden_dim)
        h_final = self.dropout(h_final)
        logits = self.fc(h_final)  # (batch_size, num_classes)
        
        # Prepare MoE info for potential regularization
        moe_info = {
            'expert_weights': expert_weights.view(batch_size, seq_len, -1),
            'expert_mask': expert_mask.view(batch_size, seq_len, -1),
            'router_logits': expert_weights.view(batch_size, seq_len, -1)  # For load balancing
        }
        
        return logits, moe_info

## Training Setup

Setting up the training pipeline with data loaders, model, and optimization.

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Data loading
transform = transforms.ToTensor()
train_dataset = EventBasedMNIST(root="./data", train=True, threshold=0.1, 
                               transform=transform, download=True)
test_dataset = EventBasedMNIST(root="./data", train=False, threshold=0.1, 
                              transform=transform, download=True)

# Create data loaders
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                         collate_fn=custom_collate_fn, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                        collate_fn=custom_collate_fn, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Batch size: {batch_size}")

# Analyze dataset statistics
sample_lengths = [len(events) for events, _ in train_dataset]
print(f"Event sequence lengths - Mean: {np.mean(sample_lengths):.1f}, "
      f"Std: {np.std(sample_lengths):.1f}, "
      f"Min: {np.min(sample_lengths)}, Max: {np.max(sample_lengths)}")

In [None]:
# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = KANMOTELSTMClassifier(
    input_size=784, 
    embedding_dim=32, 
    hidden_dim=128, 
    num_classes=10,
    config=config
).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Training and Evaluation Functions

In [None]:
def evaluate_model(model, data_loader, device, criterion):
    """Evaluate model performance on a dataset"""
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for padded_events, lengths, labels in tqdm(data_loader, desc="Evaluating"):
            padded_events = padded_events.to(device)
            lengths = lengths.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs, moe_info = model(padded_events, lengths)
            loss = criterion(outputs, labels)
            
            # Statistics
            total_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

def analyze_expert_usage(model, data_loader, device, num_batches=10):
    """Analyze which experts are being used most frequently"""
    model.eval()
    expert_usage = torch.zeros(config.num_experts)
    total_tokens = 0
    
    with torch.no_grad():
        for i, (padded_events, lengths, labels) in enumerate(data_loader):
            if i >= num_batches:
                break
                
            padded_events = padded_events.to(device)
            lengths = lengths.to(device)
            
            _, moe_info = model(padded_events, lengths)
            expert_mask = moe_info['expert_mask']  # (batch, seq_len, num_experts)
            
            # Count expert usage
            usage = expert_mask.sum(dim=(0, 1)).cpu()  # Sum over batch and sequence
            expert_usage += usage
            total_tokens += expert_mask.shape[0] * expert_mask.shape[1]
    
    # Normalize to get percentages
    expert_usage_pct = expert_usage / total_tokens * 100
    expert_names = ['Fourier', 'Spline', 'Gaussian', 'Wavelet']
    
    print("Expert Usage Analysis:")
    for name, usage in zip(expert_names, expert_usage_pct):
        print(f"  {name}: {usage:.1f}%")
    
    return expert_usage_pct.numpy()

## Training Loop

Training the KAN-MOTE model with regularization and monitoring.

In [None]:
# Training setup
num_epochs = 50
log_file = "training_log_kan_mote.csv"
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

# Remove existing log file
if os.path.exists(log_file):
    os.remove(log_file)

# Initialize tracking
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
learning_rates = []

# CSV logging setup
with open(log_file, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_loss", "train_accuracy", "test_loss", 
                    "test_accuracy", "lr", "moe_loss", "main_loss"])

print("Starting training with KAN-MOTE...")
print("=" * 60)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_main_loss = 0.0
    total_moe_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    # Training loop
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, (padded_events, lengths, labels) in enumerate(progress_bar):
        padded_events = padded_events.to(device)
        lengths = lengths.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs, moe_info = model(padded_events, lengths)
        
        # Main classification loss
        main_loss = criterion(outputs, labels)
        
        # Compute total loss with MoE regularization
        total_loss_with_reg, loss_dict = recommended_feature_loss(main_loss, moe_info)
        
        # Backward pass
        total_loss_with_reg.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Statistics
        total_loss += total_loss_with_reg.item() * labels.size(0)
        total_main_loss += main_loss.item() * labels.size(0)
        total_moe_loss += (total_loss_with_reg.item() - main_loss.item()) * labels.size(0)
        
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
        # Update progress bar
        current_acc = total_correct / total_samples
        progress_bar.set_postfix({
            'Loss': f'{total_loss/total_samples:.4f}',
            'Acc': f'{current_acc:.4f}'
        })
    
    # Calculate epoch statistics
    avg_loss = total_loss / total_samples
    avg_main_loss = total_main_loss / total_samples
    avg_moe_loss = total_moe_loss / total_samples
    train_acc = total_correct / total_samples
    current_lr = optimizer.param_groups[0]['lr']
    
    # Evaluation
    test_loss, test_acc = evaluate_model(model, test_loader, device, criterion)
    
    # Learning rate scheduling
    scheduler.step(test_loss)
    
    # Store results
    train_losses.append(avg_loss)
    train_accuracies.append(train_acc)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)
    learning_rates.append(current_lr)
    
    # Print epoch results
    print(f"Epoch [{epoch+1}/{num_epochs}]:")
    print(f"  Train - Loss: {avg_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"  Test  - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")
    print(f"  MoE Loss: {avg_moe_loss:.6f}, LR: {current_lr:.6f}")
    print()
    
    # Save to CSV
    with open(log_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch+1, avg_loss, train_acc, test_loss, test_acc, 
                        current_lr, avg_moe_loss, avg_main_loss])
    
    # Expert usage analysis every 10 epochs
    if (epoch + 1) % 10 == 0:
        print("Expert Usage Analysis:")
        analyze_expert_usage(model, test_loader, device, num_batches=5)
        print()

print("Training completed!")
print(f"Final Test Accuracy: {test_accuracies[-1]:.4f}")

## Results Visualization and Analysis

In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

# Loss curves
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', alpha=0.8)
plt.plot(test_losses, label='Test Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Accuracy curves
plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Accuracy', alpha=0.8)
plt.plot(test_accuracies, label='Test Accuracy', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Test Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Learning rate
plt.subplot(1, 3, 3)
plt.plot(learning_rates, alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.yscale('log')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{results_dir}/training_curves_kan_mote.png', dpi=300, bbox_inches='tight')
plt.show()

# Print final statistics
print("=" * 60)
print("FINAL RESULTS")
print("=" * 60)
print(f"Best Train Accuracy: {max(train_accuracies):.4f} (Epoch {train_accuracies.index(max(train_accuracies))+1})")
print(f"Best Test Accuracy:  {max(test_accuracies):.4f} (Epoch {test_accuracies.index(max(test_accuracies))+1})")
print(f"Final Train Accuracy: {train_accuracies[-1]:.4f}")
print(f"Final Test Accuracy:  {test_accuracies[-1]:.4f}")

In [None]:
# Detailed expert usage analysis
print("Detailed Expert Usage Analysis")
print("=" * 40)

expert_usage = analyze_expert_usage(model, test_loader, device, num_batches=20)

# Visualize expert usage
plt.figure(figsize=(10, 6))
expert_names = ['Fourier', 'Spline', 'Gaussian', 'Wavelet']
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']

plt.bar(expert_names, expert_usage, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
plt.ylabel('Usage Percentage (%)')
plt.title('Expert Usage Distribution in KAN-MOTE')
plt.grid(True, alpha=0.3, axis='y')

# Add percentage labels on bars
for i, (name, usage) in enumerate(zip(expert_names, expert_usage)):
    plt.text(i, usage + 1, f'{usage:.1f}%', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig(f'{results_dir}/expert_usage_kan_mote.png', dpi=300, bbox_inches='tight')
plt.show()

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'train_accuracies': train_accuracies,
    'test_accuracies': test_accuracies,
    'expert_usage': expert_usage
}, f'{results_dir}/kan_mote_mnist_model.pth')

print(f"Model and results saved to {results_dir}/")

## Summary

This notebook demonstrates the successful application of KAN-MOTE for event-based MNIST classification. Key observations:

### Performance Benefits:
- **Adaptive Expert Selection**: KAN-MOTE automatically learns which temporal patterns work best for different timestamps
- **Rich Temporal Modeling**: Multiple expert types (Fourier, Spline, Gaussian, Wavelet) capture diverse temporal patterns
- **Plug-and-Play**: Direct replacement for LeTE with minimal architecture changes

### Expert Usage Insights:
The expert usage analysis reveals which temporal patterns are most useful for MNIST event classification, providing interpretability into the model's temporal processing.

### Comparison with LeTE:
- **LeTE**: Fixed mixing ratio (p=0.5) between Fourier and Spline
- **KAN-MOTE**: Dynamic expert selection with 4 expert types and adaptive mixing

This demonstrates KAN-MOTE's potential as a superior time encoding method for temporal sequence tasks.