In [None]:
#@title üß∞ Pre-convert setup (installs + folders)
import os, sys, platform, torch

print("Python:", sys.version)
print("Platform:", platform.platform())
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("Compute capability:", torch.cuda.get_device_capability(0))
    print("CUDA runtime:", torch.version.cuda)

!pip -q install -U safetensors tqdm requests==2.32.4 huggingface_hub hf_transfer
!pip -q install -U comfy-kitchen

# optional: faster Hugging Face transfers
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
print("HF_HUB_ENABLE_HF_TRANSFER=1")

MODEL_DIR = "/content/models"
os.makedirs(MODEL_DIR, exist_ok=True)
print("MODEL_DIR:", MODEL_DIR)

In [None]:
#@title üì• Download BF16 model from Hugging Face
from huggingface_hub import hf_hub_download

REPO_ID = "black-forest-labs/FLUX.2-klein-4B" # or black-forest-labs/FLUX.2-klein-4B | artokun/blzib
FILENAME = "flux-2-klein-4b.safetensors" # or flux-2-klein-4b.safetensors | blzib.safetensors

input_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
print("Downloaded to:", input_path)
# After hf_hub_download(...)
downloaded_path = input_path  # so key-check + convert cells that expect downloaded_path still work
print("downloaded_path set to:", downloaded_path)


In [None]:
#@title üì• Download from Civitai (modelId + modelVersionId) slower
import os, re, requests
from tqdm.auto import tqdm

# ===== YOU SET THESE =====
CIVITAI_MODEL_ID = ""          # optional (info lookup only)
CIVITAI_VERSION_ID = ""        # REQUIRED
CIVITAI_API_TOKEN = ""         # optional; needed for models requiring login
FILENAME_OVERRIDE = ""         # optional
# =========================

def _filename_from_cd(cd: str):
    if not cd:
        return None
    m = re.search(r'filename\*?=(?:UTF-8\'\')?"?([^\";]+)"?', cd, flags=re.IGNORECASE)
    return m.group(1) if m else None

def civitai_download(model_version_id: str, out_dir: str, token: str = "", filename_override: str = "") -> str:
    url = f"https://civitai.com/api/download/models/{model_version_id}"
    # token via query string is explicitly supported
    if token:
        url += f"?token={token}"

    with requests.get(url, stream=True, allow_redirects=True, timeout=(10, 120)) as r:
        r.raise_for_status()
        total = int(r.headers.get("Content-Length", "0") or "0")
        cd = r.headers.get("Content-Disposition", "")

        fname = filename_override.strip() or _filename_from_cd(cd) or f"civitai_{model_version_id}.safetensors"
        out_path = os.path.join(out_dir, fname)

        pbar = tqdm(total=total if total > 0 else None, unit="B", unit_scale=True, desc=f"Downloading {fname}")
        with open(out_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))
        pbar.close()

    return out_path

assert CIVITAI_VERSION_ID.strip(), "Set CIVITAI_VERSION_ID (modelVersionId)."
downloaded_path = civitai_download(CIVITAI_VERSION_ID.strip(), MODEL_DIR, CIVITAI_API_TOKEN.strip(), FILENAME_OVERRIDE.strip())
print("‚úÖ Downloaded:", downloaded_path)

In [None]:
#@title üßæ Key check: list safetensors keys (prints + saves full list) For adding more support
import os, safetensors

assert downloaded_path.endswith(".safetensors"), f"Not a .safetensors file: {downloaded_path}"

with safetensors.safe_open(downloaded_path, framework="pt") as f:
    keys = list(f.keys())
    meta = f.metadata() or {}

print("‚úÖ Key count:", len(keys))
print("‚úÖ First 120 keys:")
for k in keys[:120]:
    print(k)

keys_txt = "/content/model_keys.txt"
with open(keys_txt, "w", encoding="utf-8") as w:
    for k in keys:
        w.write(k + "\n")

