# üß† Quick Prototype - All-in-One Workspace

**Purpose:** Rapid prototyping and experimentation for the entire team

**Use this notebook to:**
- Test ideas quickly
- Experiment with different approaches
- Share code snippets
- Debug integration issues

**Team:** All members
**Phase:** Days 1-7 (MVP)

## üìö Setup & Imports

In [None]:
# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch & SNN
import torch
import torch.nn as nn
import snntorch as snn

# Data generation
import neurokit2 as nk

# Visualization
import plotly.graph_objects as go
import plotly.express as px

# Utilities
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Matplotlib config
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

## üî¨ Section 1: Quick Data Generation Test

**Owner:** CS3 / Data Engineer

Generate a few synthetic ECG samples to verify neurokit2 works.

In [None]:
# Generate a single ECG sample
ecg_normal = nk.ecg_simulate(duration=10, sampling_rate=250, heart_rate=70)
ecg_arrhythmia = nk.ecg_simulate(duration=10, sampling_rate=250, heart_rate=120)

# Plot
fig, axes = plt.subplots(2, 1, figsize=(15, 6))

axes[0].plot(ecg_normal)
axes[0].set_title('Normal ECG (70 bpm)')
axes[0].set_xlabel('Sample')
axes[0].set_ylabel('Amplitude')

axes[1].plot(ecg_arrhythmia, color='red')
axes[1].set_title('Arrhythmia ECG (120 bpm)')
axes[1].set_xlabel('Sample')
axes[1].set_ylabel('Amplitude')

plt.tight_layout()
plt.show()

print(f"‚úÖ Generated ECG with {len(ecg_normal)} samples")

## üß™ Section 2: Spike Encoding Test

**Owner:** CS2 / SNN Expert

Convert continuous signal to spike trains.

In [None]:
# Simple rate coding
def rate_encode(signal, num_steps=100, gain=10):
    """Convert signal to spike train using rate coding"""
    # Normalize signal to [0, 1]
    signal_norm = (signal - signal.min()) / (signal.max() - signal.min())
    
    # Bin signal
    bins = np.linspace(0, len(signal), num_steps+1, dtype=int)
    rates = np.array([signal_norm[bins[i]:bins[i+1]].mean() for i in range(num_steps)])
    
    # Generate spikes based on rate
    spikes = np.random.rand(num_steps) < (rates * gain)
    
    return spikes.astype(float)

# Test encoding
spike_train = rate_encode(ecg_normal, num_steps=100)

# Visualize
fig, axes = plt.subplots(2, 1, figsize=(15, 6))

# Original signal
axes[0].plot(ecg_normal)
axes[0].set_title('Original ECG Signal')
axes[0].set_ylabel('Amplitude')

# Spike train
spike_times = np.where(spike_train)[0]
axes[1].eventplot(spike_times, colors='black')
axes[1].set_title('Encoded Spike Train (Rate Coding)')
axes[1].set_xlabel('Time Step')
axes[1].set_ylabel('Neuron')
axes[1].set_ylim([0.5, 1.5])

plt.tight_layout()
plt.show()

print(f"‚úÖ Encoded to {spike_train.sum():.0f} spikes ({spike_train.sum()/len(spike_train)*100:.1f}% active)")

## üß† Section 3: Simple SNN Test

**Owner:** CS2 / SNN Expert

Create and test a basic snnTorch model.

In [None]:
# Simple SNN with LIF neurons
class SimpleSNN(nn.Module):
    def __init__(self, input_size=100, hidden=64, output=2, beta=0.9):
        super().__init__()
        
        self.fc1 = nn.Linear(input_size, hidden)
        self.lif1 = snn.Leaky(beta=beta)
        
        self.fc2 = nn.Linear(hidden, output)
        self.lif2 = snn.Leaky(beta=beta)
        
    def forward(self, x):
        # Initialize hidden states
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        # Record spikes
        spk2_rec = []
        mem2_rec = []
        
        # Process each timestep
        for step in range(x.size(0)):
            cur1 = self.fc1(x[step])
            spk1, mem1 = self.lif1(cur1, mem1)
            
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
        
        return torch.stack(spk2_rec), torch.stack(mem2_rec)

