# Continuous Transformer
What if we take the Transformer to its continuous limit?

In [1]:
# Model settings
DIM = 1024
DEPTH = 6
VOCAB_SIZE = 256
USE_CHECKPOINT = True  # Gradient checkpointing

# Training settings
BASE_LR = 3e-4
WEIGHT_DECAY = 0.01
BATCH_SIZE = 1
SEQ_LENGTH = 8192 # Character level
GRAD_CLIP = 1.0

# Checkpoint settings
CHECKPOINT_EVERY = 1000
PRINT_EVERY = 100

# Generation settings
GEN_LENGTH = 200
GEN_TEMPERATURE = 0.8
GEN_TOP_P = 0.9

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import os
import urllib.request
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
import mmap

In [3]:
def prepare_tinystories_dataset(cache_dir='~/.cache/continuous_transformer'):
    """
    Download TinyStories dataset if needed.
    Returns path to the text file - no preprocessing needed!
    """
    cache_dir = os.path.expanduser(cache_dir)
    os.makedirs(cache_dir, exist_ok=True)
    
    train_path = os.path.join(cache_dir, 'tinystories_train.txt')
    
    if not os.path.exists(train_path):
        print("Downloading TinyStories training data...")
        url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt"
        urllib.request.urlretrieve(url, train_path)
        print(f"Downloaded to {train_path}")
    
    file_size = os.path.getsize(train_path)
    print(f"Using dataset: {train_path} ({file_size/1e6:.1f}MB)")
    
    return train_path


