In [15]:
# %% [markdown]
# # Continuous-Depth Transformers with Learned Control Dynamics
# ## Recreation & Falsification Experiments
#
# This notebook implements the hybrid ODE-Transformer architecture.
# It includes specific "Popperian" patches to verify:
# 1. The explicit injection mechanism of the control signal `u`.
# 2. The mathematical continuity of the learned dynamics (Fixed vs. Adaptive check).

# %% [code]
# 1. INSTALL DEPENDENCIES
import sys
!{sys.executable} -m pip install torchdiffeq

import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchdiffeq import odeint_adjoint as odeint

# Configuration based on your paper (Section 4)
CONFIG = {
    'd_model': 256,
    'n_heads': 4,
    'n_layers': 6,          # Total effective layers
    'ode_layer_start': 2,   # Replace layers 2 & 3
    'ode_layer_end': 4,
    'vocab_size': 33278,    # WikiText-2 vocab size approx
    'seq_len': 32,
    'control_dim': 4,       # Low-dimensional control u
    'batch_size': 32,
    'lr': 0.001,
    'epochs': 1             # Short run for demonstration
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on: {DEVICE}")

# %% [markdown]
# ## 2. Architecture Definitions (With Popperian Patches)

# %% [code]
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, max_len=512):
        super().__init__()
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        self.n_heads = n_heads
        self.d_model = d_model
        self.register_buffer("bias", torch.tril(torch.ones(max_len, max_len))
                                     .view(1, 1, max_len, max_len))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.d_model, dim=2)
        k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

class MLP(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )
    def forward(self, x): return self.net(x)

class StandardBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# --- THE PATCH: Explicit Control Injection ---
class ControlledVectorField(nn.Module):
    """
    Explicitly defines F(h, t, u) = MLP(Concat(h, u)).
    This resolves the "Black Box Injection" ambiguity.
    """
    def __init__(self, d_model, control_dim):
        super().__init__()
        # Explicit injection: Concatenate hidden state + control vector
        input_dim = d_model + control_dim

        self.net = nn.Sequential(
            nn.Linear(input_dim, d_model * 2),
            nn.Softplus(), # Smooth activation for better ODE properties
            nn.Linear(d_model * 2, d_model)
        )

        # Stability Initialization: Initialize output to near-zero
        # This keeps the initial vector field 'flat', aiding stability.
        nn.init.normal_(self.net[-1].weight, mean=0, std=0.01)
        nn.init.constant_(self.net[-1].bias, 0)

    def forward(self, t, h, u):
        # h: [Batch, Seq, Dim]
        # u: [Batch, Control_Dim]

        # Broadcast u across the sequence dimension
        seq_len = h.shape[1]
        u_expanded = u.unsqueeze(1).expand(-1, seq_len, -1)

        # Concat and project
        state = torch.cat([h, u_expanded], dim=-1)
        return self.net(state)

# --- THE PATCH: Hybrid Block with Fixed/Adaptive Toggle ---
class ContinuousDepthBlock(nn.Module):
    def __init__(self, d_model, control_dim):
        super().__init__()
        self.vector_field = ControlledVectorField(d_model, control_dim)

        # Learned scale parameter initialized to 0.1 (Section 3.3)
        self.alpha = nn.Parameter(torch.tensor(0.1))

        # Training defaults (Fixed Step)
        self.train_method = 'euler'
        self.train_options = {'step_size': 0.25} # 4 steps for t=[0,1]

    def forward(self, h, u, use_adaptive=False):
        # Wrapper to bind 'u' so solver sees f(t, h)
        def func(t, x):
            return self.alpha * self.vector_field(t, x, u)

        integration_times = torch.tensor([0, 1]).float().to(h.device)

        # Switch between Fixed (Train) and Adaptive (Analysis)
        if use_adaptive:
            method = 'dopri5'
            options = {}
        else:
            method = self.train_method
            options = self.train_options

        # ODESolve
        # Collect all parameters that `func` depends on
        adjoint_params = tuple(self.vector_field.parameters()) + (self.alpha,)
        state = odeint(func, h, integration_times, method=method, options=options, adjoint_params=adjoint_params)
        return state[1]

class HybridTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed = nn.Embedding(config['vocab_size'], config['d_model'])
        self.pos_embed = nn.Parameter(torch.zeros(1, config['seq_len'], config['d_model']))

        # 1. Discrete Early Layers (0-1)
        self.early_layers = nn.ModuleList([
            StandardBlock(config['d_model'], config['n_heads'])
            for _ in range(config['ode_layer_start'])
        ])

        # 2. Continuous ODE Block (Replaces 2-3)
        self.ode_block = ContinuousDepthBlock(config['d_model'], config['control_dim'])

        # 3. Discrete Late Layers (4-5)
        # Note: We subtract the layers we 'skipped' to keep indices aligned
        remaining_layers = config['n_layers'] - config['ode_layer_end']
        self.late_layers = nn.ModuleList([
            StandardBlock(config['d_model'], config['n_heads'])
            for _ in range(remaining_layers)
        ])

        self.ln_f = nn.LayerNorm(config['d_model'])
        self.head = nn.Linear(config['d_model'], config['vocab_size'], bias=False)

    def forward(self, idx, u, use_adaptive=False):
        B, T = idx.size()
        x = self.embed(idx) + self.pos_embed[:, :T, :]

        # Early Discrete
        for layer in self.early_layers:
            x = layer(x)

        # Continuous Control
        x = self.ode_block(x, u, use_adaptive=use_adaptive)

        # Late Discrete
        for layer in self.late_layers:
            x = layer(x)

        logits = self.head(self.ln_f(x))
        return logits

# %% [markdown]
# ## 3. Data & Training Setup

# %% [code]
# Dummy Data Generator (Replacing WikiText download for instant runnability)
# In real exp, replace this with 'datasets' load
def get_batch(config):
    data = torch.randint(0, config['vocab_size'], (config['batch_size'], config['seq_len'] + 1)).to(DEVICE)
    x = data[:, :-1]
    y = data[:, 1:]

    # Random control signal for training
    # We want the model to learn to be robust to ANY u
    u = torch.randn(config['batch_size'], config['control_dim']).to(DEVICE)
    return x, y, u

model = HybridTransformer(CONFIG).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'])

print(f"Model Parameters: {sum(p.numel() for p in model.parameters())}")
print("Starting Training (Fixed Euler Steps)...")

# Simple Training Loop
model.train()
start_time = time.time()

for step in range(50): # Small steps for demo
    x, y, u = get_batch(CONFIG)

    # Forward pass (Fixed Steps)
    logits = model(x, u, use_adaptive=False)

    # Fix: Use .reshape() instead of .view() for non-contiguous tensors
    loss = F.cross_entropy(logits.reshape(-1, CONFIG['vocab_size']), y.reshape(-1))

    optimizer.zero_grad()
    loss.backward()

    # Gradient clipping (standard for transformers)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % 10 == 0:
        print(f"Step {step} | Loss: {loss.item():.4f}")

print(f"Training finished in {time.time() - start_time:.2f}s")

# %% [markdown]
# ## 4. The Popperian Falsification Test
#
# We now test if the model is a "ResNet in disguise" (failed hypothesis) or a "Continuous Flow" (confirmed hypothesis).
# We compare the output of the **Fixed 4-Step Solver** (used in training) vs. an **Adaptive Dopri5 Solver**.

# %% [code]
def validate_continuity_hypothesis(model, config):
    print("\n--- RUNNING POPPERIAN CONTINUITY TEST ---")
    model.eval()

    # Get a probe batch
    x, _, u = get_batch(config)

    with torch.no_grad():
        # 1. Run Fixed Step (What it learned)
        logits_fixed = model(x, u, use_adaptive=False)

        # 2. Run Adaptive Step (The theoretical limit)
        # If the dynamics are smooth/continuous, this should yield a very similar result.
        # If the dynamics are discrete/overfit, this will diverge.
        logits_adaptive = model(x, u, use_adaptive=True)

        # Measure divergence in the hidden space (before readout)
        # We access the block output directly for cleaner measurement
        # Re-running just the block flow:
        emb = model.embed(x) + model.pos_embed[:, :config['seq_len'], :]
        for l in model.early_layers: emb = l(emb)

        h_fixed = model.ode_block(emb, u, use_adaptive=False)
        h_adaptive = model.ode_block(emb, u, use_adaptive=True)

        # Relative Error
        diff = (h_fixed - h_adaptive).norm()
        base = h_adaptive.norm()
        rel_error = (diff / base).item()

    print(f"Fixed (4-step) vs Adaptive (Dopri5) Divergence: {rel_error:.4%}")

    threshold = 0.10 # 10% tolerance
    if rel_error < threshold:
        print(f"PASS: Divergence < {threshold*100}%. The model approximates a continuous vector field.")
        print("Verdict: The 'Continuous' claim is robust.")
    else:
        print(f"FAIL: Divergence > {threshold*100}%. The model is effectively a discrete ResNet.")
        print("Verdict: The 'Continuous' claim is falsified. Needs 'Consistency Loss' in training.")

