# SNN Language Zone Demo: Prosody-Driven Attention

This notebook demonstrates the **Spiking Neural Network (SNN) Language Zone** with **Multi-Channel Prosody Attention** on:
1. **GoEmotions** (text emotion classification) - primary demo
2. **MNIST** (digit recognition) - baseline comparison

## Components Showcased
- Multi-Channel Spiking Attention (amplitude/pitch/boundary)
- Prosody-Modulated GIF Neurons
- SNN Transformer Operations
- Spike⇄Continuous Bridges

In [None]:
# Install dependencies if needed
# !pip install torch torchvision datasets transformers numpy matplotlib

In [None]:
import sys
sys.path.insert(0, '../')  # Add parent directory to path

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms

# Import SNN Language Zone components
from src.core.language_zone.prosody_attention import ProsodyAttentionBridge
from src.core.language_zone.prosody_gif import ProsodyModulatedGIF
from src.core.language_zone.full_language_zone import FullLanguageZone
from src.core.language_zone.gif_neuron import GIFNeuron
from src.core.language_zone.synapsis import Synapsis

print("✅ Imports successful")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---
# Part 1: GoEmotions Text Classification

GoEmotions is a dataset of 58k Reddit comments labeled with 28 emotions. Perfect for demonstrating prosody-driven attention!

In [None]:
# Load GoEmotions dataset
print("Loading GoEmotions dataset...")
dataset = load_dataset("google-research-datasets/go_emotions", "simplified")

# Take a small subset for quick demo
train_data = dataset['train'].select(range(1000))
test_data = dataset['test'].select(range(200))

print(f"Train samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")
print(f"\nExample:")
print(f"Text: {train_data[0]['text']}")
print(f"Labels: {train_data[0]['labels']}")

In [None]:
# Visualize emotion distribution
emotion_names = [
    'admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion',
    'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment',
    'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness',
    'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'
]

# Count labels
label_counts = [0] * 28
for sample in train_data:
    for label in sample['labels']:
        label_counts[label] += 1

plt.figure(figsize=(15, 5))
plt.bar(range(28), label_counts)
plt.xticks(range(28), emotion_names, rotation=45, ha='right')
plt.ylabel('Count')
plt.title('Emotion Distribution in Training Set')
plt.tight_layout()
plt.show()

## Prosody Extraction Demo

Let's see how prosody channels are extracted from emotional text.

In [None]:
# Initialize prosody attention bridge
prosody_bridge = ProsodyAttentionBridge(
    attention_preset='emotional',  # High sensitivity to emotional content
    k_winners=5
)

# Example texts with different emotional content
test_texts = [
    "This is AMAZING! I absolutely love it!",
    "I'm so disappointed and sad about this.",
    "The quick brown fox jumps over the lazy dog.",
    "WOW! This is incredible! Best day ever!!!"
]

for i, text in enumerate(test_texts):
    # Tokenize (simple whitespace split for demo)
    tokens = text.split()
    
    # Extract prosody channels
    amp, pitch, boundary = prosody_bridge.extract_prosody(tokens)
    
    print(f"\n{'='*60}")
    print(f"Text {i+1}: {text}")
    print(f"Tokens: {tokens}")
    print(f"Amplitude (CAPS/!): {amp}")
    print(f"Pitch (emotive/?): {pitch}")
    print(f"Boundary (punct): {boundary}")
    
    # Compute attention
    token_ids = list(range(len(tokens)))
    result = prosody_bridge.compute_attention_gains(
        token_ids=token_ids,
        token_strings=tokens
    )
    
    print(f"Salience: {result['salience']}")
    print(f"Winner indices: {result['winners_idx']}")
    print(f"Winner tokens: {[tokens[i] for i in result['winners_idx']]}")
    print(f"Attention gain (μ): {result['mu_scalar']:.3f}")