print("Saved full key list to:", keys_txt)
print("Metadata keys:", list(meta.keys())[:40])
print("Has 'model.diffusion_model.' prefix?", any("model.diffusion_model." in k for k in keys))

In [None]:
#@title üß± Write converter module to disk (convert_nvfp4.py)
%%writefile convert_nvfp4.py
"""
convert_nvfp4.py

Supports:
- Flux.2-Klein-9b / Flux.2-Klein-4b  (txt_attn_mode switching)
- Z-Image-Turbo / Z-Image-Base

FP8 formats:
- "e4m3fn" -> float8_e4m3fn
- "e5m2"   -> float8_e5m2

New behaviors (apply to all model types):
- blacklisted_2d_mode:
    * "bf16": keep blacklist-matched tensors BF16
    * "fp8" : if blacklist-matched tensor is a 2D ".weight", store it as FP8 (chosen fp8_format)
- nvfp4_fail_fallback:
    * "bf16": NVFP4 failure -> BF16
    * "fp8" : NVFP4 failure -> FP8 (chosen fp8_format), else BF16 if FP8 also fails

Flux Klein txt_attn_mode:
- "nvfp4" : try NVFP4 then fallback via nvfp4_fail_fallback
- "bf16"  : keep BF16
- "fp8"   : store txt_attn 2D weights as FP8 (chosen fp8_format)

FP8 checkpoint fields:
- store {base}.weight_scale and {base}.input_scale as float32 scalars.
"""

from __future__ import annotations

import os
import json
from collections import OrderedDict
from typing import Dict, Any, Tuple, List

import torch
import safetensors
import safetensors.torch
from tqdm.auto import tqdm

try:
    import comfy_kitchen as ck
    from comfy_kitchen.tensor import TensorCoreNVFP4Layout
except Exception as e:
    ck = None
    TensorCoreNVFP4Layout = None
    _IMPORT_ERR = e

SUPPORTED_TYPES = ("Flux.2-Klein-9b", "Flux.2-Klein-4b", "Z-Image-Turbo", "Z-Image-Base")

_FP8_DTYPE = {
    "e4m3fn": torch.float8_e4m3fn,
    "e5m2": torch.float8_e5m2,
}
_FP8_META_FORMAT = {
    "e4m3fn": "float8_e4m3fn",
    "e5m2": "float8_e5m2",
}

def _sm_str() -> str:
    if not torch.cuda.is_available():
        return "CPU"
    major, minor = torch.cuda.get_device_capability()
    return f"SM{major}{minor}"

def _profile(model_type: str, match_official_flux9b: bool = True) -> Tuple[List[str], List[str]]:
    """
    Returns (BLACKLIST, EXTRA_BF16)
    EXTRA_BF16 always stays BF16.
    """
    if model_type == "Z-Image-Base":
        blacklist = [
            "attention.out", "layers.0.", "layers.29.", "adaLN_modulation", "norm",
            "final_layer", "cap_embedder", "x_embedder", "noise_refiner", "context_refiner", "t_embedder"
        ]
        return blacklist, []

    if model_type == "Z-Image-Turbo":
        blacklist = ["cap_embedder", "x_embedder", "noise_refiner", "context_refiner", "t_embedder", "final_layer"]
        return blacklist, []

    if model_type in ("Flux.2-Klein-9b", "Flux.2-Klein-4b"):
        blacklist = [
            "bias",
            "img_in", "txt_in", "time_in", "vector_in", "guidance_in",
            "final_layer", "class_embedding",
            "single_stream_modulation", "double_stream_modulation_img", "double_stream_modulation_txt",
        ]
        extra_bf16 = []
        if match_official_flux9b and model_type == "Flux.2-Klein-9b":
            extra_bf16 = ["double_blocks.0.img_attn.qkv", "double_blocks.0.txt_attn.qkv"]
        return blacklist, extra_bf16

    raise ValueError(f"Unsupported model_type: {model_type}")

