# Stage-1 Model Surgery (Adaptive) — Liquid → Hybrid (Net2Net widen + new layers)

This notebook:
- Introspects your student checkpoint to infer d_model, n_layers, and vocab size.
- Widens all tensors that depend on d_model by +10% (rounded to a multiple of 8).
- Appends 2 classic (transformer-style) layers (tiny-out init) and 3 liquid layers (small-scale init).
- Writes out: a new checkpoint, config_stage1.json, freeze_mask.json, and surgery_report.md.

> Fill in the two placeholders for CHECKPOINT_PATH and OUTPUT_DIR below.

In [None]:

# Cell 2: imports, config, and IO paths
import os
import re
import json
import math
import time
import shutil
import datetime
from collections import Counter, defaultdict, OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn.init as nn_init

if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

SURGERY_DTYPE = torch.float32

CHECKPOINT_PATH = "/path/to/stage0_student.pt"
OUTPUT_DIR = "./_stage1_out"

WIDTH_SCALE = 1.10
ROUND_MULT = 8
ADD_CLASSIC_LAY = 2
ADD_LIQUID_LAY = 3
WIDTH_TOLERANCE = 0.06
DEFAULT_MAX_SEQ = 4096
MAX_K_MATCH = 8

CUSTOM_LAYER_REGEX = ""  # e.g., r"^transformer\.h\.(\d+)\.(.+)$"


torch.manual_seed(0)


In [None]:

# Cell 3: checkpoint IO + layer schema discovery
PREFIXES_TO_STRIP = (
    "module.",
    "model.",
    "student.",
    "state_dict.",
    "network.",
)


def _is_placeholder(value: Optional[str]) -> bool:
    if value is None:
        return True
    value = value.strip()
    if not value:
        return True
    placeholder_tokens = {"/path/to", "CHANGE_ME", "your/checkpoint", "gs://bucket"}
    return any(token in value for token in placeholder_tokens)


def strip_known_prefixes(key: str) -> str:
    new_key = key
    changed = True
    while changed:
        changed = False
        for prefix in PREFIXES_TO_STRIP:
            if new_key.startswith(prefix):
                new_key = new_key[len(prefix) :]
                changed = True
    return new_key


def _safe_torch_load(path: str) -> Any:
    """
    Try to load with weights_only=True first (PyTorch >=2.4 behavior),
    then fall back to weights_only=False if needed.
    """
    try:
        return torch.load(path, map_location="cpu", weights_only=True)
    except TypeError:
        # older torch doesn't support weights_only
        return torch.load(path, map_location="cpu")
    except Exception:
        raise


def _safe_load_any(path: str) -> Any:
    """Load torch or safetensors checkpoints safely."""
    if path.lower().endswith(".safetensors"):
        # Prefer safetensors for .safetensors files when available.
        try:
            from safetensors.torch import load_file
        except ImportError:
            return _safe_torch_load(path)
        return load_file(path, device="cpu")
    return _safe_torch_load(path)


def _extract_state_dict_root(obj: Any) -> Dict[str, Any]:
    """
    Given an arbitrary checkpoint object, return the most likely state_dict-ish mapping.
    Handles common wrappers: 'state_dict', 'model', 'module', 'model_state_dict', 'ema_state_dict'.
    If multiple candidates exist, prefer 'state_dict' > 'model' > 'module' > '*state_dict'.
    """
    if isinstance(obj, (dict, OrderedDict)):
        vals = list(obj.values())
        if vals and sum(isinstance(v, torch.Tensor) for v in vals) / len(vals) > 0.6:
            return obj
        for key in ["state_dict", "model", "module", "model_state_dict", "ema_state_dict"]:
            if key in obj and isinstance(obj[key], (dict, OrderedDict)):
                return obj[key]
        for v in obj.values():
            if isinstance(v, (dict, OrderedDict)):
                return v
    raise TypeError("Checkpoint must be a mapping that contains a state_dict-like mapping.")


