# Architecture Parameter Comparison

This notebook compares parameter counts across different architectures:

- LSTM
- Transformer
- Stack-RNN (with superposition)
- Stack-Transformer

We'll also determine optimal hyperparameters given a fixed parameter budget.


In [35]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
from typing import Dict, Optional

# Set plotting style
plt.rcParams["font.family"] = "Times New Roman"
plt.style.use('seaborn-v0_8-whitegrid')

In [36]:
def calculate_transformer_params(vocab_size: int, num_layers: int, d_model: int,
                                feedforward_size: int) -> int:
    """Calculate parameter count for transformer analytically."""
    # Embeddings (tied with output)
    embedding_params = vocab_size * d_model

    # Per layer: attention + feedforward + layer norms
    per_layer = (
        4 * d_model * d_model +  # Q, K, V, O projections (no bias)
        d_model * feedforward_size + feedforward_size +  # FFN layer 1
        feedforward_size * d_model + d_model +  # FFN layer 2
        2 * 2 * d_model  # 2 Layer norms (weight and bias each)
    )

    # Final layer norm
    final_params = 2 * d_model

    return embedding_params + num_layers * per_layer + final_params


def calculate_lstm_params(vocab_size: int, num_layers: int, hidden_units: int) -> int:
    """Calculate parameter count for LSTM analytically.

    Based on PyTorch LSTM implementation without extra bias (use_extra_bias=False).
    """
    # Embeddings
    embedding_params = vocab_size * hidden_units

    # Initial hidden states (only h for each layer, c starts at 0)
    init_params = num_layers * hidden_units

    # LSTM layers: 4 gates * (input weights + hidden weights)
    # No bias terms when use_extra_bias=False
    lstm_params = num_layers * 4 * (hidden_units * hidden_units + hidden_units * hidden_units)

    # Output projection
    output_params = hidden_units * vocab_size + vocab_size

    return embedding_params + init_params + lstm_params + output_params


def calculate_rnn_params(vocab_size: int, num_layers: int, hidden_units: int) -> int:
    """Calculate parameter count for simple RNN analytically."""
    # Embeddings
    embedding_params = vocab_size * hidden_units

    # Initial hidden states
    init_params = num_layers * hidden_units

    # RNN layers (input weights + hidden weights, no bias with use_extra_bias=False)
    rnn_params = num_layers * (hidden_units * hidden_units + hidden_units * hidden_units)

    # Output projection
    output_params = hidden_units * vocab_size + vocab_size

    return embedding_params + init_params + rnn_params + output_params


def calculate_stack_rnn_params(vocab_size: int, num_layers: int, hidden_units: int,
                             stack_size: int, controller: str = 'lstm') -> int:
    """Calculate parameter count for Stack-RNN with superposition analytically."""
    # Base RNN/LSTM parameters
    if controller == 'lstm':
        base_params = calculate_lstm_params(vocab_size, num_layers, hidden_units)
    else:
        base_params = calculate_rnn_params(vocab_size, num_layers, hidden_units)

    # Stack-specific parameters (based on SuperpositionStackRNN)
    # MultiLayer for actions: creates separate layers for each stack
    # In this case, just one stack, so hidden_units -> 3 (with bias by default)
    action_params = hidden_units * 3 + 3

    # Push value layer: hidden_units -> stack_size with sigmoid (includes bias)
    push_value_params = hidden_units * stack_size + stack_size

    # Total stack overhead
    stack_params = action_params + push_value_params

    return base_params + stack_params


def calculate_stack_transformer_params(vocab_size: int, num_layers: int, num_stack_layers: int,
                                     d_model: int, feedforward_size: int, stack_size: int) -> int:
    """Calculate parameter count for Stack-Transformer with superposition analytically.

    The architecture string format should be: d_model-num_regular.stack_type-stack_size.num_stack
    For example: 768-2.superposition-32.2 means 2 regular layers + 2 stack layers
    """
    # Embeddings (tied)
    embedding_params = vocab_size * d_model

    # Regular transformer layers
    regular_per_layer = (
        4 * d_model * d_model +  # Q, K, V, O projections
        d_model * feedforward_size + feedforward_size +  # FFN layer 1
        feedforward_size * d_model + d_model +  # FFN layer 2
        2 * 2 * d_model  # Layer norms
    )

    # Stack attention layers (based on SuperpositionStackAttention)
    stack_per_layer = (
        # Stack attention specific parameters (no bias)
        d_model * 3 +  # Action layer (no bias)
        d_model * stack_size +  # Input to pushed vector (no bias)
        stack_size * d_model +  # Stack reading to output (no bias)
        # Regular transformer FFN and layer norms
        d_model * feedforward_size + feedforward_size +  # FFN layer 1
        feedforward_size * d_model + d_model +  # FFN layer 2
        2 * 2 * d_model  # Layer norms
    )

    # Final layer norm
    final_params = 2 * d_model

    total_params = (
        embedding_params +
        num_layers * regular_per_layer +
        num_stack_layers * stack_per_layer +
        final_params
    )

    return total_params

