# 🎯 Comprehensive MNIST Embedding Comparison

This notebook compares the performance of different time embedding approaches on MNIST:
1. **Baseline LSTM** - No time embedding (raw pixel positions)
2. **LSTM + LETE** - With Learning Time Embedding (LeTE)
3. **LSTM + KAN-MAMMOTE** - With Improved KAN-MAMMOTE embedding

## 📊 Key Metrics to Compare:
- **Accuracy**: Classification performance
- **Training Speed**: Time per epoch
- **Parameter Count**: Model complexity
- **Convergence**: Training stability
- **Temporal Modeling**: How well each method captures temporal patterns

In [None]:
# ============================================================================
# 📦 IMPORTS AND SETUP
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import time
import warnings
warnings.filterwarnings('ignore')

# Import our models
import sys
import os
sys.path.append(os.path.join(os.getcwd(), 'src'))

from src.models import KAN_MAMMOTE_Model, ImprovedKANMAMOTE  # Improved version as default
from src.LETE.LeTE import CombinedLeTE
from src.utils.config import KANMAMOTEConfig

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️  Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name()}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("✅ All imports successful!")

## 📁 Data Setup

We'll convert MNIST images to event-based sequences where each non-zero pixel becomes an event with:
- **Timestamp**: Pixel position (row * width + col)
- **Features**: Pixel intensity (optional)
- **Label**: Digit class (0-9)

In [None]:
# ============================================================================
# 🎲 EVENT-BASED MNIST DATASET
# ============================================================================

class EventBasedMNIST(Dataset):
    """
    Convert MNIST images to event-based sequences.
    Each non-zero pixel becomes an event with timestamp = pixel position.
    """
    
    def __init__(self, root='./data', train=True, threshold=0.1, max_events=None):
        """
        Args:
            root: Data directory
            train: Training or test set
            threshold: Minimum pixel intensity to consider as event
            max_events: Maximum events per image (for memory efficiency)
        """
        self.threshold = threshold
        self.max_events = max_events
        
        # Load MNIST dataset
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        self.dataset = torchvision.datasets.MNIST(
            root=root, train=train, download=True, transform=transform
        )
        
        print(f"📊 Loaded {'training' if train else 'test'} set: {len(self.dataset)} samples")
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        
        # Convert to event sequence
        image = image.squeeze(0)  # Remove channel dimension: (28, 28)
        
        # Find non-zero pixels above threshold
        mask = image > self.threshold
        rows, cols = torch.where(mask)
        
        if len(rows) == 0:
            # Handle empty events (shouldn't happen with MNIST)
            events = torch.zeros(1, dtype=torch.long)
            features = torch.zeros(1, 1, dtype=torch.float)
            length = 1
        else:
            # Create timestamps: row * width + col (0-783 for 28x28)
            timestamps = rows * 28 + cols
            
            # Sort by timestamp (natural reading order)
            sorted_idx = torch.argsort(timestamps)
            timestamps = timestamps[sorted_idx]
            
            # Optional: Limit number of events
            if self.max_events and len(timestamps) > self.max_events:
                timestamps = timestamps[:self.max_events]
                sorted_idx = sorted_idx[:self.max_events]
            
            # Extract pixel intensities as features
            intensities = image[rows[sorted_idx], cols[sorted_idx]]
            features = intensities.unsqueeze(1)  # (seq_len, 1)
            
            events = timestamps
            length = len(events)
        
        return events, features, length, label

def collate_fn(batch):
    """
    Custom collate function for variable-length sequences.
    """
    events, features, lengths, labels = zip(*batch)
    
    # Pad sequences
    events_padded = pad_sequence(events, batch_first=True, padding_value=0)
    features_padded = pad_sequence(features, batch_first=True, padding_value=0.0)
    
    lengths = torch.tensor(lengths)
    labels = torch.tensor(labels)
    
    return events_padded, features_padded, lengths, labels