# Run the test
validate_continuity_hypothesis(model, CONFIG)

Running on: cuda
Model Parameters: 20471041
Starting Training (Fixed Euler Steps)...
Step 0 | Loss: 10.5640
Step 10 | Loss: 10.5746
Step 20 | Loss: 10.5987
Step 30 | Loss: 10.5581
Step 40 | Loss: 10.6220
Training finished in 2.02s

--- RUNNING POPPERIAN CONTINUITY TEST ---
Fixed (4-step) vs Adaptive (Dopri5) Divergence: 0.0674%
PASS: Divergence < 10.0%. The model approximates a continuous vector field.
Verdict: The 'Continuous' claim is robust.


In [16]:
# Check the learned output scale
final_alpha = model.ode_block.alpha.item()
print(f"Final learned alpha: {final_alpha:.6f}")

# Interpretation logic
if final_alpha < 0.05:
    print("Result: The model suppressed the dynamics (closer to Identity).")
elif final_alpha > 0.15:
    print("Result: The model amplified the dynamics (needed more 'time').")
else:
    print("Result: The model maintained the stable regime (close to initialization).")

Final learned alpha: 0.068112
Result: The model maintained the stable regime (close to initialization).


In [17]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np

def probe_decision_dynamics(model, config):
    print("--- Probing Representation Dynamics ---")
    # Get a batch of data with mixed sentiment targets
    x, y, u = get_batch(config) # u should be mixed +1/-1

    # We need to extract states at specific depths.
    # We'll use the adaptive solver to record the trajectory.
    times = torch.tensor([0, 0.5, 0.67, 0.8, 1.0]).float().to(DEVICE)

    with torch.no_grad():
        # 1. Embed and Early Layers
        emb = model.embed(x) + model.pos_embed[:, :config['seq_len'], :]
        for l in model.early_layers: emb = l(emb)

        # 2. Run ODE Solver explicitly asking for evaluation at specific times
        # Note: We need to modify the forward pass slightly or just call odeint directly here
        def func(t, h): return model.ode_block.alpha * model.ode_block.vector_field(t, h, u)

        # Explicitly define adjoint_params for func
        adjoint_params = tuple(model.ode_block.vector_field.parameters()) + (model.ode_block.alpha,)

        # trajectory shape: [Times, Batch, Seq, Dim]
        trajectory = odeint(func, emb, times, method='dopri5', adjoint_params=adjoint_params)

    # Train a simple classifier at each depth
    # Target: 1 if u > 0 (Positive), 0 if u < 0 (Negative)
    targets = (u[:, 0] > 0).cpu().numpy().astype(int)

    results = {}
    for i, t_val in enumerate(times):
        # Pool output (e.g., mean over sequence) for classification
        states = trajectory[i].mean(dim=1).cpu().numpy()

        # Simple Logic: Train on half, test on half
        split = len(states) // 2
        clf = LogisticRegression(max_iter=1000).fit(states[:split], targets[:split])
        acc = clf.score(states[split:], targets[split:])

        print(f"Depth {t_val:.2f}: Linear Separation Accuracy = {acc:.1%}")
        results[t_val.item()] = acc

# Run it
probe_decision_dynamics(model, CONFIG)

--- Probing Representation Dynamics ---
Depth 0.00: Linear Separation Accuracy = 62.5%
Depth 0.50: Linear Separation Accuracy = 62.5%
Depth 0.67: Linear Separation Accuracy = 62.5%
Depth 0.80: Linear Separation Accuracy = 62.5%
Depth 1.00: Linear Separation Accuracy = 62.5%
