# ODE Flow Block: Gradient Flow Validation

This notebook validates that neural ODE blocks can replace transformer residual blocks without destroying gradient flow.

**Experiment**: Train two tiny language models (6 layers, 256 dim) on WikiText-2:
1. **Baseline**: Standard transformer
2. **Hybrid**: Layers 3-4 replaced with ODE flow module

**Success criteria**:
- Gradients remain finite and non-vanishing
- Loss decreases comparably to baseline
- ODE module gradients are well-behaved

Runtime: ~10-15 minutes on T4 GPU

In [None]:
# Install dependencies
!pip install torchdiffeq datasets tiktoken -q
print("Dependencies installed.")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint
import math
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Model Components

In [None]:
class CausalSelfAttention(nn.Module):
    """Standard causal self-attention."""
    def __init__(self, d_model, n_heads, max_seq_len=512, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.tril(torch.ones(max_seq_len, max_seq_len)))
        
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        att = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        att = att.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        
        y = (att @ v).transpose(1, 2).reshape(B, T, C)
        return self.proj(y)


class MLP(nn.Module):
    """Standard transformer MLP."""
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, 4 * d_model)
        self.fc2 = nn.Linear(4 * d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.dropout(self.fc2(F.gelu(self.fc1(x))))


class TransformerBlock(nn.Module):
    """Standard transformer block."""
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, dropout=dropout)
        
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [None]:
class ODEFunc(nn.Module):
    """Vector field for ODE flow: dH/dt = F(H, t, u)"""
    def __init__(self, d_model, n_heads, control_dim=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.control_dim = control_dim
        
        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
        
        # Control signal embedding
        self.control_embed = nn.Linear(control_dim, d_model)
        
        # Core dynamics (simplified: attention + MLP)
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, dropout=dropout)
        
        # Scale output to keep dynamics stable
        self.output_scale = nn.Parameter(torch.tensor(0.1))
        
        # Will be set before each forward pass
        self.control = None
        
    def forward(self, t, x):
        # t: scalar, x: (B, T, D)
        B, T, D = x.shape
        
        # Time conditioning
        t_emb = self.time_embed(t.view(1, 1).expand(B, 1)).unsqueeze(1)  # (B, 1, D)
        
        # Control conditioning
        if self.control is not None:
            c_emb = self.control_embed(self.control).unsqueeze(1)  # (B, 1, D)
        else:
            c_emb = 0
            
        # Condition the input
        h = x + t_emb + c_emb
        
        # Compute vector field
        dh = self.attn(self.ln1(h)) + self.mlp(self.ln2(h))
        
        return self.output_scale * dh


class ODEFlowBlock(nn.Module):
    """Replaces k transformer blocks with continuous ODE flow."""
    def __init__(self, d_model, n_heads, control_dim=4, n_steps=4, dropout=0.1):
        super().__init__()
        self.func = ODEFunc(d_model, n_heads, control_dim, dropout)
        self.n_steps = n_steps
        self.register_buffer('integration_times', torch.linspace(0, 1, n_steps + 1))
        
    def forward(self, x, control=None):
        self.func.control = control
        # Use fixed-step solver for predictable compute
        out = odeint(self.func, x, self.integration_times, method='euler')
        return out[-1]  # Return final state

In [None]:
class BaselineTransformer(nn.Module):
    """Standard transformer LM."""
    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=6, max_seq_len=128, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.max_seq_len = max_seq_len
        
    def forward(self, idx):
        B, T = idx.shape
        tok = self.tok_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.head(x)


class HybridODETransformer(nn.Module):
    """Transformer with middle layers replaced by ODE flow."""
    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=6, 
                 ode_start=2, ode_end=4, control_dim=4, n_steps=4,
                 max_seq_len=128, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        
        # Early layers (standard)
        self.early_blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, dropout) 
                                           for _ in range(ode_start)])
        
        # ODE flow block (replaces layers ode_start to ode_end)
        self.ode_block = ODEFlowBlock(d_model, n_heads, control_dim, n_steps, dropout)
        
        # Late layers (standard)
        self.late_blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, dropout) 
                                          for _ in range(n_layers - ode_end)])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.max_seq_len = max_seq_len
        self.control_dim = control_dim
        
    def forward(self, idx, control=None):
        B, T = idx.shape
        tok = self.tok_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos
        
        # Early layers
        for block in self.early_blocks:
            x = block(x)
            
        # ODE flow
        x = self.ode_block(x, control)
        
        # Late layers
        for block in self.late_blocks:
            x = block(x)
            
        x = self.ln_f(x)
        return self.head(x)