In [57]:
# Use empirical parameter counts instead
import sys
import torch
import torch.nn as nn

# Add the rau module to path
sys.path.append('/Users/agiats/Projects/lm_inductive_bias/src/rau/src')

from rau.models.transformer.unidirectional_encoder import get_unidirectional_transformer_encoder
from rau.models.rnn.language_model import get_simple_rnn_language_model, get_lstm_language_model
from rau.models.stack_nn.transformer.unidirectional_encoder import get_unidirectional_stack_transformer_encoder
from rau.models.stack_nn.rnn.language_model import get_stack_rnn_language_model
from rau.models.stack_nn.transformer.parse import parse_stack_transformer_layers
from rau.models.stack_nn.rnn.parse import parse_stack_rnn_stack


def count_parameters(model: nn.Module) -> int:
    """Count the number of trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_empirical_params(architecture: str, vocab_size: int, **kwargs) -> int:
    """Get empirical parameter count for a given architecture."""

    if architecture == 'lstm':
        model = get_lstm_language_model(
            input_vocabulary_size=vocab_size,
            output_vocabulary_size=vocab_size,
            hidden_units=kwargs['hidden_units'],
            layers=kwargs['num_layers'],
            dropout=0.0,
            learned_hidden_state=True,
            use_padding=False
        )

    elif architecture == 'transformer':
        model = get_unidirectional_transformer_encoder(
            input_vocabulary_size=vocab_size,
            output_vocabulary_size=vocab_size,
            tie_embeddings=True,
            num_layers=kwargs['num_layers'],
            d_model=kwargs['d_model'],
            num_heads=kwargs['num_heads'],
            feedforward_size=kwargs['feedforward_size'],
            dropout=0.0,
            use_padding=False
        )

    elif architecture == 'stack-rnn':
        stack_spec = f"superposition-{kwargs['stack_size']}"
        stack = parse_stack_rnn_stack(stack_spec)
        model = get_stack_rnn_language_model(
            input_vocabulary_size=vocab_size,
            output_vocabulary_size=vocab_size,
            hidden_units=kwargs['hidden_units'],
            layers=kwargs['num_layers'],
            controller='lstm',
            stack=stack,
            dropout=0.0,
            learned_hidden_state=True,
            use_padding=False,
            tag=None
        )

    elif architecture == 'stack-transformer':
        # Format: d_model-num_regular.stack_type-stack_size.num_stack
        layer_spec = f"{kwargs['d_model']}-{kwargs['num_layers']}.superposition-{kwargs['stack_size']}.{kwargs['num_stack_layers']}"
        layers = parse_stack_transformer_layers(layer_spec)
        model = get_unidirectional_stack_transformer_encoder(
            input_vocabulary_size=vocab_size,
            output_vocabulary_size=vocab_size,
            tie_embeddings=True,
            layers=layers,
            d_model=kwargs['d_model'],
            num_heads=kwargs['num_heads'],
            feedforward_size=kwargs['feedforward_size'],
            dropout=0.0,
            use_padding=False
        )

    return count_parameters(model)


# Test our empirical counting
vocab_size = 64
num_layers = 3

print("=== Empirical Parameter Counts ===\n")

# Test configurations
configs = [
    ('LSTM (h=32)', 'lstm', {'hidden_units': 32, 'num_layers': num_layers}),
    ('LSTM (h=64)', 'lstm', {'hidden_units': 64, 'num_layers': num_layers}),
    ('LSTM (h=128)', 'lstm', {'hidden_units': 128, 'num_layers': num_layers}),

    ('Transformer (d=32)', 'transformer', {'d_model': 32, 'num_heads': 4, 'feedforward_size': 128, 'num_layers': num_layers}),
     ('Transformer (d=48)', 'transformer', {'d_model': 48, 'num_heads': 4, 'feedforward_size': 192, 'num_layers': num_layers}),
     ('Transformer (d=56)', 'transformer', {'d_model': 56, 'num_heads': 4, 'feedforward_size': 224, 'num_layers': num_layers}),
    ('Transformer (d=64)', 'transformer', {'d_model': 64, 'num_heads': 4, 'feedforward_size': 256, 'num_layers': num_layers}),
    ('Transformer (d=128)', 'transformer', {'d_model': 128, 'num_heads': 4, 'feedforward_size': 512, 'num_layers': num_layers}),

    ('Stack-RNN (h=32, s=20)', 'stack-rnn', {'hidden_units': 32, 'stack_size': 20, 'num_layers': num_layers}),
    ('Stack-RNN (h=32, s=32)', 'stack-rnn', {'hidden_units': 32, 'stack_size': 32, 'num_layers': num_layers}),
    ('Stack-RNN (h=32, s=64)', 'stack-rnn', {'hidden_units': 32, 'stack_size': 64, 'num_layers': num_layers}),
    ('Stack-RNN (h=64, s=20)', 'stack-rnn', {'hidden_units': 64, 'stack_size': 20, 'num_layers': num_layers}),
    ('Stack-RNN (h=64, s=32)', 'stack-rnn', {'hidden_units': 64, 'stack_size': 32, 'num_layers': num_layers}),
    ('Stack-RNN (h=64, s=64)', 'stack-rnn', {'hidden_units': 64, 'stack_size': 64, 'num_layers': num_layers}),

    ('Stack-Transformer (d=16, s=16)', 'stack-transformer',
     {'d_model': 16, 'num_heads': 4, 'feedforward_size': 64, 'stack_size': 16, 'num_layers': 1, 'num_stack_layers': 1}),
    ('Stack-Transformer (d=20, s=8)', 'stack-transformer',
     {'d_model': 20, 'num_heads': 4, 'feedforward_size': 80, 'stack_size': 8, 'num_layers': 1, 'num_stack_layers': 1}),
    ('Stack-Transformer (d=20, s=16)', 'stack-transformer',
     {'d_model': 20, 'num_heads': 4, 'feedforward_size': 80, 'stack_size': 16, 'num_layers': 1, 'num_stack_layers': 1}),
    ('Stack-Transformer (d=24, s=16)', 'stack-transformer',
     {'d_model': 24, 'num_heads': 4, 'feedforward_size': 96, 'stack_size': 16, 'num_layers': 1, 'num_stack_layers': 1}),
]

results = []
for name, arch, params in configs:
    try:
        param_count = get_empirical_params(arch, vocab_size, **params)
        results.append({
            'Architecture': name,
            'Parameters': param_count,
            'Hidden/d_model': params.get('hidden_units') or params.get('d_model'),
            'Stack Size': params.get('stack_size', '-')
        })
        print(f"{name}: {param_count:,} parameters")
    except Exception as e:
        print(f"{name}: Error - {e}")

df_empirical = pd.DataFrame(results)
print("\n=== Parameter Summary ===")
display(df_empirical)

=== Empirical Parameter Counts ===

LSTM (h=32): 27,104 parameters
LSTM (h=64): 103,360 parameters
LSTM (h=128): 403,328 parameters
Transformer (d=32): 40,224 parameters
Transformer (d=48): 87,984 parameters
Transformer (d=56): 118,776 parameters
Transformer (d=64): 154,176 parameters
Transformer (d=128): 603,264 parameters
Stack-RNN (h=32, s=20): 30,423 parameters
Stack-RNN (h=32, s=32): 32,355 parameters
Stack-RNN (h=32, s=64): 37,507 parameters
Stack-RNN (h=64, s=20): 109,975 parameters
Stack-RNN (h=64, s=32): 113,827 parameters
Stack-RNN (h=64, s=64): 124,099 parameters
Stack-Transformer (d=16, s=16): 59,568 parameters
Stack-Transformer (d=20, s=8): 111,340 parameters
Stack-Transformer (d=20, s=16): 111,660 parameters
Stack-Transformer (d=24, s=16): 187,848 parameters

=== Parameter Summary ===


Unnamed: 0,Architecture,Parameters,Hidden/d_model,Stack Size
0,LSTM (h=32),27104,32,-
1,LSTM (h=64),103360,64,-
2,LSTM (h=128),403328,128,-
3,Transformer (d=32),40224,32,-
4,Transformer (d=48),87984,48,-
5,Transformer (d=56),118776,56,-
6,Transformer (d=64),154176,64,-
7,Transformer (d=128),603264,128,-
8,"Stack-RNN (h=32, s=20)",30423,32,20
9,"Stack-RNN (h=32, s=32)",32355,32,32