def _flatten_nested_state_dict(sd_like: Any, prefix: str = "") -> Dict[str, torch.Tensor]:
    """
    Recursively flatten nested mappings into a single dict[str, Tensor].
    Only include leaves that are torch.Tensors; ignore non-tensor leaves safely.
    """
    flat = {}
    stack = [(prefix, sd_like)]
    while stack:
        pfx, node = stack.pop()
        if isinstance(node, (dict, OrderedDict)):
            for k, v in node.items():
                key = f"{pfx}{k}" if pfx == "" else f"{pfx}.{k}"
                if isinstance(v, (dict, OrderedDict)):
                    stack.append((key, v))
                elif isinstance(v, torch.Tensor):
                    flat[key] = v.detach().cpu()
                else:
                    pass
        else:
            pass
    if not flat:
        raise ValueError("No tensor leaves found in checkpoint. Is this a valid state_dict?")
    return flat


def load_stage0_state_dict(path: str) -> Tuple["OrderedDict[str, torch.Tensor]", Dict[str, Any]]:
    """
    Load a checkpoint from `path`, extract a state_dict-ish mapping, and
    return a FLAT dict of tensors keyed by full dotted paths, plus metadata.
    """
    if _is_placeholder(path):
        raise ValueError("CHECKPOINT_PATH is a placeholder; please provide a real checkpoint path before running surgery.")
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    checkpoint = _safe_load_any(path)
    root = _extract_state_dict_root(checkpoint)
    flat = _flatten_nested_state_dict(root)
    canonical = OrderedDict((strip_known_prefixes(k), v) for k, v in sorted(flat.items()))
    metadata = {}
    if isinstance(checkpoint, (dict, OrderedDict)):
        metadata = {k: v for k, v in checkpoint.items() if v is not root}
    metadata = dict(metadata)
    metadata.setdefault("original_keys", len(root) if isinstance(root, (dict, OrderedDict)) else None)
    metadata["flat_keys"] = len(canonical)
    return canonical, metadata


_LAYER_PATTERNS = [
    r"^(?:module\.)?(?:model\.)?(?:transformer\.)?(?:h|layers|blocks|block|layer|encoder\.layers|backbone\.layers)\.(\d+)\.(.+)$",
    r"^(?:module\.)?(?:model\.)?(?:encoder)\.(?:layers)\.(\d+)\.(.+)$",
    r"^(?:module\.)?(?:model\.)?(?:decoder)\.(?:layers)\.(\d+)\.(.+)$",
    r"^(?:module\.)?(?:model\.)?(?:gpt_neox)\.(?:layers)\.(\d+)\.(.+)$",
    r"^(?:module\.)?(?:model\.)?(?:transformer)\.(?:blocks|h)\.(\d+)\.(.+)$",
    r"^(?:module\.)?(?:model\.)?(?:blocks|layers)\.(\d+)\.(.+)$",
]

def _match_with_patterns(key: str):
    if CUSTOM_LAYER_REGEX:
        custom_match = re.match(CUSTOM_LAYER_REGEX, key)
        if custom_match:
            idx = int(custom_match.group(1))
            sub = custom_match.group(2)
            return idx, sub
    for pat in _LAYER_PATTERNS:
        m = re.match(pat, key)
        if m:
            idx = int(m.group(1))
            sub = m.group(2)
            return idx, sub
    return None

def _fallback_numeric_segment(key: str):
    parts = key.split(".")
    for i, seg in enumerate(parts):
        if seg.isdigit():
            idx = int(seg)
            sub = ".".join(parts[i + 1 :]) if i + 1 < len(parts) else ""
            return idx, sub
    return None