In [None]:
# Visualize prosody channels
text_idx = 0  # "This is AMAZING! I absolutely love it!"
tokens = test_texts[text_idx].split()
amp, pitch, boundary = prosody_bridge.extract_prosody(tokens)

fig, axes = plt.subplots(3, 1, figsize=(12, 6), sharex=True)

axes[0].stem(amp, basefmt=' ')
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Prosody Channels: "' + test_texts[text_idx] + '"')
axes[0].grid(True, alpha=0.3)

axes[1].stem(pitch, basefmt=' ', linefmt='C1-', markerfmt='C1o')
axes[1].set_ylabel('Pitch')
axes[1].grid(True, alpha=0.3)

axes[2].stem(boundary, basefmt=' ', linefmt='C2-', markerfmt='C2o')
axes[2].set_ylabel('Boundary')
axes[2].set_xlabel('Token Index')
axes[2].set_xticks(range(len(tokens)))
axes[2].set_xticklabels(tokens, rotation=45, ha='right')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## GIF Neuron Modulation Demo

Show how prosody modulates spiking behavior.

In [None]:
# Create GIF neurons with/without prosody modulation
gif_standard = GIFNeuron(input_dim=64, hidden_dim=128, L=16)
gif_prosody = ProsodyModulatedGIF(
    input_dim=64, 
    hidden_dim=128, 
    L=16,
    attention_modulation_strength=0.5  # Strong modulation for demo
)

# Create test input
batch_size = 1
seq_len = 20
x = torch.randn(batch_size, seq_len, 64)

# Case 1: No attention (baseline)
spikes_baseline, _ = gif_standard(x)

# Case 2: Uniform attention
attention_uniform = torch.ones(batch_size, seq_len)
spikes_uniform, _ = gif_prosody(x, attention_gains=attention_uniform)

# Case 3: High attention on some timesteps (simulating emotional content)
attention_emotional = torch.ones(batch_size, seq_len)
attention_emotional[0, 5:10] = 3.0  # High attention for "emotional" timesteps
spikes_emotional, _ = gif_prosody(x, attention_gains=attention_emotional)

# Visualize spike counts
plt.figure(figsize=(14, 4))

plt.subplot(1, 3, 1)
plt.imshow(spikes_baseline[0].T.detach().numpy(), aspect='auto', cmap='hot')
plt.colorbar(label='Spike count')
plt.title('Standard GIF (no modulation)')
plt.ylabel('Neuron Index')
plt.xlabel('Time')

plt.subplot(1, 3, 2)
plt.imshow(spikes_uniform[0].T.detach().numpy(), aspect='auto', cmap='hot')
plt.colorbar(label='Spike count')
plt.title('Prosody GIF (uniform attention)')
plt.xlabel('Time')

plt.subplot(1, 3, 3)
plt.imshow(spikes_emotional[0].T.detach().numpy(), aspect='auto', cmap='hot')
plt.colorbar(label='Spike count')
plt.title('Prosody GIF (emotional peaks)')
plt.xlabel('Time')

plt.tight_layout()
plt.show()

# Print statistics
print(f"Spike counts:")
print(f"  Baseline: {spikes_baseline.sum():.0f}")
print(f"  Uniform attention: {spikes_uniform.sum():.0f}")
print(f"  Emotional peaks: {spikes_emotional.sum():.0f}")
print(f"\nSpikes during emotional peak (t=5-10):")
print(f"  Baseline: {spikes_baseline[0, 5:10].sum():.0f}")
print(f"  Emotional: {spikes_emotional[0, 5:10].sum():.0f}")
print(f"  Ratio: {spikes_emotional[0, 5:10].sum() / (spikes_baseline[0, 5:10].sum() + 1e-8):.2f}x")

## Full Language Zone Demo on GoEmotions

Train a simple classifier using the prosody-driven SNN.