class TinyStoriesDataset(Dataset):
    def __init__(self, data_path: str, seq_length: int, stride: int = 512):
        self.data_path = data_path
        self.seq_length = seq_length
        self.stride = stride
        
        # Get file size without loading into memory
        self.file_size = os.path.getsize(data_path)
        
        # Calculate dataset length with stride to reduce overlapping sequences
        self.length = (self.file_size - seq_length - 1) // stride
        
        # Lazy initialization - file opened on first access per worker
        self._file = None
        self._mmap = None
        
    def _ensure_mmap(self):
        """Lazy open file handle - called once per worker process"""
        if self._mmap is None:
            self._file = open(self.data_path, 'rb')  # Binary mode for mmap
            self._mmap = mmap.mmap(self._file.fileno(), 0, access=mmap.ACCESS_READ)
        return self._mmap
    
    def __len__(self) -> int:
        return self.length
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Stream a single sequence from disk
        
        Args:
            idx: Sequence index
            
        Returns:
            x: Input tokens [L]
            y: Target tokens [L]
        """
        # Get mmap (opens file if this is first call in this worker)
        mm = self._ensure_mmap()
        
        # Calculate file position with stride
        pos = idx * self.stride
        
        # Read only what we need from disk
        mm.seek(pos)
        chunk = mm.read(self.seq_length + 1)
        
        # Decode and filter to ASCII
        try:
            text = chunk.decode('utf-8', errors='ignore')
        except:
            text = chunk.decode('latin-1', errors='ignore')
        
        # Keep only ASCII characters
        text = ''.join(c for c in text if ord(c) < 128)
        
        # Pad if needed (near end of file)
        if len(text) < self.seq_length + 1:
            text = text + ' ' * (self.seq_length + 1 - len(text))
        
        # Convert to tokens
        x = torch.tensor([ord(c) for c in text[:self.seq_length]], dtype=torch.long)
        y = torch.tensor([ord(c) for c in text[1:self.seq_length + 1]], dtype=torch.long)
        
        return x, y
    
    def __del__(self):
        """Clean up file handles"""
        if self._mmap is not None:
            self._mmap.close()
        if self._file is not None:
            self._file.close()

In [4]:
# Setup dataset
data_path = prepare_tinystories_dataset()

Using dataset: /home/midori/.cache/continuous_transformer/tinystories_train.txt (2227.8MB)


In [5]:
def hippo_freqs(dim: int) -> torch.Tensor:
    """
    Compute HiPPO-LegS frequency initialization for optimal history reconstruction.
    
    Args:
        dim: Hidden dimension size
        
    Returns:
        Frequency tensor of shape [dim]
    """
    n = torch.arange(dim)
    freqs = (2 * n + 1) ** 0.5
    freqs = freqs / freqs.max()
    return torch.exp(-freqs * np.log(10_000))

def exact_projection(z: torch.Tensor, target_mean: float = 0.0, target_std: float = 1.0) -> torch.Tensor:
    """
    Exact closed-form projection onto stable manifold S1(σ₁, σ₂).
    Geodesic projection solving the normalization constraint analytically.
    
    Args:
        z: Input tensor of shape [*, D]
        target_mean: Target mean σ₁
        target_std: Target standard deviation σ₂
        
    Returns:
        Projected tensor of shape [*, D]
    """
    mean = z.mean(dim=-1, keepdim=True)  # α(x)
    var = ((z - mean) ** 2).mean(dim=-1, keepdim=True)  # β(x)
    z_projected = (z - mean) / (torch.sqrt(var) + 1e-6) * target_std + target_mean
    return z_projected

In [6]:
class SpectralField(nn.Module):
    """
    Recursive operator combining spatial, interaction, and temporal transformations.
    Implements: z_{t+1} = Π(z_t + dt · F[z_t])
    where F is a composition of rational filtering, gating, and temporal convolution.
    """
    def __init__(self, dim: int):
        super().__init__()
        
        # Rational field parameters: R(z) = F·z / (|G·z| + ε)
        self.F = nn.Parameter(torch.complex(torch.randn(dim), torch.randn(dim)) * 0.02)
        self.G = nn.Parameter(torch.complex(torch.randn(dim), torch.randn(dim)) * 0.02)
        
        # Gating parameters: σ(z) = tanh(α·|z| + β·φ)
        self.alpha = nn.Parameter(torch.ones(dim) * 0.5)  # Content coefficient
        self.beta = nn.Parameter(torch.zeros(dim))  # Position coefficient
        
        # Temporal kernel parameters: K(t) = exp((-γ + iω)t)
        self.log_gamma = nn.Parameter(torch.randn(dim) * 0.5 - 2.0)  # Decay rate
        self.omega = nn.Parameter(hippo_freqs(dim))  # Base frequencies
        self.omega_mod = nn.Parameter(torch.randn(dim) * 0.02)  # Frequency modulation
        
        self.dt = nn.Parameter(torch.tensor(0.1))  # Integration timestep

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Args:
            z: Input tensor [B, L, D]
            
        Returns:
            Output tensor [B, L, D]
        """
        B, L, D = z.shape
        
        # Spatial transformation in frequency domain
        Z_x = torch.fft.fft(z, dim=-1)  # [B, L, D]
        denom = torch.abs(Z_x * self.G) + 0.1
        Z_filtered = Z_x * self.F / denom
        z_spatial = torch.fft.ifft(Z_filtered, dim=-1)  # [B, L, D]
        
        # Position-aware gating
        t = torch.arange(L, device=z.device)  # [L]
        phi = (t.view(L, 1) * self.omega.view(1, D) + math.pi) % (2*math.pi) - math.pi  # [L, D]
        intensity = z_spatial.abs()  # [B, L, D]
        gate = torch.tanh(intensity * self.alpha + phi * self.beta)  # [B, L, D]
        z_gated = z_spatial * gate  # [B, L, D]
        
        # Temporal convolution via spectral multiplication
        Z_t = torch.fft.fft(z_gated, dim=1)  # [B, L, D]
        omega_total = self.omega + torch.sin(self.omega_mod) * 0.1  # [D]
        lambda_complex = torch.complex(-torch.exp(self.log_gamma), omega_total)  # [D]
        kernel = torch.exp(lambda_complex * t.unsqueeze(1))  # [L, D]
        K_freq = torch.fft.fft(kernel, dim=0, n=L)  # [L, D]
        dz = torch.fft.ifft(Z_t * K_freq.unsqueeze(0), dim=1)  # [B, L, D]
        
        # Euler integration with manifold projection
        z_next = z + dz * self.dt  # [B, L, D]
        z_next = exact_projection(z_next)  # [B, L, D]
        
        return z_next