# Create model
model = SimpleSNN().to(device)
print(model)
print(f"\n‚úÖ Model has {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Test forward pass
test_input = torch.randn(100, 1, 100).to(device)  # [time_steps, batch, features]
spikes, membrane = model(test_input)

print(f"‚úÖ Forward pass successful!")
print(f"   Output shape: {spikes.shape}")
print(f"   Total spikes: {spikes.sum().item():.0f}")
print(f"   Sparsity: {(1 - spikes.sum() / spikes.numel()) * 100:.1f}%")

## üìä Section 4: Visualization Tests

**Owner:** CS4 / Deployment

Test visualization components for the demo.

In [None]:
# Interactive spike raster plot with Plotly
def plot_spike_raster(spikes, title="Spike Raster Plot"):
    """
    Create interactive spike raster plot
    spikes: [time_steps, neurons] tensor
    """
    spikes_np = spikes.detach().cpu().numpy()
    
    # Find spike times and neuron indices
    spike_times, neurons = np.where(spikes_np > 0)
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=spike_times,
        y=neurons,
        mode='markers',
        marker=dict(size=3, color='black'),
        name='Spikes'
    ))
    
    fig.update_layout(
        title=title,
        xaxis_title="Time Step",
        yaxis_title="Neuron Index",
        height=400
    )
    
    return fig

# Plot output spikes from SNN
fig = plot_spike_raster(spikes[:, 0, :], "SNN Output Spike Pattern")
fig.show()

print("‚úÖ Interactive visualization working!")

## üéØ Section 5: Integration Test

**Owner:** CS1 / Team Lead

End-to-end pipeline test.

In [None]:
def test_full_pipeline():
    """Test complete data ‚Üí model ‚Üí prediction pipeline"""
    print("Testing full pipeline...\n")
    
    # 1. Generate data
    print("1Ô∏è‚É£ Generating ECG...")
    ecg = nk.ecg_simulate(duration=10, sampling_rate=250, heart_rate=70)
    print(f"   ‚úÖ Shape: {ecg.shape}")
    
    # 2. Encode to spikes
    print("\n2Ô∏è‚É£ Encoding to spikes...")
    spikes = rate_encode(ecg, num_steps=100)
    spikes_tensor = torch.FloatTensor(spikes).unsqueeze(1).unsqueeze(2).repeat(1, 1, 100).to(device)
    print(f"   ‚úÖ Shape: {spikes_tensor.shape}")
    
    # 3. Model inference
    print("\n3Ô∏è‚É£ Running SNN inference...")
    with torch.no_grad():
        output_spikes, _ = model(spikes_tensor)
    print(f"   ‚úÖ Output shape: {output_spikes.shape}")
    
    # 4. Get prediction
    print("\n4Ô∏è‚É£ Making prediction...")
    spike_counts = output_spikes.sum(dim=0)  # Sum over time
    prediction = spike_counts.argmax(dim=1)
    confidence = torch.softmax(spike_counts, dim=1)
    print(f"   ‚úÖ Prediction: {prediction.item()} (confidence: {confidence.max().item():.2%})")
    
    print("\n" + "="*50)
    print("‚úÖ FULL PIPELINE WORKING!")
    print("="*50)

test_full_pipeline()

## üìù Notes & TODO

### What's Working:
- ‚úÖ Data generation (neurokit2)
- ‚úÖ Spike encoding (rate coding)
- ‚úÖ Basic SNN model (snnTorch)
- ‚úÖ Visualization (plotly)
- ‚úÖ End-to-end pipeline

### Next Steps:
1. **CS3:** Implement better spike encoding (temporal/latency)
2. **CS2:** Train model on real dataset
3. **CS2:** Implement STDP learning
4. **CS4:** Build Flask API around this
5. **Bio:** Validate predictions make medical sense
6. **CS1:** Refactor into src/ modules

### Known Issues:
- Model not trained yet (random predictions)
- Simple rate encoding (try temporal encoding)
- No data augmentation
- No clinical validation

### Team Communication:
**Add your notes here for team members!**