In [None]:
# Simple tokenizer (character-level for demo)
class SimpleCharTokenizer:
    def __init__(self, vocab_size=128):
        self.vocab_size = vocab_size
    
    def encode(self, text, max_len=50):
        # Convert to ASCII codes
        tokens = [min(ord(c), self.vocab_size-1) for c in text[:max_len]]
        # Pad
        tokens = tokens + [0] * (max_len - len(tokens))
        return tokens
    
    def get_token_strings(self, text, max_len=50):
        # Return words for prosody extraction
        words = text.split()[:max_len]
        return words + [''] * (max_len - len(words))

tokenizer = SimpleCharTokenizer(vocab_size=256)

# Test tokenization
sample_text = "This is amazing!"
tokens = tokenizer.encode(sample_text, max_len=20)
print(f"Sample text: {sample_text}")
print(f"Token IDs: {tokens[:len(sample_text)]}")
print(f"Decoded: {''.join([chr(t) for t in tokens if t > 0])}")

In [None]:
# Create emotion classifier
class EmotionClassifier(nn.Module):
    def __init__(self, vocab_size=256, embed_dim=64, hidden_dim=128, num_emotions=28):
        super().__init__()
        
        # Use prosody-modulated language zone
        self.language_zone = FullLanguageZone(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            hidden_dim=hidden_dim,
            attention_preset='emotional',  # High prosody sensitivity
            num_experts=4  # Small for demo
        )
        
        # Classification head
        self.classifier = nn.Linear(vocab_size, num_emotions)
    
    def forward(self, input_ids, token_strings=None):
        # Get language zone output
        logits, info = self.language_zone(input_ids, token_strings)
        
        # Pool over sequence (mean)
        pooled = logits.mean(dim=1)  # (batch, vocab_size)
        
        # Classify
        emotion_logits = self.classifier(pooled)
        
        return emotion_logits, info

model = EmotionClassifier().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Quick inference test
model.eval()
with torch.no_grad():
    # Prepare sample
    text = "This is absolutely amazing! I love it!"
    token_ids = torch.tensor([tokenizer.encode(text, max_len=30)]).to(device)
    token_strings = [tokenizer.get_token_strings(text, max_len=30)]
    
    # Forward pass
    emotion_logits, info = model(token_ids, token_strings)
    
    print(f"Input: {text}")
    print(f"\nProsody stats:")
    print(f"  Mean gain: {info['prosody_stats']['mean_gain']:.3f}")
    print(f"  Max gain: {info['prosody_stats']['max_gain']:.3f}")
    print(f"  Winner tokens: {info['attention']['winners'][0]}")
    
    # Top predicted emotions
    probs = torch.softmax(emotion_logits, dim=-1)
    top5 = torch.topk(probs[0], 5)
    print(f"\nTop 5 predicted emotions:")
    for val, idx in zip(top5.values, top5.indices):
        print(f"  {emotion_names[idx]}: {val:.3f}")

## Comparing Attention Presets

Show how different presets affect attention.

In [None]:
# Test different presets
presets = ['analytical', 'emotional', 'historical', 'streaming']
text = "WOW! This is absolutely INCREDIBLE and AMAZING!"
tokens = text.split()
token_ids = list(range(len(tokens)))

results = {}
for preset in presets:
    bridge = ProsodyAttentionBridge(attention_preset=preset, k_winners=3)
    result = bridge.compute_attention_gains(token_ids, tokens)
    results[preset] = result

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 8))
axes = axes.flatten()

for i, preset in enumerate(presets):
    sal = results[preset]['salience']
    winners = results[preset]['winners_idx']
    
    axes[i].bar(range(len(tokens)), sal, color=['red' if j in winners else 'blue' for j in range(len(tokens))])
    axes[i].set_title(f'{preset.capitalize()} (μ={results[preset]["mu_scalar"]:.2f})')
    axes[i].set_xticks(range(len(tokens)))
    axes[i].set_xticklabels(tokens, rotation=45, ha='right')
    axes[i].set_ylabel('Salience')
    axes[i].grid(True, alpha=0.3)
    
    # Annotate winners
    for w in winners:
        axes[i].text(w, sal[w] + 0.05, '★', ha='center', fontsize=16, color='red')