In [7]:
class ContinuousTransformer(nn.Module):
    """
    Continuous-time transformer with recursive operator application.
    """
    def __init__(self, dim: int = 1024, depth: int = 6, vocab: int = 256, use_checkpoint: bool = True):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        
        self.embed = nn.Embedding(vocab, dim)
        self.to_complex = nn.Linear(dim, dim * 2)
        self.operator = SpectralField(dim)
        self.depth_emb = nn.Parameter(torch.randn(depth, dim * 2) * 0.02)
        self.out_re = nn.Linear(dim, vocab)
        self.out_im = nn.Linear(dim, vocab)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input token indices [B, L]
            
        Returns:
            logits: Output logits [B, L, V]
            z: Final complex state [B, L, D]
        """
        # Token embedding to complex space
        emb = self.embed(x)  # [B, L, D]
        z_raw = self.to_complex(emb)  # [B, L, 2D]
        z = torch.complex(z_raw[..., :self.dim], z_raw[..., self.dim:])  # [B, L, D]
        
        # Recursive operator application
        for i in range(self.depth):
            # Depth-dependent bias injection
            d_bias = self.depth_emb[i]  # [2D]
            z = z + torch.complex(d_bias[:self.dim], d_bias[self.dim:])  # [B, L, D]
            
            # Apply operator with optional gradient checkpointing
            if self.training and self.use_checkpoint:
                z = checkpoint(self.operator, z, use_reentrant=False)
            else:
                z = self.operator(z)
        
        # Project to vocabulary logits
        logits = self.out_re(z.real) + self.out_im(z.imag)  # [B, L, V]
        return logits, z

In [None]:
@torch.no_grad()
def generate(model: nn.Module, prompt: str, device: str = 'cuda') -> None:
    """
    Autoregressive text generation.
    
    Args:
        model: Trained ContinuousTransformer instance
        prompt: Initial text prompt
        device: Computation device
    """
    model.eval()
    tokens = [ord(c) for c in prompt]
    print(f"\nGeneration: {prompt}", end="", flush=True)
    
    for _ in range(GEN_LENGTH):
        ctx = torch.tensor([tokens[-1024:]], device=device)  # [1, L']
        logits, _ = model(ctx)  # [1, L', V]
        
        # Nucleus (top-p) sampling
        probs = torch.softmax(logits[0, -1] / GEN_TEMPERATURE, dim=-1)  # [V]
        sorted_probs, idx = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_probs, dim=0)
        cutoff = (cumsum > GEN_TOP_P).float()
        cutoff[1:] = cutoff[:-1].clone()
        cutoff[0] = 0
        probs[idx[cutoff.bool()]] = 0
        probs = probs / probs.sum()  # Renormalize
        
        next_token = torch.multinomial(probs, 1).item()
        tokens.append(next_token)
        print(chr(next_token), end="", flush=True)
    
    print("\n")
    model.train()

# Training

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Create dataset and dataloader
dataset = TinyStoriesDataset(data_path, seq_length=SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Initialize model
model = ContinuousTransformer(dim=DIM, depth=DEPTH, vocab=VOCAB_SIZE, use_checkpoint=USE_CHECKPOINT).to(device)
params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {params/1e6:.2f}M")

Device: cuda
GPU: NVIDIA GeForce RTX 3050 Ti Laptop GPU
Parameters: 2.91M


In [10]:
# Checkpoint restoration
import glob
checkpoints = glob.glob("continuous_transformer_step*.pt")
if checkpoints:
    steps = [int(cp.split("step")[1].split(".")[0]) for cp in checkpoints]
    latest_step = max(steps)
    latest_checkpoint = f"continuous_transformer_step{latest_step:05d}.pt"
    print(f"Loading checkpoint: {latest_checkpoint}")
    model.load_state_dict(torch.load(latest_checkpoint, map_location=device))
    start_step = latest_step
    print(f"Resuming from step {start_step}")
else:
    start_step = 0
    print("No checkpoint found.")

# Optimizer with homeostatic learning rate control
opt = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
ema_loss = None

step = start_step
try:
    while True:
        for batch_idx, (x, y) in enumerate(dataloader):
            x = x.to(device)  # [B, L]
            y = y.to(device)  # [B, L]
            
            # Forward pass
            logits, _ = model(x)  # [B, L, V]
            loss = nn.CrossEntropyLoss()(logits.view(-1, VOCAB_SIZE), y.view(-1))
            
            # Backward pass
            opt.zero_grad()
            loss.backward()
            
            # Homeostatic learning rate adjustment
            with torch.no_grad():
                # Compute gradient statistics
                grad_norm = sum(p.grad.norm()**2 for p in model.parameters() if p.grad is not None)**0.5
                
                # Update exponential moving average of loss
                if ema_loss is None: 
                    ema_loss = loss.item()
                ema_loss = 0.95 * ema_loss + 0.05 * loss.item()
                
                # Adaptive learning rate based on gradient stability
                stability = 1.0 / (grad_norm.item() + 1e-6)
                reactive_lr = BASE_LR / (1.0 + np.exp(-(stability - 0.5)))
                
                # Update optimizer learning rate
                for param_group in opt.param_groups:
                    param_group['lr'] = reactive_lr
            
            # Gradient clipping and optimization step
            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            opt.step()
            
            # Logging
            if step % PRINT_EVERY == 0:
                acc = (logits.argmax(-1) == y).float().mean()
                epoch = step // len(dataloader)
                batch_in_epoch = step % len(dataloader)
                print(f"Step {step:05d} (Epoch {epoch}, Batch {batch_in_epoch}/{len(dataloader)}) | Loss: {loss.item():.4f} | Acc: {acc.item()*100:.1f}% | LR: {reactive_lr:.2e} | Stability: {stability:.2f}")
            
            # Checkpointing and generation
            if step % CHECKPOINT_EVERY == 0 and step > start_step:
                generate(model, "Once upon a time", device=device)
                torch.save(model.state_dict(), f"continuous_transformer_step{step:05d}.pt")
                print(f"  Checkpoint saved: step {step}\n")
            
            step += 1
        
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
    print(f"Final step: {step}")
    
    # Save final checkpoint
    final_checkpoint = f"continuous_transformer_step{step:05d}.pt"
    torch.save(model.state_dict(), final_checkpoint)
    print(f"Final checkpoint saved: {final_checkpoint}")

Loading checkpoint: continuous_transformer_step134641.pt
Resuming from step 134641
Step 134700 (Epoch 0, Batch 134700/4351064) | Loss: 0.9487 | Acc: 71.2% | LR: 1.15e-04 | Stability: 0.02
Step 134800 (Epoch 0, Batch 134800/4351064) | Loss: 0.8964 | Acc: 72.4% | LR: 1.14e-04 | Stability: 0.01
Step 134900 (Epoch 0, Batch 134900/4351064) | Loss: 0.8981 | Acc: 72.0% | LR: 1.15e-04 | Stability: 0.03
Step 135000 (Epoch 0, Batch 135000/4351064) | Loss: 0.9494 | Acc: 70.4% | LR: 1.17e-04 | Stability: 0.05

Generation: Once upon a time to get him. They were happy and explain and took a little bird and protend.
<|endoftext|>
Once upon a time, there was a little boy named Ben. Ben sail in his friends and pretty. The cat on it a littl

  Checkpoint saved: step 135000

Step 135100 (Epoch 0, Batch 135100/4351064) | Loss: 0.9194 | Acc: 71.4% | LR: 1.17e-04 | Stability: 0.05
Step 135200 (Epoch 0, Batch 135200/4351064) | Loss: 0.9166 | Acc: 72.0% | LR: 1.17e-04 | Stability: 0.05

Training interrupted b

## References

- **A Mathematical Explanation of Transformers for Large Language Models and GPTs** - Tai et al. (2024). [arXiv:2510.03989](https://arxiv.org/abs/2510.03989)
- **HiPPO: Recurrent Memory with Optimal Polynomial Projections** - Gu et al. (2020). NeurIPS. [arXiv:2008.07669](https://arxiv.org/abs/2008.07669)
- **Efficiently Modeling Long Sequences with Structured State Spaces (S4)** - Gu et al. (2022). ICLR. [arXiv:2111.00396](https://arxiv.org/abs/2111.00396)
- **Neural Ordinary Differential Equations** - Chen et al. (2018). NeurIPS. [arXiv:1806.07366](https://arxiv.org/abs/1806.07366)
- **Deep Complex Networks** - Trabelsi et al. (2017). ICLR. [arXiv:1705.09792](https://arxiv.org/abs/1705.09792)
- **FNet: Mixing Tokens with Fourier Transforms** - Lee-Thorp et al. (2021). NAACL. [arXiv:2105.03824](https://arxiv.org/abs/2105.03824)
- **Language Modeling with Gated Convolutional Networks** - Dauphin et al. (2017). ICML. [arXiv:1612.08083](https://arxiv.org/abs/1612.08083)
- **Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting** - Zhou et al. (2021). AAAI. [arXiv:2012.07436](https://arxiv.org/abs/2012.07436)
- **Liquid Time-constant Networks** - Hasani et al. (2020). AAAI. [arXiv:2006.04439](https://arxiv.org/abs/2006.04439)
- **TinyStories: How Small Can Language Models Be and Still Speak Coherent English?** - Eldan & Li (2023). [arXiv:2305.07759](https://arxiv.org/abs/2305.07759)