def _is_2d_weight(k: str, v: torch.Tensor) -> bool:
    return (v.ndim == 2) and k.endswith(".weight")

def _is_txt_attn_weight(k: str, v: torch.Tensor) -> bool:
    return ("txt_attn" in k) and _is_2d_weight(k, v)

def _base_keys(k: str) -> Tuple[str, str]:
    base_k_file = k[:-len(".weight")]
    if "model.diffusion_model." in base_k_file:
        base_k_meta = base_k_file.split("model.diffusion_model.")[-1]
    else:
        base_k_meta = base_k_file
    return base_k_file, base_k_meta

def _fp8_scale(t: torch.Tensor, fp8_dtype: torch.dtype,
              strategy: str = "absmax",
              percentile: float = 99.9,
              slack: float = 1.05) -> torch.Tensor:
    max_fp8 = torch.finfo(fp8_dtype).max
    if strategy == "absmax":
        a = torch.amax(t.abs())
    elif strategy == "absmax_slack":
        a = torch.amax(t.abs()) * float(slack)
    elif strategy == "percentile":
        a = torch.quantile(t.abs().flatten().float(), float(percentile) / 100.0)
    else:
        raise ValueError("fp8_scale_strategy must be absmax | absmax_slack | percentile")
    return (a / max_fp8).clamp(min=1e-12).float()

def _fp8_pack(new_sd: Dict[str, torch.Tensor], quant_map: Dict[str, Any],
              k: str, base_k_file: str, base_k_meta: str,
              v_tensor: torch.Tensor,
              fp8_dtype: torch.dtype, fp8_meta_fmt: str,
              fp8_scale_strategy: str, fp8_percentile: float, fp8_slack: float) -> bool:
    """
    Returns True if FP8 write succeeded.
    """
    try:
        scale = _fp8_scale(v_tensor, fp8_dtype, strategy=fp8_scale_strategy, percentile=fp8_percentile, slack=fp8_slack)
        q = ck.quantize_per_tensor_fp8(v_tensor, scale, fp8_dtype)
        new_sd[k] = q.cpu()
        new_sd[f"{base_k_file}.weight_scale"] = scale.float().cpu()                     # float32 scalar
        new_sd[f"{base_k_file}.input_scale"]  = torch.ones((), dtype=torch.float32).cpu()  # float32 scalar default
        quant_map["layers"][base_k_meta] = {"format": fp8_meta_fmt}
        return True
    except Exception:
        return False

def _nvfp4_pack(new_sd: Dict[str, torch.Tensor], quant_map: Dict[str, Any],
                base_k_file: str, base_k_meta: str,
                v_tensor: torch.Tensor) -> None:
    qdata, params = TensorCoreNVFP4Layout.quantize(v_tensor)
    tensors = TensorCoreNVFP4Layout.state_dict_tensors(qdata, params)
    for suffix, t in tensors.items():
        new_sd[f"{base_k_file}.weight{suffix}"] = t.cpu()
    quant_map["layers"][base_k_meta] = {"format": "nvfp4"}