## Data Loading

In [None]:
from datasets import load_dataset
import tiktoken

# Load WikiText-2
print("Loading WikiText-2...")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')

# Tokenize with GPT-2 tokenizer
enc = tiktoken.get_encoding('gpt2')
text = '\n'.join([x['text'] for x in dataset if x['text'].strip()])
tokens = enc.encode(text)
tokens = torch.tensor(tokens, dtype=torch.long)
print(f"Total tokens: {len(tokens):,}")

vocab_size = enc.n_vocab
print(f"Vocab size: {vocab_size:,}")

In [None]:
def get_batch(tokens, batch_size, seq_len, device):
    """Get a random batch of sequences."""
    ix = torch.randint(len(tokens) - seq_len - 1, (batch_size,))
    x = torch.stack([tokens[i:i+seq_len] for i in ix]).to(device)
    y = torch.stack([tokens[i+1:i+seq_len+1] for i in ix]).to(device)
    return x, y

## Training Loop

In [None]:
def compute_gradient_stats(model):
    """Compute gradient statistics for monitoring."""
    stats = {}
    total_norm = 0.0
    for name, param in model.named_parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2).item()
            total_norm += param_norm ** 2
            # Track specific components
            if 'ode' in name.lower():
                stats.setdefault('ode_grad_norm', 0.0)
                stats['ode_grad_norm'] += param_norm ** 2
    stats['total_grad_norm'] = total_norm ** 0.5
    if 'ode_grad_norm' in stats:
        stats['ode_grad_norm'] = stats['ode_grad_norm'] ** 0.5
    return stats


def train_model(model, tokens, n_steps=500, batch_size=32, seq_len=64, lr=3e-4, log_interval=50):
    """Train model and return metrics."""
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    metrics = defaultdict(list)
    
    model.train()
    for step in range(n_steps):
        x, y = get_batch(tokens, batch_size, seq_len, device)
        
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Compute gradient stats before optimizer step
        grad_stats = compute_gradient_stats(model)
        
        optimizer.step()
        
        # Log metrics
        metrics['loss'].append(loss.item())
        metrics['grad_norm'].append(grad_stats['total_grad_norm'])
        if 'ode_grad_norm' in grad_stats:
            metrics['ode_grad_norm'].append(grad_stats['ode_grad_norm'])
        
        if step % log_interval == 0:
            ode_info = f", ODE grad: {grad_stats.get('ode_grad_norm', 0):.4f}" if 'ode_grad_norm' in grad_stats else ""
            print(f"Step {step:4d} | Loss: {loss.item():.4f} | Grad norm: {grad_stats['total_grad_norm']:.4f}{ode_info}")
    
    return metrics

## Run Experiment

In [None]:
# Configuration
config = {
    'vocab_size': vocab_size,
    'd_model': 256,
    'n_heads': 4,
    'n_layers': 6,
    'max_seq_len': 128,
    'dropout': 0.1
}

train_config = {
    'n_steps': 500,
    'batch_size': 32,
    'seq_len': 64,
    'lr': 3e-4
}

print("="*60)
print("EXPERIMENT: Gradient Flow Validation")
print("="*60)
print(f"Model config: {config['n_layers']} layers, d={config['d_model']}, heads={config['n_heads']}")
print(f"Training: {train_config['n_steps']} steps, batch={train_config['batch_size']}, seq_len={train_config['seq_len']}")
print()

In [None]:
# Train baseline
print("="*60)
print("Training BASELINE transformer...")
print("="*60)

torch.manual_seed(42)
baseline = BaselineTransformer(**config)
n_params_baseline = sum(p.numel() for p in baseline.parameters())
print(f"Parameters: {n_params_baseline:,}")
print()

baseline_metrics = train_model(baseline, tokens, **train_config)

In [None]:
# Train hybrid ODE model
print("\n" + "="*60)
print("Training HYBRID ODE transformer...")
print("(Layers 2-4 replaced with ODE flow, 4 Euler steps)")
print("="*60)

torch.manual_seed(42)
hybrid = HybridODETransformer(
    **config,
    ode_start=2,  # Replace layers 2-4
    ode_end=4,
    control_dim=4,
    n_steps=4  # 4 Euler steps
)
n_params_hybrid = sum(p.numel() for p in hybrid.parameters())
print(f"Parameters: {n_params_hybrid:,}")
print()

hybrid_metrics = train_model(hybrid, tokens, **train_config)

## Results Visualization

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Smooth curves for visualization
def smooth(x, window=20):
    return np.convolve(x, np.ones(window)/window, mode='valid')

