Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions code/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
133 changes: 133 additions & 0 deletions code/export/checkpoint_to_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

"""
Convert a .pt checkpoint to a .safetensors file (fp16 or fp32).

Usage:
PYTHONPATH=code python code/export/checkpoint_to_safetensors.py \\
--checkpoint models/PreDecoderModelMemory_v1.0.94.pt \\
--model-id 1 [--fp16]
"""

from __future__ import annotations

import argparse
from pathlib import Path

import torch

from export.safetensors_utils import _build_minimal_cfg, save_safetensors
from model.factory import ModelFactory


def _load_checkpoint_state_dict(checkpoint_path: str, device: str) -> dict:
"""
Load a state dict from a .pt checkpoint, handling multiple saved formats:
- bare state dict (keys are layer names)
- {"model_state_dict": ...}
- {"state_dict": ...}
Also strips the DDP "module." prefix if present.
"""
raw = torch.load(checkpoint_path, map_location=device)

if isinstance(raw, dict):
if "model_state_dict" in raw:
state_dict = raw["model_state_dict"]
elif "state_dict" in raw:
state_dict = raw["state_dict"]
else:
# Assume it is a bare state dict
state_dict = raw
else:
raise ValueError(
f"Unexpected checkpoint format: expected a dict, got {type(raw).__name__}"
)

# Strip DDP "module." prefix if present
fixed = {}
for k, v in state_dict.items():
new_key = k[len("module."):] if k.startswith("module.") else k
fixed[new_key] = v
return fixed


def main():
parser = argparse.ArgumentParser(
description="Convert a .pt pre-decoder checkpoint to SafeTensors format."
)
parser.add_argument(
"--checkpoint",
required=True,
help="Path to the input .pt checkpoint file.",
)
parser.add_argument(
"--model-id",
type=int,
required=True,
help="Public model ID (1..5).",
)
parser.add_argument(
"--fp16",
action="store_true",
default=False,
help="Save model weights in float16 (default: float32).",
)
parser.add_argument(
"--output",
default=None,
help=(
"Output .safetensors file path. "
"If not provided, auto-generates '{stem}_fp16.safetensors' or '{stem}_fp32.safetensors' "
"next to the input checkpoint."
),
)
parser.add_argument(
"--device",
default="cpu",
help="Device to load the checkpoint on (default: cpu).",
)
args = parser.parse_args()

dtype = "fp16" if args.fp16 else "fp32"
checkpoint_path = Path(args.checkpoint)

if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

# Determine output path
if args.output:
output_path = Path(args.output)
else:
output_path = checkpoint_path.parent / f"{checkpoint_path.stem}_{dtype}.safetensors"

print(f"Loading checkpoint: {checkpoint_path}")
state_dict = _load_checkpoint_state_dict(str(checkpoint_path), args.device)

print(f"Building model architecture for model_id={args.model_id} ...")
cfg = _build_minimal_cfg(args.model_id)
model = ModelFactory.create_model(cfg).to(args.device)

model.load_state_dict(state_dict, strict=True)
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")

if dtype == "fp16":
model = model.half()
print(" Converted to float16")

print(f"Saving to: {output_path}")
save_safetensors(model, str(output_path), model_id=args.model_id, dtype=dtype)

size_mb = output_path.stat().st_size / (1024 ** 2)
print(f"Done. File size: {size_mb:.2f} MB")


if __name__ == "__main__":
main()
122 changes: 122 additions & 0 deletions code/export/safetensors_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