def convert_to_nvfp4(
    input_path: str,
    output_path: str,
    model_type: str,
    device: str = "cuda",
    match_official: bool = True,

    # Flux Klein only:
    txt_attn_mode: str = "nvfp4",            # "nvfp4" | "bf16" | "fp8"

    # FP8 (used for txt_attn fp8, blacklisted fp8, and nvfp4 fallback fp8):
    fp8_format: str = "e4m3fn",              # "e4m3fn" | "e5m2"
    fp8_scale_strategy: str = "absmax",      # absmax | absmax_slack | percentile
    fp8_slack: float = 1.05,
    fp8_percentile: float = 99.9,

    # New:
    blacklisted_2d_mode: str = "bf16",       # "bf16" | "fp8"
    nvfp4_fail_fallback: str = "bf16",       # "bf16" | "fp8"

    verbose: bool = True,
) -> Dict[str, Any]:
    if ck is None or TensorCoreNVFP4Layout is None:
        raise RuntimeError(f"comfy-kitchen import failed: {_IMPORT_ERR}\nInstall: pip install comfy-kitchen")

    if model_type not in SUPPORTED_TYPES:
        raise ValueError(f"model_type must be one of: {SUPPORTED_TYPES}")

    if device not in ("cuda", "cpu"):
        raise ValueError("device must be 'cuda' or 'cpu'")
    if device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError("CUDA requested but torch.cuda.is_available() is False")

    if txt_attn_mode not in ("nvfp4", "bf16", "fp8"):
        raise ValueError("txt_attn_mode must be nvfp4 | bf16 | fp8")
    if fp8_format not in _FP8_DTYPE:
        raise ValueError("fp8_format must be e4m3fn | e5m2")

    if blacklisted_2d_mode not in ("bf16", "fp8"):
        raise ValueError("blacklisted_2d_mode must be bf16 | fp8")
    if nvfp4_fail_fallback not in ("bf16", "fp8"):
        raise ValueError("nvfp4_fail_fallback must be bf16 | fp8")

    fp8_dtype = _FP8_DTYPE[fp8_format]
    fp8_meta  = _FP8_META_FORMAT[fp8_format]

    blacklist, extra_bf16 = _profile(model_type, match_official_flux9b=match_official)

    # stream read keys + metadata
    with safetensors.safe_open(input_path, framework="pt") as f:
        keys = list(f.keys())
        orig_meta = f.metadata() or {}

    quant_map = {"format_version": "1.0", "layers": {}}
    new_sd: Dict[str, torch.Tensor] = {}

    if verbose:
        print(f"üîß Input:  {input_path}")
        print(f"üíæ Output: {output_path}")
        print(f"üß† Model:  {model_type}")
        print(f"üß† Device: {device} ({_sm_str()})")
        print(f"üéØ match_official={match_official} | extra_bf16={extra_bf16}")
        print(f"üßä fp8_format={fp8_format} | blacklisted_2d_mode={blacklisted_2d_mode} | nvfp4_fail_fallback={nvfp4_fail_fallback}")
        if model_type.startswith("Flux.2-Klein"):
            print(f"üéõÔ∏è txt_attn_mode={txt_attn_mode}")
        print(f"üì¶ Keys: {len(keys)}")

    nvfp4_ok = nvfp4_fail = 0
    fp8_ok = fp8_fail = 0
    bf16_keep = 0

    with safetensors.safe_open(input_path, framework="pt") as f:
        for k in tqdm(keys, desc="Converting", unit="tensor"):
            v = f.get_tensor(k)

            in_extra = any(s in k for s in extra_bf16)
            in_blacklist = any(s in k for s in blacklist)

            # EXTRA_BF16 always BF16
            if in_extra:
                new_sd[k] = v.to(dtype=torch.bfloat16)
                bf16_keep += 1
                continue

            # Blacklist: optionally FP8 for 2D weights
            if in_blacklist:
                if blacklisted_2d_mode == "fp8" and _is_2d_weight(k, v):
                    base_k_file, base_k_meta = _base_keys(k)
                    v_tensor = v.to(device=device, dtype=torch.bfloat16)
                    ok = _fp8_pack(new_sd, quant_map, k, base_k_file, base_k_meta, v_tensor,
                                   fp8_dtype, fp8_meta, fp8_scale_strategy, fp8_percentile, fp8_slack)
                    if ok:
                        fp8_ok += 1
                    else:
                        fp8_fail += 1
                        new_sd[k] = v.to(dtype=torch.bfloat16)
                        bf16_keep += 1
                    if device == "cuda":
                        del v_tensor
                    continue

                new_sd[k] = v.to(dtype=torch.bfloat16)
                bf16_keep += 1
                continue

            # Flux Klein: txt_attn special handling
            if model_type.startswith("Flux.2-Klein") and _is_txt_attn_weight(k, v):
                if txt_attn_mode == "bf16":
                    new_sd[k] = v.to(dtype=torch.bfloat16)
                    bf16_keep += 1
                    continue

                base_k_file, base_k_meta = _base_keys(k)
                v_tensor = v.to(device=device, dtype=torch.bfloat16)

                if txt_attn_mode == "fp8":
                    ok = _fp8_pack(new_sd, quant_map, k, base_k_file, base_k_meta, v_tensor,
                                   fp8_dtype, fp8_meta, fp8_scale_strategy, fp8_percentile, fp8_slack)
                    if ok:
                        fp8_ok += 1
                    else:
                        fp8_fail += 1
                        new_sd[k] = v.to(dtype=torch.bfloat16)
                        bf16_keep += 1
                    if device == "cuda":
                        del v_tensor
                    continue

                # txt_attn_mode == nvfp4
                try:
                    _nvfp4_pack(new_sd, quant_map, base_k_file, base_k_meta, v_tensor)
                    nvfp4_ok += 1
                except Exception:
                    nvfp4_fail += 1
                    if nvfp4_fail_fallback == "fp8":
                        ok = _fp8_pack(new_sd, quant_map, k, base_k_file, base_k_meta, v_tensor,
                                       fp8_dtype, fp8_meta, fp8_scale_strategy, fp8_percentile, fp8_slack)
                        if ok:
                            fp8_ok += 1
                        else:
                            fp8_fail += 1
                            new_sd[k] = v.to(dtype=torch.bfloat16)
                            bf16_keep += 1
                    else:
                        new_sd[k] = v.to(dtype=torch.bfloat16)
                        bf16_keep += 1
                finally:
                    if device == "cuda":
                        del v_tensor
                continue

            # Default: NVFP4 for 2D weights; fallback if requested
            if _is_2d_weight(k, v):
                base_k_file, base_k_meta = _base_keys(k)
                v_tensor = v.to(device=device, dtype=torch.bfloat16)
                try:
                    _nvfp4_pack(new_sd, quant_map, base_k_file, base_k_meta, v_tensor)
                    nvfp4_ok += 1
                except Exception:
                    nvfp4_fail += 1
                    if nvfp4_fail_fallback == "fp8":
                        ok = _fp8_pack(new_sd, quant_map, k, base_k_file, base_k_meta, v_tensor,
                                       fp8_dtype, fp8_meta, fp8_scale_strategy, fp8_percentile, fp8_slack)
                        if ok:
                            fp8_ok += 1
                        else:
                            fp8_fail += 1
                            new_sd[k] = v.to(dtype=torch.bfloat16)
                            bf16_keep += 1
                    else:
                        new_sd[k] = v.to(dtype=torch.bfloat16)
                        bf16_keep += 1
                finally:
                    if device == "cuda":
                        del v_tensor
            else:
                new_sd[k] = v.to(dtype=torch.bfloat16)
                bf16_keep += 1

    final_metadata = OrderedDict()
    final_metadata["_quantization_metadata"] = json.dumps(quant_map)
    for mk, mv in orig_meta.items():
        if mk not in final_metadata:
            final_metadata[mk] = mv

    safetensors.torch.save_file(new_sd, output_path, metadata=final_metadata)

    total_bytes = os.path.getsize(output_path)
    summary = {
        "model_type": model_type,
        "input_path": input_path,
        "output_path": output_path,
        "output_gb": round(total_bytes / (1024**3), 3),
        "nvfp4_ok": nvfp4_ok,
        "nvfp4_fail": nvfp4_fail,
        "fp8_ok": fp8_ok,
        "fp8_fail": fp8_fail,
        "bf16_kept": bf16_keep,
        "sm": _sm_str(),
        "fp8_format": fp8_format,
        "blacklisted_2d_mode": blacklisted_2d_mode,
        "nvfp4_fail_fallback": nvfp4_fail_fallback,
    }
    print("‚úÖ Done:", summary)
    return summary