# Loss curves
ax = axes[0]
ax.plot(smooth(baseline_metrics['loss']), label='Baseline', alpha=0.8)
ax.plot(smooth(hybrid_metrics['loss']), label='Hybrid ODE', alpha=0.8)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Gradient norms
ax = axes[1]
ax.plot(smooth(baseline_metrics['grad_norm']), label='Baseline', alpha=0.8)
ax.plot(smooth(hybrid_metrics['grad_norm']), label='Hybrid ODE', alpha=0.8)
ax.set_xlabel('Step')
ax.set_ylabel('Gradient Norm')
ax.set_title('Total Gradient Norm')
ax.legend()
ax.grid(True, alpha=0.3)

# ODE-specific gradients
ax = axes[2]
if hybrid_metrics['ode_grad_norm']:
    ax.plot(smooth(hybrid_metrics['ode_grad_norm']), color='C1', alpha=0.8)
    ax.set_xlabel('Step')
    ax.set_ylabel('ODE Block Gradient Norm')
    ax.set_title('ODE Block Gradient Flow')
    ax.grid(True, alpha=0.3)
    
    # Add annotation about gradient health
    mean_grad = np.mean(hybrid_metrics['ode_grad_norm'])
    std_grad = np.std(hybrid_metrics['ode_grad_norm'])
    ax.axhline(mean_grad, color='red', linestyle='--', alpha=0.5, label=f'Mean: {mean_grad:.3f}')
    ax.legend()

plt.tight_layout()
plt.savefig('gradient_flow_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nResults saved to 'gradient_flow_results.png'")

In [None]:
# Summary statistics
print("\n" + "="*60)
print("SUMMARY")
print("="*60)

print(f"\nParameter counts:")
print(f"  Baseline: {n_params_baseline:,}")
print(f"  Hybrid:   {n_params_hybrid:,} ({100*n_params_hybrid/n_params_baseline:.1f}% of baseline)")

print(f"\nFinal loss (last 50 steps avg):")
print(f"  Baseline: {np.mean(baseline_metrics['loss'][-50:]):.4f}")
print(f"  Hybrid:   {np.mean(hybrid_metrics['loss'][-50:]):.4f}")

print(f"\nGradient statistics:")
print(f"  Baseline grad norm (mean ± std): {np.mean(baseline_metrics['grad_norm']):.4f} ± {np.std(baseline_metrics['grad_norm']):.4f}")
print(f"  Hybrid grad norm (mean ± std):   {np.mean(hybrid_metrics['grad_norm']):.4f} ± {np.std(hybrid_metrics['grad_norm']):.4f}")

if hybrid_metrics['ode_grad_norm']:
    print(f"  ODE block grad norm (mean ± std): {np.mean(hybrid_metrics['ode_grad_norm']):.4f} ± {np.std(hybrid_metrics['ode_grad_norm']):.4f}")
    
    # Check for vanishing/exploding gradients
    ode_grads = np.array(hybrid_metrics['ode_grad_norm'])
    vanishing = np.sum(ode_grads < 1e-6)
    exploding = np.sum(ode_grads > 100)
    print(f"\n  Vanishing gradient steps (<1e-6): {vanishing}")
    print(f"  Exploding gradient steps (>100):  {exploding}")

print("\n" + "="*60)
print("CONCLUSION")
print("="*60)
loss_ratio = np.mean(hybrid_metrics['loss'][-50:]) / np.mean(baseline_metrics['loss'][-50:])
if loss_ratio < 1.1 and np.mean(hybrid_metrics['ode_grad_norm']) > 0.01:
    print("✓ ODE flow block maintains healthy gradient flow")
    print("✓ Training loss comparable to baseline")
    print("✓ No evidence of vanishing/exploding gradients in ODE block")
else:
    print("⚠ Results require further investigation")

## Bonus: Control Signal Effect (Quick Test)

In [None]:
# Quick test: does the control signal change outputs?
hybrid.eval()

# Get a test sequence
test_x, _ = get_batch(tokens, 1, 32, device)

with torch.no_grad():
    # No control
    logits_base = hybrid(test_x, control=None)
    
    # With control signal
    control = torch.randn(1, 4, device=device)
    logits_ctrl = hybrid(test_x, control=control)
    
    # Measure difference
    diff = (logits_ctrl - logits_base).abs().mean().item()
    print(f"Mean absolute logit difference with control signal: {diff:.4f}")
    print(f"(Nonzero difference confirms control signal affects outputs)")