# Create datasets
train_dataset = EventBasedMNIST(train=True, threshold=0.1, max_events=200)
test_dataset = EventBasedMNIST(train=False, threshold=0.1, max_events=200)

# Create data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"📦 Data loaders created:")
print(f"   Train: {len(train_loader)} batches")
print(f"   Test: {len(test_loader)} batches")

# Test data loading
sample_batch = next(iter(train_loader))
events, features, lengths, labels = sample_batch
print(f"\n📋 Sample batch:")
print(f"   Events shape: {events.shape}")
print(f"   Features shape: {features.shape}")
print(f"   Lengths: {lengths[:5]}")
print(f"   Labels: {labels[:5]}")

## 🏗️ Model Definitions

We'll define three different LSTM-based models:
1. **Baseline LSTM**: Raw timestamps → LSTM → Classifier
2. **LSTM + LETE**: Timestamps → LETE → LSTM → Classifier
3. **LSTM + KAN-MAMMOTE**: Timestamps → KAN-MAMMOTE → LSTM → Classifier

In [None]:
# ============================================================================
# 🔥 MODEL 1: BASELINE LSTM (No Time Embedding)
# ============================================================================

class BaselineLSTM(nn.Module):
    """
    Baseline LSTM model without time embedding.
    Uses raw timestamps and pixel features directly.
    """
    
    def __init__(self, input_size=784, hidden_dim=128, num_layers=2, num_classes=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Simple embedding for timestamps (position embedding)
        self.timestamp_embedding = nn.Embedding(input_size, 32)
        
        # Feature projection (pixel intensity)
        self.feature_projection = nn.Linear(1, 32)
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=64,  # 32 (timestamp) + 32 (feature)
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.2
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )
        
    def forward(self, events, features, lengths):
        batch_size = events.size(0)
        
        # Embed timestamps
        timestamp_emb = self.timestamp_embedding(events)  # (batch, seq_len, 32)
        
        # Project features
        feature_emb = self.feature_projection(features)  # (batch, seq_len, 32)
        
        # Concatenate embeddings
        combined = torch.cat([timestamp_emb, feature_emb], dim=-1)  # (batch, seq_len, 64)
        
        # Pack sequences for LSTM
        packed = pack_padded_sequence(combined, lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # LSTM forward pass
        lstm_out, (h_n, c_n) = self.lstm(packed)
        
        # Use last hidden state for classification
        final_hidden = h_n[-1]  # (batch, hidden_dim)
        
        # Classify
        logits = self.classifier(final_hidden)
        
        return logits

# Create baseline model
baseline_model = BaselineLSTM().to(device)
baseline_params = sum(p.numel() for p in baseline_model.parameters() if p.requires_grad)
print(f"🔥 Baseline LSTM created: {baseline_params:,} parameters")

In [None]:
# ============================================================================
# 🌟 MODEL 2: LSTM + LETE EMBEDDING
# ============================================================================

class LSTM_LETE(nn.Module):
    """
    LSTM model with LETE (Learning Time Embedding) for temporal modeling.
    """
    
    def __init__(self, input_size=784, hidden_dim=128, num_layers=2, num_classes=10, lete_dim=64):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lete_dim = lete_dim
        
        # LETE for time embedding
        self.lete = CombinedLeTE(dim=lete_dim, p=0.5, layer_norm=True, scale=True)
        
        # Feature projection (pixel intensity)
        self.feature_projection = nn.Linear(1, 32)
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=lete_dim + 32,  # LETE embedding + feature embedding
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.2
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )
        
    def forward(self, events, features, lengths):
        batch_size = events.size(0)
        
        # Normalize timestamps to [0, 1] range for LETE
        timestamps_normalized = events.float() / 783.0  # 784 - 1
        
        # Apply LETE embedding
        lete_emb = self.lete(timestamps_normalized)  # (batch, seq_len, lete_dim)
        
        # Project features
        feature_emb = self.feature_projection(features)  # (batch, seq_len, 32)
        
        # Concatenate embeddings
        combined = torch.cat([lete_emb, feature_emb], dim=-1)  # (batch, seq_len, lete_dim + 32)
        
        # Pack sequences for LSTM
        packed = pack_padded_sequence(combined, lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # LSTM forward pass
        lstm_out, (h_n, c_n) = self.lstm(packed)
        
        # Use last hidden state for classification
        final_hidden = h_n[-1]  # (batch, hidden_dim)
        
        # Classify
        logits = self.classifier(final_hidden)
        
        return logits

# Create LETE model
lete_model = LSTM_LETE().to(device)
lete_params = sum(p.numel() for p in lete_model.parameters() if p.requires_grad)
print(f"🌟 LSTM + LETE created: {lete_params:,} parameters")

In [None]:
# ============================================================================
# 🚀 MODEL 3: LSTM + KAN-MAMMOTE EMBEDDING
# ============================================================================

class LSTM_KAN_MAMMOTE(nn.Module):
    """
    LSTM model with Improved KAN-MAMMOTE embedding for temporal modeling.
    """
    
    def __init__(self, input_size=784, hidden_dim=128, num_layers=2, num_classes=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # KAN-MAMMOTE configuration
        self.kan_config = KANMAMOTEConfig(
            D_time=64,  # Time embedding dimension
            n_experts=8,  # Number of experts
            hidden_dim_mamba=64,  # C-Mamba hidden dimension
            state_dim_mamba=32,  # State dimension
            num_mamba_layers=2,  # Number of Mamba layers
            gamma=0.3,  # Time difference scaling
            # Faster-KAN parameters
            kan_grid_size=5,
            kan_grid_min=-2.0,
            kan_grid_max=2.0,
            kan_spline_scale=0.667,
            kan_num_layers=2,
            kan_hidden_dim=64
        )
        
        # Improved KAN-MAMMOTE for time embedding
        self.kan_mammote = ImprovedKANMAMOTE(self.kan_config)
        
        # Feature projection (pixel intensity)
        self.feature_projection = nn.Linear(1, 32)
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=self.kan_config.hidden_dim_mamba + 32,  # KAN-MAMMOTE + features
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.2
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )
        
    def forward(self, events, features, lengths):
        batch_size = events.size(0)
        
        # Normalize timestamps to [0, 1] range for KAN-MAMMOTE
        timestamps = events.float() / 783.0  # 784 - 1
        timestamps = timestamps.unsqueeze(-1)  # (batch, seq_len, 1)
        
        # For KAN-MAMMOTE, we don't use additional event features (use empty)
        # The model focuses on temporal patterns
        empty_features = torch.zeros(batch_size, timestamps.size(1), 0, device=timestamps.device)
        
        # Apply KAN-MAMMOTE embedding
        kan_emb, kan_info = self.kan_mammote(timestamps, empty_features)
        # kan_emb: (batch, seq_len, hidden_dim_mamba)
        
        # Project pixel features
        feature_emb = self.feature_projection(features)  # (batch, seq_len, 32)
        
        # Concatenate KAN-MAMMOTE embeddings with pixel features
        combined = torch.cat([kan_emb, feature_emb], dim=-1)
        
        # Pack sequences for LSTM
        packed = pack_padded_sequence(combined, lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # LSTM forward pass
        lstm_out, (h_n, c_n) = self.lstm(packed)
        
        # Use last hidden state for classification
        final_hidden = h_n[-1]  # (batch, hidden_dim)
        
        # Classify
        logits = self.classifier(final_hidden)
        
        return logits, kan_info

# Create KAN-MAMMOTE model
kan_model = LSTM_KAN_MAMMOTE().to(device)
kan_params = sum(p.numel() for p in kan_model.parameters() if p.requires_grad)
print(f"🚀 LSTM + KAN-MAMMOTE created: {kan_params:,} parameters")

# Summary of all models
print(f"\n📊 Model Comparison:")
print(f"   Baseline LSTM:     {baseline_params:,} parameters")
print(f"   LSTM + LETE:       {lete_params:,} parameters")
print(f"   LSTM + KAN-MAMMOTE: {kan_params:,} parameters")

## 🎯 Training Setup

Define training and evaluation functions that work for all three models.

In [None]:
# ============================================================================
# 🏋️ TRAINING AND EVALUATION FUNCTIONS
# ============================================================================

def train_model(model, train_loader, test_loader, model_name, num_epochs=10):
    """
    Train a model and track performance metrics.
    """
    print(f"\n🏋️ Training {model_name}...")
    
    # Setup optimizer and loss
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    criterion = nn.CrossEntropyLoss()
    
    # Tracking metrics
    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []
    epoch_times = []
    
    best_test_acc = 0.0
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        # ========== TRAINING PHASE ==========
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        train_bar = tqdm(train_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs}")
        
        for batch_idx, (events, features, lengths, labels) in enumerate(train_bar):
            events, features, lengths, labels = events.to(device), features.to(device), lengths.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass (handle KAN-MAMMOTE returning additional info)
            if 'KAN' in model_name:
                outputs, _ = model(events, features, lengths)
            else:
                outputs = model(events, features, lengths)
            
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # Statistics
            train_loss += loss.item() * labels.size(0)
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            train_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*train_correct/train_total:.2f}%'
            })
        
        # ========== EVALUATION PHASE ==========
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for events, features, lengths, labels in test_loader:
                events, features, lengths, labels = events.to(device), features.to(device), lengths.to(device), labels.to(device)
                
                # Forward pass
                if 'KAN' in model_name:
                    outputs, _ = model(events, features, lengths)
                else:
                    outputs = model(events, features, lengths)
                
                loss = criterion(outputs, labels)
                
                test_loss += loss.item() * labels.size(0)
                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()
        
        # Calculate metrics
        train_loss = train_loss / train_total
        train_acc = 100. * train_correct / train_total
        test_loss = test_loss / test_total
        test_acc = 100. * test_correct / test_total
        
        # Update learning rate
        scheduler.step(test_loss)
        
        # Record metrics
        epoch_time = time.time() - epoch_start
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        epoch_times.append(epoch_time)
        
        # Track best model
        if test_acc > best_test_acc:
            best_test_acc = test_acc
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
        print(f"  Time: {epoch_time:.1f}s, Best Acc: {best_test_acc:.2f}%")
    
    return {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'test_losses': test_losses,
        'test_accs': test_accs,
        'epoch_times': epoch_times,
        'best_test_acc': best_test_acc,
        'final_test_acc': test_accs[-1],
        'avg_epoch_time': np.mean(epoch_times)
    }

