# Stage-1 Model Surgery (Adaptive) — Liquid → Hybrid (Net2Net widen + new layers)

This notebook:
- Introspects your student checkpoint to infer d_model, n_layers, and vocab size.
- Widens all tensors that depend on d_model by +10% (rounded to a multiple of 8).
- Appends 2 classic (transformer-style) layers (tiny-out init) and 3 liquid layers (small-scale init).
- Writes out: a new checkpoint, config_stage1.json, freeze_mask.json, and surgery_report.md.

> Fill in the two placeholders for CHECKPOINT_PATH and OUTPUT_DIR below.

In [None]:
# Cell 1: install (optional if you've already installed requirements.txt)
# This uses the PyTorch cu121 wheel index.
# !pip install -r requirements.txt --index-url https://download.pytorch.org/whl/cu121

In [None]:
# Cell 2: imports, config, and IO paths
import os
import sys
import re
import json
import math
import time
import shutil
import subprocess
import datetime
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from difflib import SequenceMatcher

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("high")
SURGERY_DTYPE = torch.float32

# ---------------------------
# TODO: set your IO paths
# ---------------------------
CHECKPOINT_PATH = r"C:\Users\samsf\Liquid-LLM\Important-Model-Checkpoint\stage0.pt"
OUTPUT_DIR = r"C:\Users\samsf\Liquid-LLM\Important-Model-Checkpoint\stage-1"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Surgery hyperparameters / knobs
WIDTH_SCALE      = 1.10   # +10%
ROUND_MULT       = 8      # tensor-core friendly multiple
ADD_CLASSIC_LAY  = 2
ADD_LIQUID_LAY   = 3
WIDTH_TOLERANCE  = 0.06   # tolerate +/-6% when matching k*d_model dims
ENABLE_DRY_RUN   = False  # set True to verify widen-only logits similarity
DRY_RUN_BATCH    = 2
DRY_RUN_SEQ      = 16
DRY_RUN_COS_TGT  = 0.99
N_HEADS_OVERRIDE: Optional[int] = None  # set to int to override inference

# Dtype override (default float32). For bf16/fp16 surgery, adjust SURGERY_DTYPE above.

# Optional: add any package roots so imports resolve inside the notebook runtime.
NOTEBOOK_DIR = Path.cwd()
PACKAGE_HINTS = [
    Path("vertex") / "package" / "Stage_1",
    Path("vertex") / "package" / "liquid_llm_vertex_pkg_4" / "src",
    Path("vertex") / "package" / "liquid_llm_vertex_pkg_3" / "src",
    Path("vertex") / "package" / "liquid_llm_vertex_pkg_4_annealing" / "src",
    Path("vertex") / "package" / "liquid_llm_vertex_pkg_1024" / "src",
    Path("vertex") / "package" / "liquid_llm_vertex_pkg_1024_next" / "src",
]

added_sys_paths: List[str] = []
search_roots = [NOTEBOOK_DIR, *NOTEBOOK_DIR.parents]
for rel_path in PACKAGE_HINTS:
    for base in search_roots:
        candidate = (base / rel_path).resolve()
        if candidate.is_dir():
            candidate_str = str(candidate)
            if candidate_str not in sys.path:
                sys.path.append(candidate_str)
                added_sys_paths.append(candidate_str)
            break
if added_sys_paths:
    print("Added package roots:")
    for _path in added_sys_paths:
        print(f"- {_path}")
else:
    print("WARNING: no package roots were added automatically; adjust PACKAGE_HINTS if imports fail.")

# ------------------------------------------------------------------
# Import the real model/config classes from the package
# ------------------------------------------------------------------
from Stage_1.models.config import ModelConfig
from Stage_1.models.blocks import LiquidBlock, ClassicBlock
from Stage_1.models.stage1_model import Stage1Model

try:  # optional Stage-0 student class (for dry-run cosine check if available)
    from liquid_llm.models.liquid import StudentLM
except Exception:
    StudentLM = None