In [None]:
#@title üç≥ Convert (custom output filename)
import os
from convert_nvfp4 import convert_to_nvfp4

input_path = downloaded_path  # HF or Civitai download
OUT_DIR = "/content/models"
os.makedirs(OUT_DIR, exist_ok=True)

MODEL_TYPE = "Flux.2-Klein-4b"  # "Z-Image-Turbo" | "Z-Image-Base" | "Flux.2-Klein-9b" | "Flux.2-Klein-4b"
OUT_NAME  = "FluxKlein4b_nvfp4"  # change to whatever
output_path = os.path.join(OUT_DIR, f"{OUT_NAME}.safetensors")

# =========================
# COMMON (applies to all)
# =========================
DEVICE = "cuda"
MATCH_OFFICIAL = True

# Pick ONE FP8 format (no chain):
FP8_FORMAT = "e4m3fn"     # "e4m3fn" | "e5m2"

# FP8 scaling knobs:
FP8_SCALE_STRATEGY = "absmax"   # "absmax" | "absmax_slack" | "percentile"
FP8_SLACK = 1.05
FP8_PERCENTILE = 99.9

# NVFP4 failure fallback:
NVFP4_FAIL_FALLBACK = "bf16"    # "bf16" | "fp8"

# =========================
# Z-IMAGE options
# =========================
# What to do with blacklist-matched 2D ".weight" tensors:
ZIMAGE_BLACKLISTED_2D_MODE = "bf16"  # "bf16" | "fp8"