"""
SafeTensors save/load utilities for fp16/fp32 pre-decoder models.

No quantization or ModelOpt dependencies.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Optional

import torch
from safetensors import safe_open
from safetensors.torch import save_file, load_file

from model.registry import get_model_spec
from model.factory import ModelFactory
from workflows.config_validator import apply_public_defaults_and_model
from omegaconf import OmegaConf


def _build_minimal_cfg(model_id: int):
"""Build a minimal inference config for model_id without a full Hydra setup."""
spec = get_model_spec(model_id)
cfg = OmegaConf.create({
"model_id": model_id,
"distance": spec.receptive_field,
"n_rounds": spec.receptive_field,
"data": {"code_rotation": "XV"},
})
return apply_public_defaults_and_model(cfg, spec)


def save_safetensors(
model: torch.nn.Module,
path: str,
model_id: int,
dtype: str = "fp32",
extra_metadata: Optional[dict] = None,
) -> None:
"""
Save a pre-decoder model to a SafeTensors file.

Args:
model: The model to save (should already be on cpu or target device).
path: Output file path (e.g. "model_fp32.safetensors").
model_id: Public model ID (1..5).
dtype: "fp32" or "fp16".
extra_metadata: Optional dict of additional string metadata to embed.
"""
if dtype not in ("fp32", "fp16"):
raise ValueError(f"dtype must be 'fp32' or 'fp16', got: {dtype!r}")

spec = get_model_spec(model_id)

metadata = {
"model_id": str(model_id),
"quant_format": dtype,
"model_version": spec.model_version,
"receptive_field": str(spec.receptive_field),
"num_filters": json.dumps(list(spec.num_filters)),
"kernel_size": json.dumps(list(spec.kernel_size)),
}
if extra_metadata:
for k, v in extra_metadata.items():
metadata[str(k)] = str(v)

save_file(model.state_dict(), path, metadata=metadata)


def load_safetensors(
safetensors_path: str,
model_id: Optional[int] = None,
device: str = "cuda",
):
"""
Load a pre-decoder model from a SafeTensors file.

Args:
safetensors_path: Path to the .safetensors file.
model_id: If provided, overrides the model_id stored in metadata.
device: Target device string (e.g. "cuda", "cpu", "cuda:0").

Returns:
(model, metadata) where model is the loaded nn.Module and metadata is a dict.
"""
# Read metadata without loading tensors
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
metadata = dict(f.metadata())

# Resolve model_id
if model_id is None:
if "model_id" not in metadata:
raise ValueError(
f"SafeTensors file has no 'model_id' in metadata and model_id was not provided: "
f"{safetensors_path}"
)
model_id = int(metadata["model_id"])
else:
model_id = int(model_id)

cfg = _build_minimal_cfg(model_id)
model = ModelFactory.create_model(cfg).to(device)

state_dict = load_file(safetensors_path, device=device)
model.load_state_dict(state_dict, strict=True)

if metadata.get("quant_format") == "fp16":
model = model.half()

return model, metadata
1 change: 1 addition & 0 deletions code/requirements_public_inference.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ torch
stim
pymatching
matplotlib
safetensors>=0.4.0
22 changes: 22 additions & 0 deletions code/workflows/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,28 @@ def _load_model(cfg, dist):
print(f"🚀 Loading model for task: {cfg.workflow.task}")

_ensure_inference_io_channels(cfg)

# SafeTensors path: load fp16/fp32 model from SafeTensors file
safetensors_path = os.environ.get("PREDECODER_SAFETENSORS_CHECKPOINT", "").strip()
if safetensors_path:
from export.safetensors_utils import load_safetensors
if dist.rank == 0:
print(f"Loading model from SafeTensors: {safetensors_path}")
model, metadata = load_safetensors(
safetensors_path,
model_id=getattr(cfg, "model_id", None),
device=str(dist.device),
)
model = torch.compile(model, disable=True)
if dist.rank == 0:
dtype = metadata.get("quant_format", "fp32")
param_count = sum(p.numel() for p in model.parameters())
print(f" dtype: {dtype}")
print(f" Model parameters: {param_count:,}")
if metadata.get("quant_format") == "fp16":
cfg.enable_fp16 = True
return model

model = ModelFactory.create_model(cfg).to(dist.device)

if dist.rank == 0:
Expand Down