In [7]:
import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict

import torch
from huggingface_hub import HfApi, HfFolder, create_repo, upload_folder
from safetensors.torch import (
    save_file as save_safetensors,
    load_file as load_safetensors,
)
from safetensors.torch import save_model, load_file

In [6]:

TRAINING_KEYS: set[str] = {
    "optimizer_states",
    "lr_schedulers",
    "callbacks",
    "loops",
    "amp_scaler",
    # keep hyper_parameters optionally
}

def _dedupe_state_dict(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Clone tensors that share storage so safetensors can serialize them safely."""
    seen: dict[int, str] = {}
    deduped: dict[str, torch.Tensor] = {}

    for name, tensor in sd.items():
        ptr = tensor.storage().data_ptr()
        if ptr in seen:
            deduped[name] = tensor.clone()
            print(f" Cloned shared tensor '{name}' (alias of '{seen[ptr]}')")
        else:
            deduped[name] = tensor
            seen[ptr] = name
    return deduped

def _validate_state_dict(original_sd: Dict[str, torch.Tensor], saved_file: Path) -> None:
    """Byte-level integrity check between in-memory and on-disk weights.

    Raises
    ------
    ValueError
        If a key is missing/extra or any tensor differs in shape, dtype or value.
    """
    loaded_sd: Dict[str, torch.Tensor] = load_safetensors(str(saved_file))  # type: ignore[arg-type]

    if original_sd.keys() != loaded_sd.keys():
        missing = original_sd.keys() - loaded_sd.keys()
        extra = loaded_sd.keys() - original_sd.keys()
        raise ValueError(
            f"Key mismatch between in-memory and saved weights. Missing: {missing}, Extra: {extra}"
        )

    for k, t in original_sd.items():
        l = loaded_sd[k]
        if t.shape != l.shape or t.dtype != l.dtype or not torch.equal(t.cpu(), l.cpu()):
            raise ValueError(f"Tensor mismatch for key '{k}'")

    print("Validation passed: saved weights are byte-identical to the in-memory state_dict")


def _strip_checkpoint(
    ckpt_path: Path,
    output_dir: Path,
    keep_hparams: bool = True,
    validate: bool = True,
) -> Path:
    """Load ckpt_path, drop non-essential keys, export as .safetensors, validate."""
    ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location="cpu")

    state_dict = _dedupe_state_dict(ckpt["state_dict"])

    minimal_state: Dict[str, Any] = {"state_dict": state_dict}
    if keep_hparams and "hyper_parameters" in ckpt:
        minimal_state["hyper_parameters"] = ckpt["hyper_parameters"]

    output_dir.mkdir(parents=True, exist_ok=True)
    target = output_dir / f"{ckpt_path.stem}_minimal.safetensors"
    save_safetensors(state_dict, str(target))

    if validate:
        _validate_state_dict(state_dict, target)

    if keep_hparams and "hyper_parameters" in minimal_state:
        with open(output_dir / "hparams.json", "w", encoding="utf-8") as fh:
            json.dump(minimal_state["hyper_parameters"], fh, indent=2)

    return target

In [None]:
INPUT_CKPT = Path("/projects/prjs1134/data/projects/biodt/storage/runs_storage/19-01-40/checkpoints/epoch=419-val_loss=0.01776.ckpt")
OUTPUT_DIR = Path("/projects/prjs1134/data/projects/biodt/storage/weights/cleaned")

cleaned = _strip_checkpoint(ckpt_path=INPUT_CKPT, output_dir=OUTPUT_DIR, keep_hparams=False)
print(f"Saved minimal weights: {cleaned}")

In [None]:
cleaned_w = load_file("/projects/prjs1134/data/projects/biodt/storage/weights/cleaned/epoch=419-val_loss=0.01776_minimal.safetensors")