In [1]:
from bbml.foundations.gpt2 import GPT2Foundation, GPTConfig

In [2]:
foundation = GPT2Foundation(  # based on nanoGPT
    GPTConfig(),
    None
)

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 123.65M


In [3]:
foundation.run(foundation.input_model(text="The quick brown fox jumps over", max_new_tokens=10)).text

'The quick brown fox jumps over it at the start of its fall to catch the'

In [4]:

from enum import Enum
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


class SplitLinear(nn.Module):

    def __init__(self, bias: bool, out_features: int | None = None, device=None, dtype=None):
        super().__init__()
        self.splits: nn.ModuleList = nn.ModuleList()

        if bias:
            self.bias = nn.Parameter(
                torch.empty(out_features, device=device, dtype=dtype)
            )
        else:
            self.register_parameter('bias', None)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        outs = []
        for split in self.splits:
            outs.append(split(x))
        out = torch.cat(outs, dim=-1)
        if self.bias is not None:
            out = out + self.bias
        return out
    

class ShareLinearState(str, Enum):
    ORIGINAL = "original"  # Forward uses original weights
    CALIBRATING = "calibrating"  # Forward uses original weights + tracks inputs
    COMPRESSED = "compressed"  # Forward uses basis @ coefficient


class ShareLinear(nn.Module):
    state = ShareLinearState.ORIGINAL

    def __init__(
        self,
        in_features: int,
        basis_features: int,
        out_features: int,
        device=None,
        dtype=None
    ):
        super().__init__()
        self.in_features = in_features
        self.basis_features = basis_features
        self.out_features = out_features
        
        self.basis = nn.Parameter(
            torch.empty(basis_features, in_features, device=device, dtype=dtype)
        )
        self.coefficient = nn.Parameter(
            torch.empty(out_features, basis_features, device=device, dtype=dtype)
        )
        self.original = nn.Parameter(
            torch.empty(out_features, in_features, device=device, dtype=dtype)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.state == ShareLinearState.ORIGINAL:
            return F.linear(x, self.original)

        b = F.linear(x, self.basis)
        out = F.linear(b, self.coefficient)
        return out
    
    def extra_repr(self) -> str:
        return f"in={self.in_features}, basis={self.basis_features}, out={self.out_features}, state={self.state.value}"

        

In [5]:
from collections import defaultdict
import re

from typing import Literal
from pydantic import BaseModel, model_validator

class WeightConfig(BaseModel):
    pattern: str
    split: bool|Literal["qkv", "heads"] = False
    qkv_order: str = "qkv"
    n_head: int|None = None

    @model_validator(mode="after")
    def qkv_string(self):
        if self.split:
            if not self.qkv_order or set(self.qkv_order.lower()) != {"q", "k", "v"}:
                raise ValueError("qkv_order must contain exactly 'q', 'k', and 'v' characters when split='qkv'")
        return self

    @model_validator(mode="after")
    def validate_heads_split(self):
        if self.split == "heads" and self.n_head is None:
            raise ValueError("n_head must be set when split='heads'")
        return self

class SplitConfig(BaseModel):
    weight_types: dict[str, WeightConfig]
    block_pattern: str|None = None



class SplitLinearWrapper:
    def __init__(self, model, config):
        self.model = model
        self.config = config

    
        self.original_weights = []

        named_modules = {k:v for k,v in model.named_modules()}

        for name, module in named_modules.items():
            for wtype, wtype_cfg in self.config.weight_types.items():
                if re.match(wtype_cfg.pattern, name) is not None:
                    name_parts = name.split(".")
                    parent_name = ".".join(name_parts[:-1])
                    parent = named_modules[parent_name]
                    list_id = None
                    if name_parts[-1].isdigit():
                        list_id = int(name_parts[-1])
                    
                    layer_num = re.search(config.block_pattern, name).group(1)

                    self.original_weights.append({
                        "name": name,
                        "module": module,
                        "parent": parent,
                        "list_id": list_id,
                        "last_name_part": name_parts[-1],
                        "layer": layer_num,
                        "weight_type": wtype,
                    })

        self.all_basislinears = {}
        self.weight_types = defaultdict(list)

        for wt_dict in self.original_weights:
            wtype_cfg = self.config.weight_types[wt_dict["weight_type"]]
            module = wt_dict["module"]
            name = wt_dict["name"]
            has_bias = module.bias is not None
            in_feats = module.in_features
            out_feats = module.out_features
            split_linear = SplitLinear(bias=has_bias, out_features=out_feats)
            if has_bias:
                split_linear.bias.data = module.bias.data
            setattr(wt_dict["parent"], wt_dict["last_name_part"], split_linear)

            if wtype_cfg.split == "qkv" or wtype_cfg.split == "heads":
                assert out_feats % 3 == 0
                qkv_feats = out_feats // 3 
            
            if wtype_cfg.split == "qkv":
                cur_ind = 0            
                for qkv_part in wtype_cfg.qkv_order:
                    to_part = ShareLinear(in_feats, in_feats, qkv_feats)    
                    to_part.original.data = module.weight.data[cur_ind:cur_ind+qkv_feats,:]
                    cur_ind += qkv_feats
                    
                    split_linear.splits.append(to_part)
                    self.all_basislinears[f"{name}.{qkv_part}"] = to_part
                
            elif wtype_cfg.split == "heads":
                n_heads = wtype_cfg.n_head
                head_dim = qkv_feats // n_heads

                cur_ind = 0
                for qkv_part in wtype_cfg.qkv_order:
                    for head_num in range(n_heads):
                        to_head = ShareLinear(in_feats, in_feats, head_dim)
                        to_head.original.data = module.weight.data[cur_ind:cur_ind+head_dim,:]  # [out_dim, in_dim] -> [head_dim, in_dim]
                        cur_ind += head_dim

                        split_linear.splits.append(to_head)
                        self.all_basislinears[f"{name}.{qkv_part}.{head_num}"] = to_head
                

            else: # no split
                in_linear = ShareLinear(in_feats, in_feats, out_feats)
                in_linear.original.data = module.weight.data

                split_linear.splits.append(in_linear)
                self.all_basislinears[name] = in_linear
        


In [6]:
%%writefile config.yaml
weight_types:
    attn_c_attn:
        pattern: '.*\.attn\.c_attn$'
        split: heads  # qkv, heads
        qkv_order: qkv  # string
        n_head: 12
    attn_c_proj: 
        pattern: '.*\.attn\.c_proj$'
    mlp_c_fc: 
        pattern: '.*\.mlp\.c_fc$'
    mlp_c_proj: 
        pattern: '.*\.mlp\.c_proj$'

block_pattern: 'h\.(\d+)'

Overwriting config.yaml


In [7]:
import yaml
with open("config.yaml") as f:
    cfg = yaml.safe_load(f)

In [8]:
splitcfg = SplitConfig(**cfg)

In [9]:
wrapper = SplitLinearWrapper(foundation, splitcfg)

In [10]:
# ============================================================================
# VERIFICATION TEST: Compare original vs wrapped model outputs
# ============================================================================
# This test verifies that SplitLinear correctly reproduces the original Linear behavior.
# We use deterministic forward passes (not generation) to isolate weight transfer issues.

import copy

def verify_split_linear_equivalence():
    """Verify that splitting a linear layer preserves its forward pass behavior."""
    print("=" * 60)
    print("VERIFICATION TEST: SplitLinear Weight Transfer")
    print("=" * 60)
    
    # Test 1: Single Linear (no split)
    print("\n[Test 1] Single Linear → SplitLinear (no split)")
    torch.manual_seed(42)
    original_linear = nn.Linear(768, 3072, bias=True)
    x = torch.randn(2, 10, 768)
    
    with torch.no_grad():
        expected = original_linear(x)
    
    # Create SplitLinear wrapper
    split_linear = SplitLinear(bias=True, out_features=3072)
    split_linear.bias.data = original_linear.bias.data.clone()
    
    share = ShareLinear(768, 768, 3072)
    share.original.data = original_linear.weight.data.clone()
    split_linear.splits.append(share)
    
    with torch.no_grad():
        actual = split_linear(x)
    
    diff = (expected - actual).abs().max().item()
    status = "✅ PASS" if diff < 1e-5 else "❌ FAIL"
    print(f"  Max difference: {diff:.2e} {status}")
    
    # Test 2: QKV Split (3-way split)
    print("\n[Test 2] QKV Linear → SplitLinear (3-way split)")
    torch.manual_seed(42)
    qkv_linear = nn.Linear(768, 2304, bias=True)  # 768 * 3 = 2304
    x = torch.randn(2, 10, 768)
    
    with torch.no_grad():
        expected = qkv_linear(x)
    
    split_linear = SplitLinear(bias=True, out_features=2304)
    split_linear.bias.data = qkv_linear.bias.data.clone()
    
    # Split into Q, K, V (each 768)
    for i, name in enumerate(['q', 'k', 'v']):
        start = i * 768
        end = start + 768
        share = ShareLinear(768, 768, 768)
        share.original.data = qkv_linear.weight.data[start:end, :].clone()
        split_linear.splits.append(share)
    
    with torch.no_grad():
        actual = split_linear(x)
    
    diff = (expected - actual).abs().max().item()
    status = "✅ PASS" if diff < 1e-5 else "❌ FAIL"
    print(f"  Max difference: {diff:.2e} {status}")
    
    # Test 3: Head Split (36-way split for 12 heads × 3 QKV)
    print("\n[Test 3] QKV Linear → SplitLinear (36-way head split)")
    torch.manual_seed(42)
    qkv_linear = nn.Linear(768, 2304, bias=True)
    x = torch.randn(2, 10, 768)
    
    with torch.no_grad():
        expected = qkv_linear(x)
    
    split_linear = SplitLinear(bias=True, out_features=2304)
    split_linear.bias.data = qkv_linear.bias.data.clone()
    
    n_heads = 12
    head_dim = 64  # 768 / 12
    cur_ind = 0
    for qkv_part in ['q', 'k', 'v']:
        for head_num in range(n_heads):
            share = ShareLinear(768, 768, head_dim)
            share.original.data = qkv_linear.weight.data[cur_ind:cur_ind+head_dim, :].clone()
            cur_ind += head_dim
            split_linear.splits.append(share)
    
    with torch.no_grad():
        actual = split_linear(x)
    
    diff = (expected - actual).abs().max().item()
    status = "✅ PASS" if diff < 1e-5 else "❌ FAIL"
    print(f"  Max difference: {diff:.2e} {status}")
    
    print("\n" + "=" * 60)
    return diff < 1e-5

# Run the isolated verification tests
all_passed = verify_split_linear_equivalence()


VERIFICATION TEST: SplitLinear Weight Transfer

[Test 1] Single Linear → SplitLinear (no split)
  Max difference: 1.07e-06 ✅ PASS

[Test 2] QKV Linear → SplitLinear (3-way split)
  Max difference: 1.07e-06 ✅ PASS

[Test 3] QKV Linear → SplitLinear (36-way head split)
  Max difference: 1.07e-06 ✅ PASS



In [16]:
# ============================================================================
# FULL MODEL VERIFICATION: Test the SplitLinearWrapper on the actual model
# ============================================================================
# This creates a fresh model and compares outputs before/after wrapping

def verify_full_model_wrapper():
    """Test the complete SplitLinearWrapper on GPT-2."""
    print("=" * 60)
    print("FULL MODEL VERIFICATION: SplitLinearWrapper on GPT-2")
    print("=" * 60)
    
    # Create a fresh model
    print("\n[1] Loading fresh GPT-2 model...")
    fresh_model = GPT2Foundation(GPTConfig(), None)
    fresh_model.model.eval()
    
    # Create deterministic test input
    torch.manual_seed(123)
    test_input = torch.randint(0, 1000, (1, 20))
    
    # Get original output BEFORE wrapping
    print("[2] Getting original model output...")
    with torch.no_grad():
        original_output = fresh_model.model(test_input)
        original_logits = original_output.clone() if isinstance(original_output, torch.Tensor) else original_output[0].clone()
    
    # Apply the wrapper
    print("[3] Applying SplitLinearWrapper...")
    test_wrapper = SplitLinearWrapper(fresh_model, splitcfg)
    fresh_model.model.eval()
    
    # Get output AFTER wrapping
    print("[4] Getting wrapped model output...")
    with torch.no_grad():
        wrapped_output = fresh_model.model(test_input)
        wrapped_logits = wrapped_output if isinstance(wrapped_output, torch.Tensor) else wrapped_output[0]
    
    # Compare
    print("[5] Comparing outputs...")
    diff = (original_logits - wrapped_logits).abs()
    max_diff = diff.max().item()
    mean_diff = diff.mean().item()
    
    print(f"\n  Original logits shape: {original_logits.shape}")
    print(f"  Wrapped logits shape:  {wrapped_logits.shape}")
    print(f"  Max absolute difference:  {max_diff:.2e}")
    print(f"  Mean absolute difference: {mean_diff:.2e}")
    
    passed = False
    if torch.allclose(original_logits, wrapped_logits, rtol=1e-4, atol=1e-4):
        print("✅ PASS")
        passed = True
    
    print("\n" + "=" * 60)
    return passed

# Run full model verification
# Note: This will load a new model, so it takes a moment
model_test_passed = verify_full_model_wrapper()


FULL MODEL VERIFICATION: SplitLinearWrapper on GPT-2

[1] Loading fresh GPT-2 model...
loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 123.65M
[2] Getting original model output...
[3] Applying SplitLinearWrapper...
[4] Getting wrapped model output...
[5] Comparing outputs...

  Original logits shape: torch.Size([1, 20, 50257])
  Wrapped logits shape:  torch.Size([1, 20, 50257])
  Max absolute difference:  1.14e-04
  Mean absolute difference: 2.05e-05
✅ PASS



In [12]:
foundation.run(foundation.input_model(text="The quick brown fox jumps over", max_new_tokens=10)).text

'The quick brown fox jumps over the startled soldier and swallows him in a fierce'

In [13]:
wrapper.all_basislinears

{'model.transformer.h.0.attn.c_attn.q.0': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.1': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.2': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.3': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.4': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.5': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.6': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.7': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.8': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.h.0.attn.c_attn.q.9': ShareLinear(in=768, basis=768, out=64, state=original),
 'model.transformer.

In [14]:
wrapper.weight_types

defaultdict(list, {})

In [15]:
list(foundation.named_modules())

[('',
  GPT2Foundation(
    (model): GPT(
      (transformer): ModuleDict(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.0, inplace=False)
        (h): ModuleList(
          (0-11): 12 x Block(
            (ln_1): LayerNorm()
            (attn): CausalSelfAttention(
              (c_attn): SplitLinear(
                (splits): ModuleList(
                  (0-35): 36 x ShareLinear(in=768, basis=768, out=64, state=original)
                )
              )
              (c_proj): SplitLinear(
                (splits): ModuleList(
                  (0): ShareLinear(in=768, basis=768, out=768, state=original)
                )
              )
              (attn_dropout): Dropout(p=0.0, inplace=False)
              (resid_dropout): Dropout(p=0.0, inplace=False)
            )
            (ln_2): LayerNorm()
            (mlp): MLP(
              (c_fc): SplitLinear(
                (splits): ModuleList(
                  (0): ShareLin