print("✅ Training functions ready!")

## 🧪 Experiment Execution

Now let's train all three models and compare their performance!

In [None]:
# ============================================================================
# 🧪 RUN EXPERIMENTS
# ============================================================================

# Training configuration
NUM_EPOCHS = 15
results = {}

print("🎯 Starting comprehensive embedding comparison experiments...")
print(f"📊 Training for {NUM_EPOCHS} epochs each")

# Train all models
models_to_test = [
    (baseline_model, "Baseline LSTM"),
    (lete_model, "LSTM + LETE"),
    (kan_model, "LSTM + KAN-MAMMOTE")
]

for model, name in models_to_test:
    print(f"\n{'='*60}")
    print(f"🚀 Training {name}...")
    print(f"{'='*60}")
    
    try:
        # Train the model
        result = train_model(model, train_loader, test_loader, name, NUM_EPOCHS)
        results[name] = result
        
        print(f"\n✅ {name} training completed!")
        print(f"   Best Test Accuracy: {result['best_test_acc']:.2f}%")
        print(f"   Final Test Accuracy: {result['final_test_acc']:.2f}%")
        print(f"   Average Epoch Time: {result['avg_epoch_time']:.1f}s")
        
        # Clear GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    except Exception as e:
        print(f"❌ Error training {name}: {e}")
        results[name] = None