# =========================
# FLUX KLEIN options
# =========================
# txt_attn handling (only used for Flux.2-Klein-*):
FLUX_TXT_ATTN_MODE = "nvfp4"        # "nvfp4" | "bf16" | "fp8"

summary = convert_to_nvfp4(
    input_path=input_path,
    output_path=output_path,
    model_type=MODEL_TYPE,
    device=DEVICE,
    match_official=MATCH_OFFICIAL,

    # Flux Klein only:
    txt_attn_mode=FLUX_TXT_ATTN_MODE,

    # FP8 (used anywhere we choose FP8):
    fp8_format=FP8_FORMAT,
    fp8_scale_strategy=FP8_SCALE_STRATEGY,
    fp8_slack=FP8_SLACK,
    fp8_percentile=FP8_PERCENTILE,

    # New behaviors:
    blacklisted_2d_mode=(ZIMAGE_BLACKLISTED_2D_MODE if MODEL_TYPE.startswith("Z-Image") else "bf16"),
    nvfp4_fail_fallback=NVFP4_FAIL_FALLBACK,

    verbose=True,
)

print("‚úÖ Wrote:", output_path)
summary

In [None]:
#@title üîê Hugging Face login (token)
from getpass import getpass
HF_TOKEN = getpass("Paste your Hugging Face token (write access):").strip()
assert HF_TOKEN, "Token is required"
print("Token received (not printing it).")


In [None]:
#@title üèóÔ∏è Create repo 'quanttesting' and upload NVFP4 model
import os
from huggingface_hub import HfApi

api = HfApi(token=HF_TOKEN)
who = api.whoami()
username = who["name"]

repo_name = "YOUR REPO HERE" # Change
repo_id = f"{username}/{repo_name}"

api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
print("Repo:", repo_id)

# Upload the NVFP4 safetensors (uses LFS automatically for large files)
api.upload_file(
    path_or_fileobj=output_path,
    path_in_repo=os.path.basename(output_path),
    repo_id=repo_id,
    repo_type="model",
    commit_message="Upload NVFP4-converted", # can change
)
print("‚úÖ Upload complete.")