In [None]:
# Cell 3: checkpoint IO + dimension inference utilities
PREFIXES_TO_STRIP = (
    "module.",
    "model.",
    "student.",
    "state_dict.",
    "network.",
)


def strip_known_prefixes(key: str) -> str:
    new_key = key
    changed = True
    while changed:
        changed = False
        for prefix in PREFIXES_TO_STRIP:
            if new_key.startswith(prefix):
                new_key = new_key[len(prefix) :]
                changed = True
    return new_key


def resolve_checkpoint_path(path: str) -> str:
    if not path.startswith("gs://"):
        return path
    local_name = os.path.join("/tmp", os.path.basename(path))
    if shutil.which("gsutil"):
        print(f"Copying {path} -> {local_name} via gsutil")
        subprocess.run(["gsutil", "cp", path, local_name], check=True)
        return local_name
    try:
        import gcsfs  # type: ignore

        print(f"Copying {path} -> {local_name} via gcsfs")
        fs = gcsfs.GCSFileSystem()
        with fs.open(path, "rb") as src, open(local_name, "wb") as dst:
            dst.write(src.read())
        return local_name
    except Exception as exc:  # pragma: no cover - optional dependency path
        raise RuntimeError(
            "gs:// path provided but neither gsutil nor gcsfs is available."
        ) from exc