def discover_layer_schema(state_dict: Dict[str, torch.Tensor]) -> Tuple[List[int], List[str]]:
    """
    Discover (layer_indices, layer_subkeys) from arbitrary key layouts.
    Returns:
        layer_indices: sorted list of ints
        layer_subkeys: ordered list of subkey strings (canonical schema)
    Behavior:
        - Try several regexes; collect (idx, subkey) matches.
        - If none match, use a generic fallback that finds the first numeric segment.
        - Pick the most common subkey-set across layers as the canonical schema.
    """
    layer_sets = defaultdict(set)
    layer_orders = defaultdict(list)
    matched = 0

    for k in state_dict.keys():
        res = _match_with_patterns(k)
        if res:
            matched += 1
            idx, sub = res
            layer_sets[idx].add(sub)
            layer_orders[idx].append(sub)

    if matched == 0:
        for k in state_dict.keys():
            res = _fallback_numeric_segment(k)
            if res:
                idx, sub = res
                if sub:
                    layer_sets[idx].add(sub)
                    layer_orders[idx].append(sub)

    if not layer_sets:
        sample = list(state_dict.keys())[:50]
        raise ValueError(
            "Could not discover layered parameters. Examples of keys:\n  - "
            + "\n  - ".join(sample)
        )

    layer_indices = sorted(layer_sets.keys())

    set_counts = Counter(frozenset(s) for s in layer_sets.values())
    canonical_set = max(set_counts.items(), key=lambda kv: kv[1])[0]
    canonical_list = sorted(list(canonical_set))

    return layer_indices, canonical_list


def infer_model_stats(state_dict: Dict[str, torch.Tensor], layer_indices: List[int]) -> Dict[str, Any]:
    n_layers = max(layer_indices) + 1 if layer_indices else 0
    norm_sizes: List[int] = []
    one_d_sizes: List[int] = []
    embed_shape: Optional[Tuple[int, ...]] = None
    embed_key: Optional[str] = None
    lm_head_shape: Optional[Tuple[int, ...]] = None
    lm_head_key: Optional[str] = None
    for key, tensor in state_dict.items():
        if not isinstance(tensor, torch.Tensor):
            continue
        tensor = tensor.detach().cpu()
        if tensor.ndim == 1:
            size = int(tensor.shape[0])
            one_d_sizes.append(size)
            lower = key.lower()
            if "ln" in lower or "norm" in lower:
                norm_sizes.append(size)
        lower_key = key.lower()
        if embed_shape is None:
            if (
                "emb.weight" in lower_key
                or lower_key.endswith("embedding.weight")
                or lower_key.endswith("embeddings.weight")
                or lower_key.endswith("wte.weight")
                or lower_key.endswith("tok_embeddings.weight")
                or lower_key.endswith("word_embeddings.weight")
            ):
                embed_shape = tensor.shape
                embed_key = key
        if lm_head_shape is None and lower_key.endswith("lm_head.weight"):
            lm_head_shape = tensor.shape
            lm_head_key = key
    def _most_common(values: List[int]) -> Optional[int]:
        if not values:
            return None
        counts = Counter(values)
        return counts.most_common(1)[0][0]
    d_model = _most_common(norm_sizes)
    if d_model is None and embed_shape is not None and len(embed_shape) >= 2:
        d_model = int(embed_shape[-1])
    if d_model is None and lm_head_shape is not None and len(lm_head_shape) >= 2:
        d_model = int(lm_head_shape[-1])
    if d_model is None:
        d_model = _most_common(one_d_sizes)
    if d_model is None:
        raise ValueError("Unable to infer d_model from the checkpoint tensors.")
    if embed_shape is None and lm_head_shape is not None:
        vocab_size = int(lm_head_shape[0])
        embed_key = lm_head_key
    elif embed_shape is not None:
        vocab_size = int(embed_shape[0])
    else:
        vocab_size = 50257
    stats = {
        "n_layers": n_layers,
        "d_model": int(d_model),
        "vocab_size": int(vocab_size),
        "embed_key": embed_key,
        "lm_head_key": lm_head_key,
    }
    return stats


In [None]:

# Cell 4: shape scaling, Net2Net widening, and initialization helpers
def compute_target_width(d_model: int, scale: float, round_mult: int) -> int:
    widened = math.ceil(d_model * scale)
    if round_mult > 1:
        widened = int(math.ceil(widened / round_mult) * round_mult)
    return int(max(d_model, widened))


def scale_dim_like_dmodel(dim: int, d0: int, d1: int, tol: float = WIDTH_TOLERANCE, max_k: int = MAX_K_MATCH) -> Optional[int]:
    if d0 <= 0 or dim <= 0:
        return None
    for k in range(1, max_k + 1):
        target = k * d0
        if abs(dim - target) <= max(int(round(target * tol)), 1):
            return int(k * d1)
    return None