plt.suptitle(f'Attention Preset Comparison: "{text}"', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---
# Part 2: MNIST Baseline (Vision)

Quick demo using SNN components for digit recognition.

In [None]:
# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
mnist_test = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

# Use subset for quick demo
mnist_train_small = Subset(mnist_train, range(1000))
mnist_test_small = Subset(mnist_test, range(200))

train_loader = DataLoader(mnist_train_small, batch_size=32, shuffle=True)
test_loader = DataLoader(mnist_test_small, batch_size=32, shuffle=False)

print(f"MNIST train: {len(mnist_train_small)} samples")
print(f"MNIST test: {len(mnist_test_small)} samples")

# Visualize samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    img, label = mnist_train[i]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'Label: {label}')
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Simple SNN classifier for MNIST
class SNNMNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Flatten 28x28 image to 784
        # Use 2-layer SNN
        self.synapsis1 = Synapsis(784, 256)
        self.gif1 = GIFNeuron(256, 256, L=8)
        
        self.synapsis2 = Synapsis(256, 128)
        self.gif2 = GIFNeuron(128, 128, L=8)
        
        self.classifier = nn.Linear(128, 10)
    
    def forward(self, x):
        # x: (batch, 1, 28, 28)
        batch_size = x.shape[0]
        
        # Flatten and add time dimension
        x = x.view(batch_size, -1, 784)  # (batch, 1, 784)
        
        # Layer 1
        h1, _ = self.synapsis1(x)
        s1, _ = self.gif1(h1)
        
        # Layer 2
        h2, _ = self.synapsis2(s1)
        s2, _ = self.gif2(h2)
        
        # Pool spikes over time and classify
        pooled = s2.mean(dim=1)  # (batch, 128)
        logits = self.classifier(pooled)
        
        return logits, {'spike_rate_l1': s1.mean(), 'spike_rate_l2': s2.mean()}

mnist_model = SNNMNISTClassifier().to(device)
print(f"MNIST model parameters: {sum(p.numel() for p in mnist_model.parameters()):,}")

In [None]:
# Quick inference test
mnist_model.eval()
with torch.no_grad():
    # Get a batch
    images, labels = next(iter(test_loader))
    images, labels = images.to(device), labels.to(device)
    
    # Forward pass
    logits, info = mnist_model(images)
    preds = logits.argmax(dim=1)
    
    # Calculate accuracy
    acc = (preds == labels).float().mean()
    
    print(f"Random initialization accuracy: {acc:.2%}")
    print(f"Spike rate layer 1: {info['spike_rate_l1']:.3f}")
    print(f"Spike rate layer 2: {info['spike_rate_l2']:.3f}")
    
    # Visualize predictions
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].cpu().squeeze(), cmap='gray')
        correct = '✓' if preds[i] == labels[i] else '✗'
        ax.set_title(f'True:{labels[i]} Pred:{preds[i]} {correct}')
        ax.axis('off')
    plt.suptitle('SNN Predictions (Untrained)', fontsize=14)
    plt.tight_layout()
    plt.show()

---
# Summary

## Key Takeaways

1. **Prosody Extraction**: Amplitude/pitch/boundary channels capture emotional content from text
2. **Attention Modulation**: k-WTA selects salient tokens, modulates GIF neuron thresholds
3. **Preset Configurations**: Different presets (analytical/emotional/etc) emphasize different prosody features
4. **SNN Components**: GIF neurons + Synapsis layers provide efficient spike-based processing

## Component Status
- ✅ Multi-Channel Spiking Attention
- ✅ Prosody-Modulated GIF Neurons
- ✅ SNN Transformer Operations
- ✅ Spike↔Continuous Bridges
- ✅ Full Language Zone Integration

## Next Steps
- Full training on GoEmotions
- Liquid MoE integration (async routing)
- Energy tracking and optimization
- Multi-task learning across emotions