print(f"\n🎉 All experiments completed!")
print(f"📊 Results summary:")
for name, result in results.items():
    if result is not None:
        print(f"   {name}: {result['best_test_acc']:.2f}% (best), {result['final_test_acc']:.2f}% (final)")
    else:
        print(f"   {name}: Failed")

## 📊 Results Analysis & Visualization

Let's analyze and visualize the results to understand the performance differences.

In [None]:
# ============================================================================
# 📊 RESULTS ANALYSIS & VISUALIZATION
# ============================================================================

# Filter successful results
successful_results = {name: result for name, result in results.items() if result is not None}

if len(successful_results) == 0:
    print("❌ No successful experiments to analyze")
else:
    print(f"📊 Analyzing {len(successful_results)} successful experiments...")
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('🎯 MNIST Embedding Comparison Results', fontsize=16, fontweight='bold')
    
    # Colors for different models
    colors = {'Baseline LSTM': '#FF6B6B', 'LSTM + LETE': '#4ECDC4', 'LSTM + KAN-MAMMOTE': '#45B7D1'}
    
    # Plot 1: Training Loss
    ax1 = axes[0, 0]
    for name, result in successful_results.items():
        epochs = range(1, len(result['train_losses']) + 1)
        ax1.plot(epochs, result['train_losses'], label=name, color=colors.get(name, 'gray'), linewidth=2)
    ax1.set_title('📉 Training Loss', fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Training Accuracy
    ax2 = axes[0, 1]
    for name, result in successful_results.items():
        epochs = range(1, len(result['train_accs']) + 1)
        ax2.plot(epochs, result['train_accs'], label=name, color=colors.get(name, 'gray'), linewidth=2)
    ax2.set_title('📈 Training Accuracy', fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Test Accuracy
    ax3 = axes[0, 2]
    for name, result in successful_results.items():
        epochs = range(1, len(result['test_accs']) + 1)
        ax3.plot(epochs, result['test_accs'], label=name, color=colors.get(name, 'gray'), linewidth=2)
    ax3.set_title('🎯 Test Accuracy', fontweight='bold')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Accuracy (%)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Final Performance Comparison
    ax4 = axes[1, 0]
    model_names = list(successful_results.keys())
    best_accs = [result['best_test_acc'] for result in successful_results.values()]
    final_accs = [result['final_test_acc'] for result in successful_results.values()]
    
    x = np.arange(len(model_names))
    width = 0.35
    
    bars1 = ax4.bar(x - width/2, best_accs, width, label='Best Test Acc', alpha=0.8, color='#45B7D1')
    bars2 = ax4.bar(x + width/2, final_accs, width, label='Final Test Acc', alpha=0.8, color='#96CEB4')
    
    ax4.set_title('🏆 Final Performance Comparison', fontweight='bold')
    ax4.set_xlabel('Model')
    ax4.set_ylabel('Accuracy (%)')
    ax4.set_xticks(x)
    ax4.set_xticklabels(model_names, rotation=45, ha='right')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar in bars1:
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{height:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    for bar in bars2:
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{height:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    # Plot 5: Training Time Comparison
    ax5 = axes[1, 1]
    avg_times = [result['avg_epoch_time'] for result in successful_results.values()]
    bars = ax5.bar(model_names, avg_times, color=[colors.get(name, 'gray') for name in model_names], alpha=0.8)
    ax5.set_title('⏱️ Training Time Comparison', fontweight='bold')
    ax5.set_xlabel('Model')
    ax5.set_ylabel('Avg Time per Epoch (s)')
    ax5.set_xticklabels(model_names, rotation=45, ha='right')
    ax5.grid(True, alpha=0.3)
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax5.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{height:.1f}s', ha='center', va='bottom', fontweight='bold')
    
    # Plot 6: Parameter Count Comparison
    ax6 = axes[1, 2]
    param_counts = [baseline_params, lete_params, kan_params]
    model_labels = ['Baseline LSTM', 'LSTM + LETE', 'LSTM + KAN-MAMMOTE']
    bars = ax6.bar(model_labels, param_counts, color=['#FF6B6B', '#4ECDC4', '#45B7D1'], alpha=0.8)
    ax6.set_title('🔢 Parameter Count Comparison', fontweight='bold')
    ax6.set_xlabel('Model')
    ax6.set_ylabel('Parameters')
    ax6.set_xticklabels(model_labels, rotation=45, ha='right')
    ax6.grid(True, alpha=0.3)
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax6.text(bar.get_x() + bar.get_width()/2., height + 5000,
                f'{int(height/1000)}K', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed comparison table
    print("\n" + "="*80)
    print("📊 DETAILED COMPARISON RESULTS")
    print("="*80)
    
    print(f"{'Model':<20} {'Best Acc':<10} {'Final Acc':<10} {'Avg Time':<10} {'Parameters':<12}")
    print("-" * 80)
    
    param_map = {'Baseline LSTM': baseline_params, 'LSTM + LETE': lete_params, 'LSTM + KAN-MAMMOTE': kan_params}
    
    for name, result in successful_results.items():
        print(f"{name:<20} {result['best_test_acc']:<10.2f} {result['final_test_acc']:<10.2f} {result['avg_epoch_time']:<10.1f} {param_map[name]:<12,}")
    
    # Calculate improvements
    if 'Baseline LSTM' in successful_results:
        baseline_acc = successful_results['Baseline LSTM']['best_test_acc']
        print(f"\n🚀 PERFORMANCE IMPROVEMENTS vs Baseline:")
        print("-" * 50)
        
        for name, result in successful_results.items():
            if name != 'Baseline LSTM':
                improvement = result['best_test_acc'] - baseline_acc
                print(f"{name:<20} {improvement:+.2f}% improvement")
    
    print("\n" + "="*80)
    print("🎯 CONCLUSION")
    print("="*80)
    
    # Find best performing model
    best_model = max(successful_results.items(), key=lambda x: x[1]['best_test_acc'])
    print(f"🏆 Best performing model: {best_model[0]}")
    print(f"   Best accuracy: {best_model[1]['best_test_acc']:.2f}%")
    print(f"   Parameters: {param_map[best_model[0]]:,}")
    print(f"   Avg training time: {best_model[1]['avg_epoch_time']:.1f}s per epoch")
    
    # Efficiency analysis
    print(f"\n⚡ EFFICIENCY ANALYSIS:")
    for name, result in successful_results.items():
        params = param_map[name]
        acc = result['best_test_acc']
        time_per_epoch = result['avg_epoch_time']
        
        efficiency = acc / (params / 1000)  # Accuracy per 1K parameters
        speed_efficiency = acc / time_per_epoch  # Accuracy per second
        
        print(f"{name:<20} Acc/1K params: {efficiency:.2f}, Acc/sec: {speed_efficiency:.2f}")

print("\n✅ Analysis complete!")

## 🔍 Detailed Analysis

Let's dive deeper into the temporal modeling capabilities and examine specific aspects of each approach.

In [None]:
# ============================================================================
# 🔍 DETAILED TEMPORAL ANALYSIS
# ============================================================================

if 'LSTM + KAN-MAMMOTE' in successful_results:
    print("🔍 Performing detailed KAN-MAMMOTE temporal analysis...")
    
    # Get a batch for analysis
    kan_model.eval()
    with torch.no_grad():
        sample_batch = next(iter(test_loader))
        events, features, lengths, labels = sample_batch
        events, features, lengths, labels = events.to(device), features.to(device), lengths.to(device), labels.to(device)
        
        # Get detailed KAN-MAMMOTE information
        outputs, kan_info = kan_model(events, features, lengths)
        
        print(f"\n📊 KAN-MAMMOTE Temporal Analysis:")
        print(f"   Batch size: {events.shape[0]}")
        print(f"   Max sequence length: {events.shape[1]}")
        print(f"   Average sequence length: {lengths.float().mean():.1f}")
        
        # Analyze temporal differences
        if 'temporal_differences' in kan_info:
            temporal_diffs = kan_info['temporal_differences']
            print(f"   Temporal differences shape: {temporal_diffs.shape}")
            print(f"   Temporal differences range: [{temporal_diffs.min():.4f}, {temporal_diffs.max():.4f}]")
            print(f"   Temporal differences std: {temporal_diffs.std():.4f}")
        
        # Analyze expert usage if available
        if 'kmote_info' in kan_info and 'expert_weights' in kan_info['kmote_info']:
            expert_weights = kan_info['kmote_info']['expert_weights']
            expert_usage = torch.softmax(expert_weights, dim=-1).mean(dim=(0, 1))
            
            print(f"\n🎯 Expert Usage Analysis:")
            for i, usage in enumerate(expert_usage):
                print(f"   Expert {i}: {usage:.1%}")
            
            # Check if experts are balanced
            expert_std = expert_usage.std()
            if expert_std < 0.05:
                print(f"   ✅ Experts are well-balanced (std: {expert_std:.4f})")
            else:
                print(f"   ⚠️  Expert usage is imbalanced (std: {expert_std:.4f})")
        
        # Visualize temporal patterns for a few samples
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle('🔍 KAN-MAMMOTE Temporal Pattern Analysis', fontsize=14, fontweight='bold')
        
        # Show temporal differences for first 4 samples
        for i in range(min(4, events.shape[0])):
            ax = axes[i // 2, i % 2]
            
            seq_len = lengths[i].item()
            sample_timestamps = events[i, :seq_len].cpu().numpy()
            
            if 'temporal_differences' in kan_info:
                sample_diffs = temporal_diffs[i, :seq_len].cpu().numpy()
                
                # Plot temporal differences
                ax.plot(sample_timestamps, sample_diffs.mean(axis=1), 'b-', alpha=0.7, label='Temporal Diffs')
                ax.fill_between(sample_timestamps, 
                               sample_diffs.mean(axis=1) - sample_diffs.std(axis=1),
                               sample_diffs.mean(axis=1) + sample_diffs.std(axis=1),
                               alpha=0.3, color='blue')
            
            ax.set_title(f'Sample {i+1} (Label: {labels[i].item()}, Len: {seq_len})')
            ax.set_xlabel('Timestamp')
            ax.set_ylabel('Temporal Difference')
            ax.grid(True, alpha=0.3)
            ax.legend()
        
        plt.tight_layout()
        plt.show()

print("\n✅ Detailed analysis complete!")

## 🎯 Experiment Conclusions

### Key Findings:

1. **Performance Comparison**: 
   - Compare accuracy improvements of each embedding method
   - Analyze convergence speed and stability

2. **Efficiency Analysis**:
   - Parameter efficiency (accuracy per parameter)
   - Training time efficiency
   - Memory usage patterns

3. **Temporal Modeling Quality**:
   - KAN-MAMMOTE's temporal difference modeling
   - LETE's Fourier + Spline approach
   - Baseline's simple position embeddings

### Recommendations:

Based on the experimental results, we can determine which embedding approach provides the best balance of:
- **Accuracy**: Classification performance
- **Efficiency**: Parameters and training time
- **Robustness**: Consistent performance across different scenarios
- **Interpretability**: Understanding of temporal patterns

### Future Work:

1. **Extended Evaluation**: Test on more complex temporal datasets
2. **Ablation Studies**: Analyze individual components of KAN-MAMMOTE
3. **Hyperparameter Optimization**: Fine-tune each model for optimal performance
4. **Architectural Variants**: Explore different LSTM configurations
5. **Real-world Applications**: Apply to actual event-based vision tasks