def compute_new_shape(shape: Tuple[int, ...], d0: int, d1: int) -> Tuple[Tuple[int, ...], bool]:
    new_shape = list(shape)
    changed = False
    if len(shape) == 1:
        scaled = scale_dim_like_dmodel(shape[0], d0, d1)
        if scaled is not None and scaled != shape[0]:
            new_shape[0] = scaled
            changed = True
    elif len(shape) == 2:
        for axis in range(2):
            scaled = scale_dim_like_dmodel(shape[axis], d0, d1)
            if scaled is not None and scaled != shape[axis]:
                new_shape[axis] = scaled
                changed = True
    return tuple(new_shape), changed


def widen_tensor(tensor: torch.Tensor, new_shape: Tuple[int, ...]) -> torch.Tensor:
    result = tensor.to(dtype=SURGERY_DTYPE, device="cpu")
    if tuple(result.shape) == new_shape:
        return result.clone()
    for axis, target_dim in enumerate(new_shape):
        current_dim = result.shape[axis]
        if current_dim == target_dim:
            continue
        if current_dim <= 0:
            raise ValueError(f"Cannot widen axis {axis} with size {current_dim} to {target_dim} for tensor.")
        if current_dim == 1:
            expand_shape = list(result.shape)
            expand_shape[axis] = target_dim
            result = result.expand(*expand_shape).clone()
        else:
            indices = torch.linspace(0, current_dim - 1, target_dim, dtype=torch.float32)
            indices = indices.round().clamp(0, current_dim - 1).to(torch.long)
            result = torch.index_select(result, axis, indices)
    if tuple(result.shape) != new_shape:
        result = result.reshape(new_shape)
    if tensor.numel() > 0 and tuple(tensor.shape) != new_shape:
        std = tensor.float().std().item() if tensor.numel() > 1 else 0.0
        jitter_std = 0.01 * std if std > 0 else 1e-3
        if jitter_std > 0:
            result = result + torch.randn_like(result) * jitter_std
    return result


def maybe_widen_tensor(key: str, tensor: torch.Tensor, d0: int, d1: int, target_shape: Optional[Tuple[int, ...]] = None) -> Tuple[torch.Tensor, str, Tuple[int, ...], Tuple[int, ...]]:
    source = tensor.detach().cpu().to(dtype=SURGERY_DTYPE)
    old_shape = tuple(source.shape)
    if target_shape is None:
        target_shape, changed = compute_new_shape(old_shape, d0, d1)
    else:
        changed = target_shape != old_shape
    if not changed or target_shape is None:
        return source.clone(), "copied", old_shape, old_shape
    widened = widen_tensor(source, target_shape)
    return widened, "widened", old_shape, target_shape


def infer_fallback_shape(subkey: str, d1: int) -> Tuple[int, ...]:
    lower = subkey.lower()
    if "weight" in lower and ("ln" in lower or "norm" in lower):
        return (d1,)
    if "bias" in lower and ("ln" in lower or "norm" in lower):
        return (d1,)
    if "bias" in lower:
        return (d1,)
    if "weight" in lower:
        return (d1, d1)
    if "scale" in lower:
        return ()
    return (d1,)


def synthesize_parameter(subkey: str, target_shape: Optional[Tuple[int, ...]], d1: int) -> torch.Tensor:
    shape = target_shape or infer_fallback_shape(subkey, d1)
    if len(shape) == 0:
        value = 1e-3 if "scale" in subkey.lower() else 0.0
        return torch.tensor(value, dtype=SURGERY_DTYPE)
    tensor = torch.empty(shape, dtype=SURGERY_DTYPE)
    lower = subkey.lower()
    if len(shape) == 1:
        if "weight" in lower and ("ln" in lower or "norm" in lower):
            tensor.fill_(1.0)
        elif "bias" in lower:
            tensor.zero_()
        elif "scale" in lower:
            tensor.fill_(1e-3)
        elif "weight" in lower:
            nn_init.kaiming_uniform_(tensor.unsqueeze(0))
            tensor.mul_(1e-3)
        else:
            tensor.zero_()
    elif len(shape) == 2:
        nn_init.kaiming_uniform_(tensor, a=math.sqrt(5))
        tensor.mul_(1e-3)
    else:
        tensor.zero_()
    return tensor


