In [None]:
# Cell 1: Setup and imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from liquidS4_audio import LiquidS4AudioClassifier
from audio_utils import create_dataloaders
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

print("🚀 Liquid S4 Audio Classifier Training")
print("=" * 50)


In [None]:
# Cell 2: Configuration
config = {
    'batch_size': 32,        # Liquid S4 can handle larger batches
    'learning_rate': 1e-3,
    'epochs': 150,
    'patience': 15,
    'n_mels': 128,
    'num_classes': 50,
    'd_model': 64,           # Smaller than Mamba
    'n_layers': 8,
    'd_state': 64,
    'dropout': 0.1,
}

print("📋 Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")


In [None]:
# Cell 3: Create data loaders
print("📁 Creating data loaders...")
train_loader, val_loader, test_loader, num_classes = create_dataloaders(
    model_type='sequence',  # Liquid S4 uses sequence format
    batch_size=config['batch_size'],
    num_workers=2,
    augment=True,
    augment_factor=2
)

print(f"✅ Data loaded: {len(train_loader)} train, {len(val_loader)} val, {len(test_loader)} test batches")
print(f"📊 Classes: {num_classes}")

# Test data loading
sample_batch = next(iter(train_loader))
sample_data, sample_labels = sample_batch
print(f"📊 Sample batch shape: {sample_data.shape}")
print(f"📊 Sample labels: {sample_labels[:5]}")


In [None]:
# Cell 4: Create model
print("🔧 Creating Liquid S4 model...")
model = LiquidS4AudioClassifier(
    n_mels=config['n_mels'],
    num_classes=num_classes,
    d_model=config['d_model'],
    n_layers=config['n_layers'],
    d_state=config['d_state'],
    dropout=config['dropout'],
    device=device
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"✅ Liquid S4 model created: {total_params:,} parameters")

# Test forward pass
dummy_batch = next(iter(train_loader))
dummy_input, _ = dummy_batch
dummy_input = dummy_input.to(device)
with torch.no_grad():
    output = model(dummy_input)
print(f"✅ Forward pass test: {dummy_input.shape} -> {output.shape}")

# Model summary
print(f"\n📊 Model Architecture:")
print(f"  Input: [batch, seq_len, {config['n_mels']}]")
print(f"  Hidden: [batch, seq_len, {config['d_model']}]")
print(f"  Output: [batch, {num_classes}]")
print(f"  Parameters: {total_params:,}")


In [None]:
# Cell 5: Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'])

# Training tracking
train_losses = []
val_accuracies = []
best_val_acc = 0.0
patience_counter = 0

print("🚀 Starting training...")
print(f"📊 Training setup:")
print(f"  Loss: CrossEntropyLoss")
print(f"  Optimizer: AdamW (lr={config['learning_rate']}, wd=0.01)")
print(f"  Scheduler: CosineAnnealingLR")
print(f"  Early stopping: {config['patience']} epochs")


In [None]:
# Cell 6: Training loop
start_time = time.time()

for epoch in range(config['epochs']):
    # Training phase
    model.train()
    total_loss = 0.0
    
    with tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar:
        for batch_idx, (data, targets) in enumerate(pbar):
            data, targets = data.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            total_loss += loss.item()
            
            pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
    
    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    val_acc = correct / total
    val_accuracies.append(val_acc)
    
    scheduler.step()
    
    print(f'Epoch {epoch+1:3d} | Loss: {avg_train_loss:.4f} | Val Acc: {val_acc:.4f}')
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), 'best_liquid_s4_model.pth')
        print(f'💾 New best model saved! Val Acc: {val_acc:.4f}')
    else:
        patience_counter += 1
        if patience_counter >= config['patience']:
            print(f'🛑 Early stopping after {epoch+1} epochs')
            break

training_time = time.time() - start_time
print(f"⏱️ Training completed in {training_time:.2f} seconds")


In [None]:
# Cell 7: Test evaluation
print("🧪 Evaluating on test set...")
model.load_state_dict(torch.load('best_liquid_s4_model.pth'))
model.eval()

correct = 0
total = 0
test_predictions = []
test_targets = []

with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        
        test_predictions.extend(predicted.cpu().numpy())
        test_targets.extend(targets.cpu().numpy())

test_acc = correct / total
print(f'🎯 Final Test Accuracy: {test_acc:.4f}')

# Calculate per-class accuracy
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

print("\n📊 Per-class Performance:")
print(classification_report(test_targets, test_predictions, target_names=[f'Class_{i}' for i in range(num_classes)]))


In [None]:
# Cell 8: Results summary and visualization
print("\n📊 TRAINING SUMMARY")
print("=" * 50)
print(f"Model: Liquid S4 Audio Classifier")
print(f"Parameters: {total_params:,}")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
print(f"Final Test Accuracy: {test_acc:.4f}")
print(f"Training Time: {training_time:.2f} seconds")
print(f"Training completed!")

# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curve
ax1.plot(train_losses, label='Training Loss', color='blue')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True)

# Accuracy curve
ax2.plot(val_accuracies, label='Validation Accuracy', color='green')
ax2.axhline(y=best_val_acc, color='red', linestyle='--', label=f'Best: {best_val_acc:.4f}')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

# Confusion Matrix
cm = confusion_matrix(test_targets, test_predictions)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[f'C{i}' for i in range(num_classes)],
            yticklabels=[f'C{i}' for i in range(num_classes)])
plt.title('Confusion Matrix - Liquid S4 Audio Classifier')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()