def canonicalize_state_dict(
    state_dict: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
    canonical: Dict[str, torch.Tensor] = {}
    aux: Dict[str, Any] = {}
    for key, value in state_dict.items():
        if torch.is_tensor(value):
            canonical[strip_known_prefixes(key)] = value.detach().to(SURGERY_DTYPE)
        else:
            aux[key] = value
    return canonical, aux


def load_checkpoint(path: str) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
    resolved = resolve_checkpoint_path(path)
    print(f"Loading checkpoint from {resolved}")
    payload = torch.load(resolved, map_location="cpu")
    if isinstance(payload, dict):
        if "state_dict" in payload and isinstance(payload["state_dict"], dict):
            state_dict = payload["state_dict"]
            metadata = {k: v for k, v in payload.items() if k != "state_dict"}
        else:
            state_dict = {k: v for k, v in payload.items() if torch.is_tensor(v)}
            metadata = {k: v for k, v in payload.items() if not torch.is_tensor(v)}
    elif isinstance(payload, torch.Tensor):
        raise ValueError("Expected a dict-like checkpoint payload, found a raw tensor.")
    else:
        raise ValueError("Unsupported checkpoint format; expected a mapping or state_dict wrapper.")
    canonical, aux = canonicalize_state_dict(state_dict)
    if aux:
        metadata.setdefault("non_tensor", {}).update(aux)
    return canonical, metadata


def infer_layer_count(keys: Iterable[str]) -> int:
    patterns = [
        re.compile(r"(?:^|\.)(layers|blocks|h)\.(\d+)")
    ]
    indices: set[int] = set()
    for key in keys:
        for pat in patterns:
            for match in pat.finditer(key):
                try:
                    indices.add(int(match.group(2)))
                except ValueError:
                    continue
    return (max(indices) + 1) if indices else 0


def infer_vocab_and_width(
    state_dict: Dict[str, torch.Tensor]
) -> Tuple[int, int, str]:
    candidates: List[Tuple[int, int, str, torch.Tensor]] = []
    for key, tensor in state_dict.items():
        if tensor.ndim != 2:
            continue
        vocab, width = tensor.shape
        score = 0
        key_lower = key.lower()
        if vocab < 128 or width < 32:
            continue
        if "embed" in key_lower:
            score += 5
        if "token" in key_lower:
            score += 2
        if "lm_head" in key_lower or "output" in key_lower:
            score += 1
        candidates.append((score, -vocab, key, tensor))
    if not candidates:
        raise ValueError("Could not infer embedding matrix to determine d_model/vocab_size.")
    candidates.sort(reverse=True)
    _, _, key, tensor = candidates[0]
    vocab, width = tensor.shape
    return width, vocab, key


def infer_max_seq_len(state_dict: Dict[str, torch.Tensor]) -> int:
    candidates: List[Tuple[int, str, torch.Tensor]] = []
    for key, tensor in state_dict.items():
        lower = key.lower()
        if "pos" not in lower:
            continue
        if tensor.ndim == 2:
            length = tensor.shape[0]
        elif tensor.ndim >= 3:
            length = tensor.shape[-2]
        else:
            continue
        candidates.append((length, key, tensor))
    if not candidates:
        return 2048
    candidates.sort(reverse=True)
    length, key, tensor = candidates[0]
    print(f"Detected positional embedding candidate '{key}' with length={length}")
    return int(length)


def _search_heads_in_metadata(obj: Any) -> Optional[int]:
    if isinstance(obj, dict):
        for key, value in obj.items():
            key_l = key.lower()
            if key_l in {"n_heads", "num_heads", "num_attention_heads", "n_head"}:
                if isinstance(value, (int, float)):
                    return int(value)
            result = _search_heads_in_metadata(value)
            if result is not None:
                return result
    return None


def infer_n_heads(
    state_dict: Dict[str, torch.Tensor], metadata: Dict[str, Any], d_model: int
) -> Optional[int]:
    heads = _search_heads_in_metadata(metadata)
    if heads:
        return heads
    for key, tensor in state_dict.items():
        lower = key.lower()
        if tensor.numel() == 1 and ("num_heads" in lower or "n_heads" in lower):
            return int(tensor.item())
    # Heuristic fallback: prefer divisors producing head_dim close to 64
    divisors = [h for h in range(1, d_model + 1) if d_model % h == 0]
    if not divisors:
        return None
    preferred = [16, 12, 8, 4, 32]
    for candidate in preferred:
        if candidate in divisors:
            return candidate
    return max(divisors)


def infer_model_stats(
    state_dict: Dict[str, torch.Tensor], metadata: Dict[str, Any]
) -> Dict[str, Any]:
    d_model, vocab_size, embed_key = infer_vocab_and_width(state_dict)
    n_layers = infer_layer_count(state_dict.keys())
    max_seq_len = infer_max_seq_len(state_dict)
    n_heads = infer_n_heads(state_dict, metadata, d_model)
    info = {
        "d_model": d_model,
        "vocab_size": vocab_size,
        "embed_key": embed_key,
        "n_layers": n_layers,
        "max_seq_len": max_seq_len,
        "n_heads": n_heads,
    }
    return info


In [None]:
# Cell 4: shape scaling, Net2Net widening, and initialization helpers
def compute_target_width(
    d_model: int,
    scale: float,
    round_mult: int,
    n_heads: Optional[int] = None,
) -> int:
    widened = int(round(d_model * scale / round_mult) * round_mult)
    widened = max(widened, d_model)
    if n_heads and n_heads > 0:
        remainder = widened % n_heads
        if remainder:
            widened += n_heads - remainder
    return widened


def scale_dim(dim: int, d_old: int, d_new: int, tol: float) -> int:
    if d_old == 0:
        return dim
    for k in range(1, 9):
        baseline = k * d_old
        if baseline == 0:
            continue
        if abs(dim - baseline) <= tol * baseline:
            return int(round(k * d_new))
    if abs(dim - d_old) <= tol * max(dim, d_old):
        return int(round(dim * d_new / d_old))
    return dim


def inverse_scale_dim(dim: int, d_old: int, d_new: int, tol: float) -> int:
    if d_new == 0:
        return dim
    for k in range(1, 9):
        baseline = k * d_new
        if baseline == 0:
            continue
        if abs(dim - baseline) <= tol * baseline:
            return int(round(k * d_old))
    if abs(dim - d_new) <= tol * max(dim, d_new):
        return int(round(dim * d_old / d_new))
    return dim


def scale_shape(shape: Tuple[int, ...], d_old: int, d_new: int, tol: float) -> Tuple[int, ...]:
    return tuple(scale_dim(dim, d_old, d_new, tol) for dim in shape)


def inverse_scale_shape(
    shape: Iterable[int], d_old: int, d_new: int, tol: float
) -> Tuple[int, ...]:
    return tuple(inverse_scale_dim(dim, d_old, d_new, tol) for dim in shape)


def key_similarity(dest_key: str, src_key: str) -> float:
    seq_score = SequenceMatcher(None, dest_key, src_key).ratio()
    dest_tokens = set(filter(None, re.split(r"[._]", dest_key)))
    src_tokens = set(filter(None, re.split(r"[._]", src_key)))
    token_score = 0.0
    if dest_tokens or src_tokens:
        token_score = len(dest_tokens & src_tokens) / max(len(dest_tokens | src_tokens), 1)
    return seq_score + token_score


def is_new_layer_key(key: str, base_layers: int) -> bool:
    if not key.startswith("blocks."):
        return False
    parts = key.split(".")
    if len(parts) < 2 or not parts[1].isdigit():
        return False
    return int(parts[1]) >= base_layers


def build_parameter_mapping(
    dest_state: Dict[str, torch.Tensor],
    src_state: Dict[str, torch.Tensor],
    base_layers: int,
    d_old: int,
    d_new: int,
    tol: float,
) -> Dict[str, str]:
    mapping: Dict[str, str] = {}
    assigned: set[str] = set()
    for dest_key, dest_tensor in dest_state.items():
        if is_new_layer_key(dest_key, base_layers):
            continue
        expected_old_shape = inverse_scale_shape(dest_tensor.shape, d_old, d_new, tol)
        candidates: List[Tuple[float, str]] = []
        for src_key, src_tensor in src_state.items():
            if src_key in assigned:
                continue
            if src_tensor.shape == expected_old_shape:
                candidates.append((key_similarity(dest_key, src_key), src_key))
            elif (
                len(dest_tensor.shape) == len(expected_old_shape) + 1
                and dest_tensor.shape[0] == 1
                and src_tensor.shape == expected_old_shape[1:]
            ):
                candidates.append((key_similarity(dest_key, src_key), src_key))
        if candidates:
            candidates.sort(reverse=True)
            best_key = candidates[0][1]
            mapping[dest_key] = best_key
            assigned.add(best_key)
    return mapping


def widen_tensor(src: torch.Tensor, target_shape: Tuple[int, ...]) -> torch.Tensor:
    tensor = src.detach().to(SURGERY_DTYPE)
    if tensor.shape == target_shape:
        return tensor.clone()
    result = tensor
    changed = False
    for dim, new_size in enumerate(target_shape):
        old_size = result.shape[dim]
        if old_size == new_size:
            continue
        changed = True
        if old_size == 0:
            raise ValueError(f"Cannot widen dimension {dim} with size 0")
        indices = torch.linspace(0, old_size - 1, new_size)
        indices = indices.round().long()
        result = torch.index_select(result, dim, indices)
    result = result.clone()
    if changed:
        std = tensor.std().item() if tensor.numel() else 0.0
        noise_std = max(abs(std) * 0.01, 1e-3)
        noise = torch.randn(result.shape, dtype=result.dtype)
        result.add_(noise * noise_std)
    return result


def adapt_tensor(
    src_tensor: torch.Tensor,
    dest_shape: Tuple[int, ...],
    d_old: int,
    d_new: int,
    tol: float,
) -> Tuple[Optional[torch.Tensor], str]:
    tensor = src_tensor.detach().to(SURGERY_DTYPE)
    if tensor.shape == dest_shape:
        return tensor.clone(), "copy"
    # Allow unsqueeze on leading dimension
    if len(dest_shape) == tensor.ndim + 1 and dest_shape[0] == 1:
        tensor = tensor.unsqueeze(0)
    scaled_shape = scale_shape(tensor.shape, d_old, d_new, tol)
    if tuple(dest_shape) == tuple(scaled_shape):
        widened = widen_tensor(tensor, dest_shape)
        status = "copy" if tensor.shape == dest_shape else "widen"
        return widened, status
    return None, "skip"


def transplant_weights(
    src_state: Dict[str, torch.Tensor],
    dest_state: Dict[str, torch.Tensor],
    base_layers: int,
    d_old: int,
    d_new: int,
    tol: float,
) -> Tuple[Dict[str, torch.Tensor], List[str], List[str], List[str], Dict[str, str]]:
    mapping = build_parameter_mapping(dest_state, src_state, base_layers, d_old, d_new, tol)
    updated = dict(dest_state)
    copied: List[str] = []
    widened: List[str] = []
    skipped: List[str] = []
    for dest_key, dest_tensor in dest_state.items():
        if is_new_layer_key(dest_key, base_layers):
            skipped.append(dest_key)
            continue
        src_key = mapping.get(dest_key)
        if src_key is None:
            skipped.append(dest_key)
            continue
        src_tensor = src_state[src_key]
        adapted, status = adapt_tensor(src_tensor, tuple(dest_tensor.shape), d_old, d_new, tol)
        if adapted is None:
            skipped.append(dest_key)
            continue
        updated[dest_key] = adapted
        if status == "copy":
            copied.append(dest_key)
        else:
            widened.append(dest_key)
    return updated, copied, widened, skipped, mapping


def init_classic_block(block: nn.Module) -> None:
    for module in block.modules():
        if isinstance(module, nn.LayerNorm):
            module.weight.data.fill_(1.0)
            module.bias.data.zero_()
        elif isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=1e-3)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.MultiheadAttention):
            nn.init.normal_(module.in_proj_weight, mean=0.0, std=1e-3)
            if module.in_proj_bias is not None:
                nn.init.zeros_(module.in_proj_bias)
            nn.init.normal_(module.out_proj.weight, mean=0.0, std=1e-3)
            if module.out_proj.bias is not None:
                nn.init.zeros_(module.out_proj.bias)
    for name, param in block.named_parameters():
        if "scale" in name:
            param.data.fill_(1e-3)