def build_new_state_dict(
    state_dict: Dict[str, torch.Tensor],
    stats: Dict[str, Any],
    layer_subkeys: List[str],
    target_d_model: int,
) -> Tuple["OrderedDict[str, torch.Tensor]", Dict[str, List[str]], List[Tuple[str, Tuple[int, ...], Tuple[int, ...]]], List[Dict[str, Any]]]:
    d0 = stats["d_model"]
    d1 = target_d_model
    sd_new: "OrderedDict[str, torch.Tensor]" = OrderedDict()
    operations: Dict[str, List[str]] = {"copied": [], "widened": [], "synthesized": []}
    widen_samples: List[Tuple[str, Tuple[int, ...], Tuple[int, ...]]] = []
    new_layer_info: List[Dict[str, Any]] = []
    layer_values: Dict[int, Dict[str, torch.Tensor]] = defaultdict(dict)

    for key, tensor in state_dict.items():
        match = _match_with_patterns(key)
        if not match:
            match = _fallback_numeric_segment(key)
        if match:
            idx, subkey = match
            if subkey:
                layer_values[idx][subkey] = tensor
                continue
        new_tensor, op, old_shape, new_shape = maybe_widen_tensor(key, tensor, d0, d1)
        sd_new[key] = new_tensor
        operations[op].append(key)
        if op == "widened" and len(widen_samples) < 5:
            widen_samples.append((key, old_shape, new_shape))

    layer_target_shapes: Dict[str, Optional[Tuple[int, ...]]] = {}
    for subkey in layer_subkeys:
        sample_tensor = None
        for idx in range(stats["n_layers"]):
            candidate = layer_values.get(idx, {}).get(subkey)
            if candidate is not None:
                sample_tensor = candidate
                break
        if sample_tensor is not None:
            shape, _ = compute_new_shape(tuple(sample_tensor.shape), d0, d1)
        else:
            shape = infer_fallback_shape(subkey, d1)
        layer_target_shapes[subkey] = shape

    for idx in range(stats["n_layers"]):
        values = layer_values.get(idx, {})
        for subkey in layer_subkeys:
            source_tensor = values.get(subkey)
            if source_tensor is None:
                continue
            key = f"layers.{idx}.{subkey}"
            target_shape = layer_target_shapes.get(subkey)
            new_tensor, op, old_shape, new_shape = maybe_widen_tensor(key, source_tensor, d0, d1, target_shape)
            sd_new[key] = new_tensor
            operations[op].append(key)
            if op == "widened" and len(widen_samples) < 5:
                widen_samples.append((key, old_shape, new_shape))

    total_new_layers = ADD_CLASSIC_LAY + ADD_LIQUID_LAY
    for offset in range(total_new_layers):
        layer_idx = stats["n_layers"] + offset
        layer_type = "classic" if offset < ADD_CLASSIC_LAY else "liquid"
        new_layer_info.append({"index": layer_idx, "type": layer_type})
        for subkey in layer_subkeys:
            key = f"layers.{layer_idx}.{subkey}"
            target_shape = layer_target_shapes.get(subkey)
            synthesized = synthesize_parameter(subkey, target_shape, d1)
            sd_new[key] = synthesized
            operations["synthesized"].append(key)

    return sd_new, operations, widen_samples, new_layer_info


In [None]:

    # Cell 5: artifact writers, reporting, and utility helpers
    def build_layer_types(n_layers: int) -> List[str]:
        return ["liquid"] * n_layers + ["classic"] * ADD_CLASSIC_LAY + ["liquid"] * ADD_LIQUID_LAY


    def build_freeze_mask(n_layers: int) -> Dict[str, bool]:
        total_layers = n_layers + ADD_CLASSIC_LAY + ADD_LIQUID_LAY
        mask: Dict[str, bool] = {}
        for idx in range(total_layers):
            freeze = n_layers <= idx < n_layers + ADD_CLASSIC_LAY
            mask[f"layers.{idx}"] = bool(freeze)
        return mask


    def save_artifacts(
        sd_new: "OrderedDict[str, torch.Tensor]",
        stats: Dict[str, Any],
        target_d_model: int,
        layer_types: List[str],
        operations: Dict[str, List[str]],
        new_layer_info: List[Dict[str, Any]],
        widen_samples: List[Tuple[str, Tuple[int, ...], Tuple[int, ...]]],
        metadata: Dict[str, Any],
    ) -> Dict[str, str]:
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        checkpoint_path = os.path.join(OUTPUT_DIR, "stage1_surgery.pt")
        state_to_save = OrderedDict((k, v.detach().cpu().to(dtype=SURGERY_DTYPE)) for k, v in sd_new.items())
        torch.save({"state_dict": state_to_save, "note": "Stage-1 surgery checkpoint"}, checkpoint_path)
        config = {
            "d_model_old": stats["d_model"],
            "d_model_new": target_d_model,
            "n_layers_old": stats["n_layers"],
            "n_layers_new": stats["n_layers"] + ADD_CLASSIC_LAY + ADD_LIQUID_LAY,
            "vocab_size": stats["vocab_size"],
            "layer_types": layer_types,
            "max_seq_len": DEFAULT_MAX_SEQ,
        }
        config_path = os.path.join(OUTPUT_DIR, "config_stage1.json")
        with open(config_path, "w", encoding="utf-8") as f:
            json.dump(config, f, indent=2)
        freeze_mask = build_freeze_mask(stats["n_layers"])
        freeze_path = os.path.join(OUTPUT_DIR, "freeze_mask.json")
        with open(freeze_path, "w", encoding="utf-8") as f:
            json.dump(freeze_mask, f, indent=2)
        copied_count = len(set(operations.get("copied", [])))
        widened_count = len(set(operations.get("widened", [])))
        synthesized_count = len(set(operations.get("synthesized", [])))
        report_lines = [
            "# Stage-1 Surgery Report",
            "",
            "## Source",
            f"- Checkpoint path: {CHECKPOINT_PATH}",
            f"- Inferred layers: {stats['n_layers']}",
            f"- Inferred d_model: {stats['d_model']}",
            f"- Inferred vocab size: {stats['vocab_size']}",
            "",
            "## Target",
            f"- Target d_model: {target_d_model}",
            f"- Total layers after surgery: {stats['n_layers'] + ADD_CLASSIC_LAY + ADD_LIQUID_LAY}",
            f"- Added classic layers: {ADD_CLASSIC_LAY}",
            f"- Added liquid layers: {ADD_LIQUID_LAY}",
            "",
            "## Operations",
            f"- Copied tensors: {copied_count}",
            f"- Widened tensors: {widened_count}",
            f"- Synthesized tensors: {synthesized_count}",
            "",
            "## New layers",
        ]
        for info in new_layer_info:
            report_lines.append(f"- layers.{info['index']}: {info['type']}")
        if widen_samples:
            report_lines.append("")
            report_lines.append("## Sample widened tensors")
            for key, old_shape, new_shape in widen_samples:
                report_lines.append(f"- {key}: {old_shape} -> {new_shape}")
        if metadata:
            report_lines.append("")
            report_lines.append("## Checkpoint metadata keys")
            for meta_key in sorted(metadata.keys()):
                report_lines.append(f"- {meta_key}")
        report_lines.append("")
        report_lines.append(f"Generated on {datetime.datetime.utcnow().isoformat()}Z")
        report_path = os.path.join(OUTPUT_DIR, "surgery_report.md")
        with open(report_path, "w", encoding="utf-8") as f:
            f.write("
".join(report_lines))
        return {
            "checkpoint": checkpoint_path,
            "config": config_path,
            "freeze_mask": freeze_path,
            "report": report_path,
        }


In [None]:

# Cell 6: run the Stage-1 surgery end-to-end
def run_stage1_surgery() -> Optional[Dict[str, Any]]:
    if _is_placeholder(CHECKPOINT_PATH):
        print("CHECKPOINT_PATH is a placeholder. Update it to run surgery.")
        return None
    start_time = time.time()
    state_dict, metadata = load_stage0_state_dict(CHECKPOINT_PATH)
    layer_indices, layer_subkeys = discover_layer_schema(state_dict)
    if not layer_indices:
        print("Layer discovery returned no layer indices. Showing first 50 parameter keys:")
        for key in list(state_dict.keys())[:50]:
            print(f"  - {key}")
        raise ValueError("Layer discovery did not find any layer indices.")
    earliest_idx = layer_indices[0]
    example_keys: List[str] = []
    for key in state_dict.keys():
        match = _match_with_patterns(key)
        if not match:
            match = _fallback_numeric_segment(key)
        if match and match[0] == earliest_idx:
            example_keys.append(key)
            if len(example_keys) >= 10:
                break
    print(f"Layer subkeys (first 10): {layer_subkeys[:10]}")
    if example_keys:
        print(f"Example parameter keys for layer index {earliest_idx} (first {len(example_keys)}):")
        for key in example_keys:
            print(f"  - {key}")
    else:
        print(f"No parameter keys found for layer index {earliest_idx} using detected patterns.")
    stats = infer_model_stats(state_dict, layer_indices)
    stats["layer_subkeys"] = layer_subkeys
    target_d_model = compute_target_width(stats["d_model"], WIDTH_SCALE, ROUND_MULT)
    sd_new, operations, widen_samples, new_layer_info = build_new_state_dict(
        state_dict, stats, layer_subkeys, target_d_model
    )
    layer_types = build_layer_types(stats["n_layers"])
    artifact_paths = save_artifacts(
        sd_new,
        stats,
        target_d_model,
        layer_types,
        operations,
        new_layer_info,
        widen_samples,
        metadata,
    )
    duration = time.time() - start_time
    op_counts = {name: len(set(keys)) for name, keys in operations.items()}
    print("--- Stage-0 inference ---")
    print(f"n_layers:   {stats['n_layers']}")
    print(f"d_model:    {stats['d_model']}")
    print(f"vocab_size: {stats['vocab_size']}")
    print(f"layer schema keys: {len(layer_subkeys)} entries")
    print("--- Stage-1 summary ---")
    print(f"target d_model: {target_d_model}")
    print(f"added classic layers: {ADD_CLASSIC_LAY}")
    print(f"added liquid layers:  {ADD_LIQUID_LAY}")
    print(f"copied tensors:      {op_counts.get('copied', 0)}")
    print(f"widened tensors:     {op_counts.get('widened', 0)}")
    print(f"synthesized tensors: {op_counts.get('synthesized', 0)}")
    print(f"artifacts: {json.dumps(artifact_paths, indent=2)}")
    print(f"elapsed: {duration:.2f}s")
    return {
        "state_dict": sd_new,
        "stats": stats,
        "target_d_model": target_d_model,
        "layer_types": layer_types,
        "operations": operations,
        "operation_counts": op_counts,
        "widen_samples": widen_samples,
        "new_layer_info": new_layer_info,
        "artifacts": artifact_paths,
        "metadata": metadata,
    }


results = run_stage1_surgery()
results


In [None]:

        # Cell 7: inspect the synthesized checkpoint
        if isinstance(results, dict) and results.get("state_dict") is not None:
            sd_preview = results["state_dict"]
            all_keys = list(sd_preview.keys())
            print("First 10 keys:")
            for key in all_keys[:10]:
                print(" ", key)
            widen_samples = results.get("widen_samples", [])
            if widen_samples:
                print("
Sample widened tensors:")
                for key, old_shape, new_shape in widen_samples[:3]:
                    print(f" - {key}: {old_shape} -> {new_shape}")
            op_counts = results.get("operation_counts", {})
            print("
Operation counts:")
            for name in sorted(op_counts.keys()):
                print(f" - {name}: {op_counts[name]}")
        else:
            print("No results available — set CHECKPOINT_PATH and rerun the surgery cell.")