def init_liquid_block(block: nn.Module) -> None:
    for module in block.modules():
        if isinstance(module, nn.LayerNorm):
            module.weight.data.fill_(1.0)
            module.bias.data.zero_()
        elif isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
            module.weight.data.mul_(1e-3)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.MultiheadAttention):
            nn.init.kaiming_uniform_(module.in_proj_weight, a=math.sqrt(5))
            module.in_proj_weight.data.mul_(1e-3)
            if module.in_proj_bias is not None:
                nn.init.zeros_(module.in_proj_bias)
            nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
            module.out_proj.weight.data.mul_(1e-3)
            if module.out_proj.bias is not None:
                nn.init.zeros_(module.out_proj.bias)
    for name, param in block.named_parameters():
        if "scale" in name or "gate" in name:
            param.data.fill_(1e-3)


def init_new_layers(model: Stage1Model, base_layers: int, add_classic: int, add_liquid: int) -> None:
    total_blocks = len(model.blocks)
    for idx in range(base_layers, min(base_layers + add_classic, total_blocks)):
        if isinstance(model.blocks[idx], ClassicBlock):
            init_classic_block(model.blocks[idx])
    start_liquid = base_layers + add_classic
    for idx in range(start_liquid, min(start_liquid + add_liquid, total_blocks)):
        if isinstance(model.blocks[idx], LiquidBlock):
            init_liquid_block(model.blocks[idx])


def gather_layer_types(model: Stage1Model) -> List[str]:
    layout: List[str] = []
    for block in model.blocks:
        if isinstance(block, ClassicBlock):
            layout.append("classic")
        elif isinstance(block, LiquidBlock):
            layout.append("liquid")
        else:
            layout.append(type(block).__name__)
    return layout


In [None]:
# Cell 5: artifact writers, reporting, and optional dry-run

def maybe_save_safetensors(
    state_dict: Dict[str, torch.Tensor], path: str, metadata: Dict[str, Any]
) -> Optional[str]:
    try:
        from safetensors.torch import save_file  # type: ignore

        meta = {k: str(v) for k, v in metadata.items() if isinstance(v, (str, int, float))}
        save_file(state_dict, path, metadata=meta)
        return path
    except Exception as exc:  # pragma: no cover - optional dependency
        print(f"safetensors save unavailable ({exc}); falling back to torch.save")
        return None


def save_checkpoint(
    model: Stage1Model,
    output_dir: str,
    metadata: Dict[str, Any],
) -> str:
    timestamp = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")
    base = f"stage1_surgery_{timestamp}"
    safepath = os.path.join(output_dir, base + ".safetensors")
    state = {k: v.detach().to("cpu") for k, v in model.state_dict().items()}
    saved = maybe_save_safetensors(state, safepath, metadata)
    if saved:
        return saved
    pypath = os.path.join(output_dir, base + ".pt")
    torch.save({"state_dict": state, "metadata": metadata}, pypath)
    return pypath


def write_json(path: str, payload: Dict[str, Any]) -> None:
    with open(path, "w") as handle:
        json.dump(payload, handle, indent=2, sort_keys=True)


def write_freeze_mask(
    output_dir: str,
    base_layers: int,
    add_classic: int,
    total_layers: int,
) -> str:
    mask = {}
    for idx in range(total_layers):
        mask[f"blocks.{idx}"] = bool(base_layers <= idx < base_layers + add_classic)
    path = os.path.join(output_dir, "freeze_mask.json")
    write_json(path, mask)
    return path


def write_config(output_dir: str, config: ModelConfig, extras: Dict[str, Any]) -> str:
    cfg = asdict(config)
    cfg.update(extras)
    path = os.path.join(output_dir, "config_stage1.json")
    write_json(path, cfg)
    return path


def write_report(
    output_dir: str,
    stats: Dict[str, Any],
    target_d_model: int,
    copied: List[str],
    widened: List[str],
    skipped: List[str],
    missing: List[str],
    unexpected: List[str],
    artifacts: Dict[str, str],
    layer_types: List[str],
    dry_run: Optional[Dict[str, Any]] = None,
) -> str:
    report_path = os.path.join(output_dir, "surgery_report.md")
    lines = [
        "# Stage-1 Surgery Report",
        "",
        f"Source checkpoint: `{CHECKPOINT_PATH}`",
        f"Output checkpoint: `{artifacts['weights']}`",
        "",
        "## Inferred Stage-0",
        f"- n_layers: {stats['n_layers']}",
        f"- d_model: {stats['d_model']}",
        f"- vocab_size: {stats['vocab_size']}",
        f"- n_heads: {stats.get('n_heads')}",
        "## Stage-1 Target",
        f"- widened d_model: {stats['d_model']} -> {target_d_model}",
        f"- total layers: {len(layer_types)} (layout: {layer_types})",
        "",
        "## Tensor Copy Stats",
        f"- copied (same shape): {len(copied)}",
        f"- widened (Net2Net): {len(widened)}",
        f"- kept/init destination: {len(skipped)}",
        "",
        "## Load Diagnostics",
        f"- missing keys (strict=False): {len(missing)}",
        f"- unexpected keys: {len(unexpected)}",
    ]
    if dry_run is not None:
        lines.extend(
            [
                "",
                "## Dry-Run Cosine Similarity",
                f"- ran: {dry_run['status']}",
            ]
        )
        if dry_run.get("status") == "ok":
            lines.append(f"- cosine similarity: {dry_run['cosine']:.6f}")
            lines.append(f"- widen missing keys: {dry_run['missing']}")
            lines.append(f"- widen unexpected keys: {dry_run['unexpected']}")
        elif dry_run.get("message"):
            lines.append(f"- note: {dry_run['message']}")
    lines.extend(
        [
            "",
            "Artifacts:",
            f"- weights: {artifacts['weights']}",
            f"- config: {artifacts['config']}",
            f"- freeze mask: {artifacts['freeze_mask']}",
        ]
    )
    with open(report_path, "w") as handle:
        handle.write("
".join(lines) + "
")
    return report_path


def attempt_dry_run(
    src_state: Dict[str, torch.Tensor],
    stats: Dict[str, Any],
    target_d_model: int,
    widen_pct: float,
) -> Optional[Dict[str, Any]]:
    if not ENABLE_DRY_RUN:
        return None
    if StudentLM is None:
        return {"status": "skipped", "message": "StudentLM import unavailable"}
    if stats.get("n_heads") is None:
        return {"status": "skipped", "message": "n_heads unknown; set N_HEADS_OVERRIDE"}
    try:
        student = StudentLM(
            stats["vocab_size"],
            d_model=stats["d_model"],
            n_layers=stats["n_layers"],
            n_heads=stats["n_heads"],
        ).to(SURGERY_DTYPE)
        load_res = student.load_state_dict(src_state, strict=False)
        if load_res:
            print(f"StudentLM load differences: {load_res}")
    except Exception as exc:
        return {"status": "skipped", "message": f"failed to instantiate StudentLM: {exc}"}

    try:
        cfg = ModelConfig(
            vocab_size=stats["vocab_size"],
            d_model=stats["d_model"],
            n_heads=stats["n_heads"],
            n_layers=stats["n_layers"],
            max_seq_len=stats["max_seq_len"],
            widen_pct=widen_pct,
            add_classic=0,
            add_liquid=0,
        )
        widen_model = Stage1Model(cfg).to(SURGERY_DTYPE)
        dest_state = widen_model.state_dict()
        transplanted, _, _, _, _ = transplant_weights(
            src_state, dest_state, stats["n_layers"], stats["d_model"], target_d_model, WIDTH_TOLERANCE
        )
        missing, unexpected = widen_model.load_state_dict(transplanted, strict=False)
        widen_model.eval()
    except Exception as exc:
        return {"status": "skipped", "message": f"widen-only model load failed: {exc}"}

    student.eval()
    with torch.no_grad():
        tokens = torch.randint(0, stats["vocab_size"], (DRY_RUN_BATCH, DRY_RUN_SEQ))
        student_out = student(tokens)
        logits0 = student_out[0] if isinstance(student_out, tuple) else student_out
        widen_out = widen_model(tokens.to(logits0.device))
        logits1 = widen_out[0] if isinstance(widen_out, tuple) else widen_out
        cos = F.cosine_similarity(
            logits0.reshape(logits0.size(0), -1), logits1.reshape(logits1.size(0), -1)
        )
        cosine = float(cos.mean().item())
    status = "ok" if cosine >= DRY_RUN_COS_TGT else "warn"
    return {
        "status": status,
        "cosine": cosine,
        "missing": len(missing),
        "unexpected": len(unexpected),
    }


In [None]:
# Cell 6: run the Stage-1 surgery end-to-end

def run_stage1_surgery() -> Dict[str, Any]:
    src_state, metadata = load_checkpoint(CHECKPOINT_PATH)
    stats = infer_model_stats(src_state, metadata)
    if stats.get("n_heads") is None:
        if N_HEADS_OVERRIDE is None:
            raise ValueError(
                "Unable to infer n_heads from checkpoint metadata. Set N_HEADS_OVERRIDE manually."
            )
        stats["n_heads"] = N_HEADS_OVERRIDE
    elif N_HEADS_OVERRIDE is not None:
        stats["n_heads"] = int(N_HEADS_OVERRIDE)

    target_d_model = compute_target_width(
        stats["d_model"], WIDTH_SCALE, ROUND_MULT, stats.get("n_heads")
    )
    effective_scale = target_d_model / float(stats["d_model"])
    widen_pct = (effective_scale - 1.0) * 100.0

    print("--- Stage-0 inference ---")
    for key in ["n_layers", "d_model", "vocab_size", "max_seq_len", "n_heads"]:
        print(f"{key:>12}: {stats.get(key)}")
    print(f"embed key: {stats['embed_key']}")

    dry_run = attempt_dry_run(src_state, stats, target_d_model, widen_pct)
    if dry_run is not None:
        print(f"Dry-run status: {dry_run.get('status')} (cosine={dry_run.get('cosine')})")

    config = ModelConfig(
        vocab_size=stats["vocab_size"],
        d_model=stats["d_model"],
        n_heads=stats["n_heads"],
        n_layers=stats["n_layers"],
        max_seq_len=stats["max_seq_len"],
        widen_pct=widen_pct,
        add_classic=ADD_CLASSIC_LAY,
        add_liquid=ADD_LIQUID_LAY,
    )
    model = Stage1Model(config).to(SURGERY_DTYPE)
    init_new_layers(model, stats["n_layers"], ADD_CLASSIC_LAY, ADD_LIQUID_LAY)

    dest_state = model.state_dict()
    transplanted, copied, widened, skipped, mapping = transplant_weights(
        src_state,
        dest_state,
        stats["n_layers"],
        stats["d_model"],
        target_d_model,
        WIDTH_TOLERANCE,
    )
    missing, unexpected = model.load_state_dict(transplanted, strict=False)

    layer_types = gather_layer_types(model)
    new_param_keys = {k for k in dest_state if is_new_layer_key(k, stats["n_layers"])}
    skipped_existing = [k for k in skipped if k not in new_param_keys]

    print("
--- Stage-1 summary ---")
    print(f"target d_model: {target_d_model} (scale={effective_scale:.4f}, widen_pct={widen_pct:.2f})")
    print(f"blocks: base={stats['n_layers']} + classic={ADD_CLASSIC_LAY} + liquid={ADD_LIQUID_LAY}")
    print(f"copied tensors: {len(copied)}")
    print(f"widened tensors: {len(widened)}")
    print(f"new/init tensors: {len(skipped)} (new layers={len(new_param_keys)}, unmatched existing={len(skipped_existing)})")
    print(f"load_state_dict missing={len(missing)} unexpected={len(unexpected)}")

    metadata_out = {
        "source_path": CHECKPOINT_PATH,
        "d_model_before": stats["d_model"],
        "d_model_after": target_d_model,
        "effective_width_scale": effective_scale,
        "added_classic": ADD_CLASSIC_LAY,
        "added_liquid": ADD_LIQUID_LAY,
        "timestamp": datetime.datetime.utcnow().isoformat() + "Z",
    }

    weights_path = save_checkpoint(model, OUTPUT_DIR, metadata_out)
    config_path = write_config(
        OUTPUT_DIR,
        config,
        {
            "layer_types": layer_types,
            "target_d_model": target_d_model,
            "effective_width_scale": effective_scale,
        },
    )
    freeze_path = write_freeze_mask(
        OUTPUT_DIR, stats["n_layers"], ADD_CLASSIC_LAY, len(layer_types)
    )
    report_path = write_report(
        OUTPUT_DIR,
        stats,
        target_d_model,
        copied,
        widened,
        skipped,
        list(missing),
        list(unexpected),
        {
            "weights": weights_path,
            "config": config_path,
            "freeze_mask": freeze_path,
        },
        layer_types,
        dry_run=dry_run,
    )

    print("
Artifacts written:")
    print(f"- weights:      {weights_path}")
    print(f"- config:       {config_path}")
    print(f"- freeze mask:  {freeze_path}")
    print(f"- report:       {report_path}")

    return {
        "stats": stats,
        "target_d_model": target_d_model,
        "copied": copied,
        "widened": widened,
        "skipped": skipped,
        "missing": missing,
        "unexpected": unexpected,
        "weights_path": weights_path,
        "config_path": config_path,
        "freeze_path": freeze_path,
        "report_path": report_path,
        "dry_run": dry_run,
        "mapping": mapping,
    }


# Execute when the cell runs
results = run_stage1_surgery()
results
