# COREML CONVERT - this is for taking pretrained QAT models/F32 variants and converting them to CoreML

## 1. Start with model loading and set the ENV paths

In [2]:
# 1.1) Environment & Paths

import os, torch, warnings
from diffusers import UNet2DConditionModel, AutoencoderKL

# EDIT THESE THREE
PATH_RUN = "/datasets/abradshaw/output_QAT/train_marigold_depth_qat"
PATH_QAT_SD = f"{PATH_RUN}/checkpoint/latest/unet_qat_state_dict.pt"  # QAT UNet weights (with 8-ch conv_in)
PATH_PRETRAINED_LOCAL = "/datasets/abradshaw/mgold_ckpts/stable-diffusion-2"  # local base (for VAE + UNet config)

PATH_EXPORT = os.path.join(PATH_RUN, "export_coreml_192")
os.makedirs(PATH_EXPORT, exist_ok=True)

assert os.path.isfile(PATH_QAT_SD), f"Missing QAT state-dict: {PATH_QAT_SD}"
print("Export dir:", PATH_EXPORT)
print("QAT sd:", PATH_QAT_SD)
print("Base (VAE/config):", PATH_PRETRAINED_LOCAL)

warnings.filterwarnings("once")
torch.set_grad_enabled(False)
device = torch.device("cpu")


Export dir: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml_192
QAT sd: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/checkpoint/latest/unet_qat_state_dict.pt
Base (VAE/config): /datasets/abradshaw/mgold_ckpts/stable-diffusion-2


## 2. Load UNet config only, force 8 input channels like the trainer

In [3]:
import torch.nn as nn
import torch.nn.functional as F

# --- 2.1 Load UNet config only, force 8 input channels ---
unet_cfg = UNet2DConditionModel.load_config(os.path.join(PATH_PRETRAINED_LOCAL, "unet"))
unet_cfg["in_channels"] = 8  # we trained on [rgb_latent(4) + target_latent(4)]
unet = UNet2DConditionModel.from_config(unet_cfg).to(device)  # empty FP32 skeleton

# --- 2.2 Sanitize activations to match QAT training (GELU->ReLU, SiLU->Hardswish) ---
_BAD2GOOD = {
    nn.GELU: lambda _: nn.ReLU(inplace=False),
    nn.SiLU: lambda _: nn.Hardswish(),
}
_BAD_FUNCS = {F.gelu, F.silu}
def _relu(x): return F.relu(x, inplace=False)

def sanitize_only(module: nn.Module):
    for name, child in list(module.named_children()):
        replaced = False
        for bad_cls, make_good in _BAD2GOOD.items():
            if isinstance(child, bad_cls):
                setattr(module, name, make_good(child))
                replaced = True
                break
        sanitize_only(getattr(module, name) if replaced else child)
    for attr_name, attr_val in vars(module).items():
        if callable(attr_val) and attr_val in _BAD_FUNCS:
            setattr(module, attr_name, _relu)

sanitize_only(unet)

# Disable memory-efficient attention (PyTorch export prefers standard ops)
try:
    unet.disable_xformers_memory_efficient_attention()
except Exception:
    pass
unet.eval()

# --- 2.3 Load QAT state-dict; observers/fake-quant keys are ignored by strict=False ---
sd = torch.load(PATH_QAT_SD, map_location="cpu", weights_only=True)
missing, unexpected = unet.load_state_dict(sd, strict=False)
print(f"[UNet load] missing={len(missing)} unexpected={len(unexpected)}")
if unexpected:
    print("  (expected: observer/fake-quant keys show up here and are ignored)")

# quick forward sanity
with torch.no_grad():
    x = torch.randn(1, 8, 64, 64)
    t = torch.tensor([0])
    cond = torch.randn(1, 77, 1024)
    y = unet(x, t, encoder_hidden_states=cond).sample
    print("UNet forward OK →", tuple(y.shape))


vae = AutoencoderKL.from_pretrained(PATH_PRETRAINED_LOCAL, subfolder="vae").to(device).eval()
print("VAE loaded.")
fixed_embed = torch.randn(1, 77, 1024, dtype=torch.float32)
print("fixed_embed loaded.")

[UNet load] missing=0 unexpected=3087
  (expected: observer/fake-quant keys show up here and are ignored)
UNet forward OK → (1, 4, 64, 64)
VAE loaded.
fixed_embed loaded.


## 3.  Export wrappers (Encoder / UNetStep / Decoder) this way we can pass in F32 inouts ot the model and convert each wrapper separately instead of a monolithic pipeline

In [4]:
# 3.1) Export wrappers (Encoder / UNetStep / Decoder)
import torch.nn as nn

LATENT_SF = 0.18215

class EncoderWrapper(nn.Module):
    def __init__(self, vae: AutoencoderKL):
        super().__init__(); self.vae = vae
    def forward(self, rgb_norm: torch.Tensor) -> torch.Tensor:
        # [-1,1]-normalized RGB → latent(4,H/8,W/8)
        h = self.vae.encoder(rgb_norm)
        mean, logvar = torch.chunk(self.vae.quant_conv(h), 2, dim=1)
        return mean * LATENT_SF

class UNetStepWrapper(nn.Module):
    def __init__(self, unet: UNet2DConditionModel, fixed_embed: torch.Tensor):
        super().__init__(); self.unet = unet
        self.register_buffer("fixed_embed", fixed_embed, persistent=False)
    def forward(self, rgb_latent: torch.Tensor, target_latent: torch.Tensor, t_f32: torch.Tensor) -> torch.Tensor:
        # concat latents; convert float timestep → int64
        x = torch.cat([rgb_latent, target_latent], dim=1)
        t = t_f32.to(torch.int64).reshape(-1)
        cond = self.fixed_embed.expand(x.shape[0], -1, -1)
        return self.unet(x, t, encoder_hidden_states=cond).sample

class DecoderWrapper(nn.Module):
    def __init__(self, vae: AutoencoderKL):
        super().__init__(); self.vae = vae
    def forward(self, depth_latent: torch.Tensor) -> torch.Tensor:
        z = self.vae.post_quant_conv(depth_latent / LATENT_SF)
        stacked = self.vae.decoder(z)
        return stacked.mean(dim=1, keepdim=True)  # [B,1,H,W]

enc = EncoderWrapper(vae).eval()
step = UNetStepWrapper(unet, fixed_embed).eval()
dec  = DecoderWrapper(vae).eval()
print("Wrappers prepared.")


Wrappers prepared.


## 4. Save each torch script, fall back to a trace if there is conditional logic that scirpt deos not support

In [6]:
import os
import torch

# ablations
#B, H, W = 1, 512, 512
#B, H, W = 1, 320, 320
#B, H, W = 1, 256, 256
B, H, W = 1, 256, 192
h, w = H//8, W//8

ex_rgb = torch.randn(B,3,H,W, dtype=torch.float32)
ex_lat = torch.randn(B,4,h,w, dtype=torch.float32)
ex_t   = torch.tensor([0.0], dtype=torch.float32)

def save_ts(m, ex, path):
    try:
        ts = torch.jit.script(m)
    except Exception as e:
        print(f"[TS] script failed for {path}: {e}\n→ tracing instead.")
        ts = torch.jit.trace(m, ex)
    ts.save(path)
    print("[TS] saved:", path)
    return ts

enc_ts  = save_ts(enc,  (ex_rgb,),              os.path.join(PATH_EXPORT, "Encoder.ts"))
step_ts = save_ts(step, (ex_lat,ex_lat,ex_t),   os.path.join(PATH_EXPORT, "UNetStep.ts"))
dec_ts  = save_ts(dec,  (ex_lat,),              os.path.join(PATH_EXPORT, "Decoder.ts"))


[TS] script failed for /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml_192/Encoder.ts: function definitions aren't supported:
  File "/home/abradshaw/Marigold/venv/marigold/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 181
        processors = {}
    
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
        ~~~ <--- HERE
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

→ tracing instead.
[TS] saved: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml_192/Encoder.ts
[TS] script failed for /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml_192/UNetStep.ts: keyword-arg expansion is not supported:
  File "/home/abradshaw/Marigold/venv/marigold/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1173
            cross_attenti

## 5. Convert the torch scripts into CoreML packages straight from pytorch scurpts to corenk (ONNX supporrted later)

In [None]:
# TorchScript → Core ML only (no ONNX fallback)

import coremltools as ct
import numpy as np 

enc_ml = ct.convert(
    enc_ts,
    convert_to="mlprogram",
    inputs=[ct.TensorType(name="rgb_norm", shape=ex_rgb.shape, dtype=np.float16)], # dtype 16 is only supported for iOS 16+
    compute_precision=ct.precision.FLOAT16, # lowest precision supported in coreML execp if you do post traingn qunaitixaiotn which is not a targert precisons supported for int8
    minimum_deployment_target=ct.target.iOS16,
)
enc_ml.save(os.path.join(PATH_EXPORT, "Encoder.mlpackage"))

step_ml = ct.convert(
    step_ts,
    convert_to="mlprogram",
    inputs=[
        ct.TensorType(name="rgb_latent",    shape=ex_lat.shape, dtype=np.float16),
        ct.TensorType(name="target_latent", shape=ex_lat.shape, dtype=np.float16),
        ct.TensorType(name="t_f32",         shape=(1,), dtype=np.float32),
    ],
    compute_precision=ct.precision.FLOAT16,
    minimum_deployment_target=ct.target.iOS16,
)
step_ml.save(os.path.join(PATH_EXPORT, "UNetStep.mlpackage"))

dec_ml = ct.convert(
    dec_ts,
    convert_to="mlprogram",
    inputs=[ct.TensorType(name="depth_latent", shape=ex_lat.shape, dtype=np.float16)],
    compute_precision=ct.precision.FLOAT16,
    minimum_deployment_target=ct.target.iOS16,
)
dec_ml.save(os.path.join(PATH_EXPORT, "Decoder.mlpackage"))


  return _StrictVersion(version)
scikit-learn version 1.7.0 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.4.1+cu121 has not been tested with coremltools. You may run into unexpected errors. Torch 2.2.0 is the most recent version that has been tested.


Failed to load _MLModelProxy: No module named 'coremltools.libcoremlpython'
  res = mb.const(val=dtype(x.val), name=node.name)
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 300/301 [00:00<00:00, 4281.72 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 187.53 passes/s]
Running MIL default pipeline: 100%|██████████| 78/78 [00:03<00:00, 24.07 passes/s] 
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 240.03 passes/s]
Converting PyTorch Frontend ==> MIL Ops:   0%|          | 0/2412 [00:00<?, ? ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 2411/2412 [00:00<00:00, 2994.41 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 32.94 passes/s]
Running MIL default pipeline: 100%|██████████| 78/78 [00:30<00:00,  2.57 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 

## 5. Convert to CoreML with ONNX catch

In [None]:
# Core ML conversion (ML Program, FP16) + ONNX fallback
import coremltools as ct

def to_coreml_from_ts(ts_module, inputs, out_name):
    mlmodel = ct.convert(
        ts_module,
        convert_to="mlprogram",
        inputs=inputs,
        compute_precision=ct.precision.FLOAT16,
        minimum_deployment_target=ct.target.iOS15,
    )
    out_path = os.path.join(PATH_EXPORT, out_name)
    mlmodel.save(out_path)
    print("[CoreML] saved:", out_path)
    return out_path

def to_coreml_from_onnx(onnx_path, inputs, out_name):
    mlmodel = ct.convert(
        onnx_path,
        convert_to="mlprogram",
        inputs=inputs,
        compute_precision=ct.precision.FLOAT16,
        minimum_deployment_target=ct.target.iOS15,
    )
    out_path = os.path.join(PATH_EXPORT, out_name)
    mlmodel.save(out_path)
    print("[CoreML][ONNX] saved:", out_path)
    return out_path

# Encoder
try:
    enc_ml = to_coreml_from_ts(
        enc_ts,
        [ct.TensorType(name="rgb_norm", shape=ex_rgb.shape)],
        "Encoder.mlpackage",
    )
except Exception as e:
    print("[CoreML] Encoder TS failed; fallback to ONNX:", repr(e))
    onnx_path = os.path.join(PATH_EXPORT, "Encoder.onnx")
    torch.onnx.export(enc, (ex_rgb,), onnx_path, opset_version=17,
                      input_names=["rgb_norm"], output_names=["rgb_latent"], dynamic_axes=None)
    enc_ml = to_coreml_from_onnx(
        onnx_path, [ct.TensorType(name="rgb_norm", shape=ex_rgb.shape)], "Encoder.mlpackage"
    )

# UNetStep
try:
    step_ml = to_coreml_from_ts(
        step_ts,
        [
            ct.TensorType(name="rgb_latent",    shape=ex_lat.shape),
            ct.TensorType(name="target_latent", shape=ex_lat.shape),
            ct.TensorType(name="t_f32",         shape=(1,)),
        ],
        "UNetStep.mlpackage",
    )
except Exception as e:
    print("[CoreML] UNetStep TS failed; fallback to ONNX:", repr(e))
    onnx_path = os.path.join(PATH_EXPORT, "UNetStep.onnx")
    torch.onnx.export(step, (ex_lat, ex_lat, ex_t), onnx_path, opset_version=17,
                      input_names=["rgb_latent","target_latent","t_f32"], output_names=["noise_pred"], dynamic_axes=None)
    step_ml = to_coreml_from_onnx(
        onnx_path,
        [
            ct.TensorType(name="rgb_latent",    shape=ex_lat.shape),
            ct.TensorType(name="target_latent", shape=ex_lat.shape),
            ct.TensorType(name="t_f32",         shape=(1,)),
        ],
        "UNetStep.mlpackage",
    )

# Decoder
try:
    dec_ml = to_coreml_from_ts(
        dec_ts,
        [ct.TensorType(name="depth_latent", shape=ex_lat.shape)],
        "Decoder.mlpackage",
    )
except Exception as e:
    print("[CoreML] Decoder TS failed; fallback to ONNX:", repr(e))
    onnx_path = os.path.join(PATH_EXPORT, "Decoder.onnx")
    torch.onnx.export(dec, (ex_lat,), onnx_path, opset_version=17,
                      input_names=["depth_latent"], output_names=["depth"], dynamic_axes=None)
    dec_ml = to_coreml_from_onnx(
        onnx_path, [ct.TensorType(name="depth_latent", shape=ex_lat.shape)], "Decoder.mlpackage"
    )


-----

## Optional: Post training quantization (QAT) is required for CoreML conversion. If you first convert to CoreML without QAT, the CoreML package (not yet supported) will not work.

In [None]:
import os
import coremltools as ct
from coremltools.models.neural_network import quantization_utils
from coremltools.models.utils import load_spec

BASE = "/datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml"
MODELS = ["Encoder", "UNetStep", "Decoder"]

def quantize_pkg(pkg_path: str):
    coreml_dir  = os.path.join(pkg_path, "Data", "com.apple.CoreML")
    spec_path   = os.path.join(coreml_dir, "model.mlmodel")
    weights_dir = os.path.join(coreml_dir, "weights")
    assert os.path.isfile(spec_path), f"missing {spec_path}"
    assert os.path.isdir(weights_dir), f"missing {weights_dir}"

    print(f"\n→ Quantizing weights to int8: {pkg_path}")
    spec = load_spec(spec_path)

    # Quantize weights in the spec (linear int8; no calibration, weights-only)
    qspec = quantization_utils._quantize_spec_weights(  # underscore is fine here
        spec, nbits=8, quantization_mode="linear"
    )

    # Re-wrap with weights_dir so it can save as a full .mlpackage
    ml_q = ct.models.MLModel(qspec, weights_dir=weights_dir)

    out_pkg = pkg_path.replace(".mlpackage", "-INT8.mlpackage")
    ml_q.save(out_pkg)
    print(f"   Saved: {out_pkg}")
    return out_pkg

if __name__ == "__main__":
    for name in MODELS:
        pkg = os.path.join(BASE, f"{name}.mlpackage")
        if os.path.isdir(pkg):
            quantize_pkg(pkg)
        else:
            print(f"!! Missing {pkg}")

    print("\nDone. Open the *-INT8.mlpackage in Xcode → Performance to benchmark.")



→ Quantizing weights to int8: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml/Encoder.mlpackage
   Saved: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml/Encoder-INT8.mlpackage

→ Quantizing weights to int8: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml/UNetStep.mlpackage
   Saved: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml/UNetStep-INT8.mlpackage

→ Quantizing weights to int8: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml/Decoder.mlpackage
   Saved: /datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml/Decoder-INT8.mlpackage

Done. Open the *-INT8.mlpackage in Xcode → Performance to benchmark.


In [4]:
import coremltools as ct
from coremltools.models.neural_network.quantization_utils import _quantize_spec_weights
from coremltools.models import MLModel

# Step 1: Load model spec directly from .mlmodel inside the .mlpackage
spec = ct.utils.load_spec("/datasets/abradshaw/export_coreml_Full_F16/Decoder.mlpackage/Data/com.apple.CoreML/model.mlmodel")
weights_dir = "/datasets/abradshaw/export_coreml_Full_F16/Decoder.mlpackage/Data/com.apple.CoreML/weights"

# Step 2: Quantize weights using the *private* function (bypasses automatic wrapping)
qspec = _quantize_spec_weights(
    spec,
    nbits=8,
    quantization_mode="linear"
)

# Step 3: Manually wrap with original weights directory
qmodel = MLModel(qspec, weights_dir=weights_dir)

# Step 4: Save as a new quantized model package
qmodel.save("Decoder_fp8.mlpackage")


## Weight precision checker for the optional conversion using PTQ

In [None]:
import os

BASE = "/datasets/abradshaw/output_QAT/train_marigold_depth_qat/export_coreml"
NAMES = ["Encoder","UNetStep","Decoder"]

def weight_size(pkg):
    w = os.path.join(pkg, "Data","com.apple.CoreML","weights","weight.bin")
    return os.path.getsize(w) if os.path.isfile(w) else 0

print("Model         FP16 (MB)   INT8 (MB)   shrink")
for n in NAMES:
    p16 = os.path.join(BASE, f"{n}.mlpackage")
    p8  = os.path.join(BASE, f"{n}-INT8.mlpackage")
    s16 = weight_size(p16) / (1024**2)
    s8  = weight_size(p8)  / (1024**2)
    if s16 and s8:
        r = (1 - s8/s16) * 100
        print(f"{n:<12} {s16:9.2f}   {s8:9.2f}   {r:6.1f}%")
    else:
        print(f"{n:<12} (missing weight.bin)")



=== Encoder.mlpackage ===
weight.bin                                                      65.17 MB
TOTAL weights: 65.17 MB

=== UNetStep.mlpackage ===
weight.bin                                                    1606.60 MB
TOTAL weights: 1,606.60 MB

=== Decoder.mlpackage ===
weight.bin                                                      94.41 MB
TOTAL weights: 94.41 MB


# DEBUG - some helper functions for debugging and sanity checks on the conversion script

## For cell 1.2 for debug checks on the snaitzer and funcitnal ops remval ttaht are not supperotedd in the sciriting

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# --- 2.1 Load UNet config only, force 8 input channels ---
unet_cfg = UNet2DConditionModel.load_config(os.path.join(PATH_PRETRAINED_LOCAL, "unet"))
unet_cfg["in_channels"] = 8  # we trained on [rgb_latent(4) + target_latent(4)]
unet = UNet2DConditionModel.from_config(unet_cfg).to(device)  # empty FP32 skeleton



def count_bad_activations(model):
    n_gelu_mod = n_silu_mod = n_gelu_fun = n_silu_fun = 0
    for m in model.modules():                        # module instances
        if isinstance(m, nn.GELU): n_gelu_mod += 1
        if isinstance(m, nn.SiLU): n_silu_mod += 1
    for m in model.modules():                        # cached functional refs
        for _, attr in vars(m).items():
            if attr is F.gelu: n_gelu_fun += 1
            if attr is F.silu: n_silu_fun += 1
    return (n_gelu_mod + n_gelu_fun, n_silu_mod + n_silu_fun)

gelu_cnt, silu_cnt = count_bad_activations(unet)
print(f"[before sanitize] GELU-like: {gelu_cnt} | SiLU-like: {silu_cnt}")
# (optional) warn if nothing to replace
if gelu_cnt == 0 and silu_cnt == 0:
    print(" No GELU/SiLU found; sanitize_only may be redundant.")


# --- 2.2 Sanitize activations to match QAT training (GELU->ReLU, SiLU->Hardswish) ---
_BAD2GOOD = {
    nn.GELU: lambda _: nn.ReLU(inplace=False),
    nn.SiLU: lambda _: nn.Hardswish(),
}
_BAD_FUNCS = {F.gelu, F.silu}
def _relu(x): return F.relu(x, inplace=False)

def sanitize_only(module: nn.Module):
    for name, child in list(module.named_children()):
        replaced = False
        for bad_cls, make_good in _BAD2GOOD.items():
            if isinstance(child, bad_cls):
                setattr(module, name, make_good(child))
                replaced = True
                break
        sanitize_only(getattr(module, name) if replaced else child)
    for attr_name, attr_val in vars(module).items():
        if callable(attr_val) and attr_val in _BAD_FUNCS:
            setattr(module, attr_name, _relu)

sanitize_only(unet)

# Disable memory-efficient attention (PyTorch export prefers standard ops)
try:
    unet.disable_xformers_memory_efficient_attention()
except Exception:
    pass
unet.eval()

# --- 2.3 Load QAT state-dict; observers/fake-quant keys are ignored by strict=False ---
sd = torch.load(PATH_QAT_SD, map_location="cpu", weights_only=True)
missing, unexpected = unet.load_state_dict(sd, strict=False)
print(f"[UNet load] missing={len(missing)} unexpected={len(unexpected)}")
if unexpected:
    print("  (expected: observer/fake-quant keys show up here and are ignored)")

# quick forward sanity
with torch.no_grad():
    x = torch.randn(1, 8, 64, 64)
    t = torch.tensor([0])
    cond = torch.randn(1, 77, 1024)
    y = unet(x, t, encoder_hidden_states=cond).sample
    print("UNet forward OK →", tuple(y.shape))


# 2.4 — Activation parity check (matches what you trained)
import torch.nn as nn
from torch.export import export  # PyTorch 2.1+; if missing, see note below.

# Count Hardswish modules (should be >0) and verify no nn.SiLU left.
hswish_count = sum(isinstance(m, nn.Hardswish) for m in unet.modules())
silu_count   = sum(isinstance(m, nn.SiLU) for m in unet.modules())
print(f"Hardswish modules: {hswish_count} | SiLU modules: {silu_count}")

# Exported graph still uses inline GELU in transformer MLPs (expected).
gm = export(unet, (x, t, cond))
has_gelu = any("gelu" in (getattr(n.target, "__name__", "") or str(n.target))
               for n in gm.graph.nodes)
print("GELU present in exported graph:", has_gelu)  # should be True

# Optional: inspect a compact op table
try:
    gm.graph.print_tabular()
except Exception:
    pass

vae = AutoencoderKL.from_pretrained(PATH_PRETRAINED_LOCAL, subfolder="vae").to(device).eval()
print("VAE loaded.")
fixed_embed = torch.randn(1, 77, 1024, dtype=torch.float32)
print("fixed_embed loaded.")



## For running after cell 4, to check if the ts scirpoted models are well behaved

In [None]:
import os, math, random, torch
from collections import Counter

torch.set_grad_enabled(False)
device = torch.device("cpu")

# --- reload TS ---
enc_ts  = torch.jit.load(os.path.join(PATH_EXPORT, "Encoder.ts")).eval()
step_ts = torch.jit.load(os.path.join(PATH_EXPORT, "UNetStep.ts")).eval()
dec_ts  = torch.jit.load(os.path.join(PATH_EXPORT, "Decoder.ts")).eval()

# --- helpers ---
def _shape(t): return tuple(t.shape)

def check_graph(name, m):
    g = m.inlined_graph if hasattr(m, "inlined_graph") else m.graph
    kinds = [n.kind() for n in g.nodes()]
    c = Counter(kinds)
    bad_signatures = [
        "quantize_per_tensor", "aten::fake_quantize_per_tensor_affine",
        "xformers", "aten::_scaled_dot_product_efficient_attention"
    ]
    has_bad = [k for k in kinds if any(b in k for b in bad_signatures)]
    print(f"[graph] {name}: {len(kinds)} nodes | top ops: {c.most_common(8)}")
    if has_bad:
        print(f"  bad unexpected ops present: {sorted(set(has_bad))}")
    else:
        print("  good no obvious bad ops found")

def assert_close(a, b, name, rtol=1e-3, atol=1e-3):
    try:
        torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
        print(f"  good {name} close (rtol={rtol}, atol={atol})")
    except AssertionError as e:
        max_abs = (a - b).abs().max().item()
        print(f"  bad {name} mismatch (max_abs={max_abs:.4e})")

def no_nans(t, name):
    if torch.isnan(t).any() or torch.isinf(t).any():
        raise RuntimeError(f"{name} has NaN/Inf")
    else:
        print(f" good {name} finite")

# --- run N randomized trials with static shapes ---
N = 3
B, H, W = 1, 512, 512
h, w = H//8, W//8

for trial in range(1, N+1):
    torch.manual_seed(1234 + trial)
    print(f"\n=== Sanity trial {trial}/{N} ===")

    rgb = torch.randn(B,3,H,W, dtype=torch.float32)
    lat = torch.randn(B,4,h,w, dtype=torch.float32)
    t_f = torch.tensor([float(random.randrange(0, 50))], dtype=torch.float32)

    # Eager refs (wrappers you built earlier)
    rgb_latent_ref = enc(rgb)
    noise_ref      = step(lat, lat, t_f)
    depth_ref      = dec(lat)

    # TS runs
    rgb_latent_ts = enc_ts(rgb)
    noise_ts      = step_ts(lat, lat, t_f)
    depth_ts      = dec_ts(lat)

    # Shape checks
    print("  shapes:",
          "enc", _shape(rgb_latent_ts),
          "| step", _shape(noise_ts),
          "| dec", _shape(depth_ts))
    assert _shape(rgb_latent_ts) == (B,4,h,w)
    assert _shape(noise_ts)      == (B,4,h,w)
    assert _shape(depth_ts)      == (B,1,H,W) or _shape(depth_ts) == (B,1,h,w)

    # Finite checks
    no_nans(rgb_latent_ts, "Encoder.ts out")
    no_nans(noise_ts,      "UNetStep.ts out")
    no_nans(depth_ts,      "Decoder.ts out")

    # Numeric closeness (script/trace vs eager)
    # Tolerances are loose because tracing can fold constants & reorder ops.
    assert_close(rgb_latent_ts, rgb_latent_ref, "Encoder TS≈eager", rtol=1e-3, atol=2e-3)
    assert_close(noise_ts,      noise_ref,      "UNetStep TS≈eager", rtol=2e-3, atol=3e-3)
    assert_close(depth_ts,      depth_ref,      "Decoder TS≈eager",  rtol=1e-3, atol=2e-3)

# Graph hygiene
check_graph("Encoder.ts",  enc_ts)
check_graph("UNetStep.ts", step_ts)
check_graph("Decoder.ts",  dec_ts)

print("\nTS sanity pass completed")


-----

In [None]:
# Flops counter
import torch
from torch.profiler import profile, ProfilerActivity

def profile_model(name, path, input_tensor):
    model = torch.jit.load(path).eval().to("cpu")
    with profile(
        activities=[ProfilerActivity.CPU],
        record_shapes=True,
        with_flops=True
    ) as prof:
        if isinstance(input_tensor, tuple):
            model(*input_tensor)
        else:
            model(input_tensor)

    total_flops = sum(
        evt.flops for evt in prof.key_averages()
        if getattr(evt, "flops", None) is not None
    )
    total_macs = total_flops / 2
    flops_G = total_flops / 1e9
    macs_G  = total_macs  / 1e9

    print(f"\n{name} → {flops_G:.3f} G-FLOPs, {macs_G:.3f} G-MACs")
    print(prof.key_averages().table(sort_by="flops", row_limit=10))



rgb_latent     = torch.randn(1, 4, 32, 24)  # Output of encode_rgb
target_latent  = torch.randn(1, 4, 32, 24)  # Noisy or clean latent
timesteps_f32  = torch.tensor([500.0])     # Single timestep as float32

# Update paths to point to your actual .ts files
profile_model("Encoder",   "/datasets/abradshaw/export_coreml_Full_F16/Encoder.ts",   torch.randn(1, 3, 256, 256))
profile_model("Decoder",   "/datasets/abradshaw/export_coreml_Full_F16/Decoder.ts",   torch.randn(1, 4, 32, 32))
profile_model("UNetStep", "/datasets/abradshaw/export_coreml_Full_F16/UNetStep.ts", (rgb_latent, target_latent, timesteps_f32))




Encoder → 2.179 G-FLOPs, 1.090 G-MACs
-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total KFLOPs  
-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             aten::mm         0.61%       1.936ms         0.61%       1.937ms     645.544us             3   1610612.736  
                                          aten::addmm         0.19%     603.935us         0.22%     705.560us     705.560us             1    536870.912  
                                            aten::add         4.26%      13.554ms         4.26%      13.554ms       1.232ms            11     31981.568  
                                     

In [13]:
enc = torch.jit.load("/datasets/abradshaw/export_coreml_Full_F16/Encoder.ts").eval()
x = torch.randn(1,3,256,192)
print("Encoder input shape:", tuple(x.shape), "→ output shape:", tuple(enc(x).shape))


Encoder input shape: (1, 3, 256, 192) → output shape: (1, 4, 32, 24)


In [None]:
import torch
from ptflops import get_model_complexity_info
import torch
from ptflops import get_model_complexity_info

# Profile Encoder (.ts expects input [1,3,256,256])
encoder = torch.jit.load("/datasets/abradshaw/export_coreml_Full_F16/Encoder.ts")
encoder.eval()
macs_enc, params_enc = get_model_complexity_info(
    encoder,
    (3, 256, 256),
    as_strings=False,
    print_per_layer_stat=False,
    verbose=False
)
flops_enc = macs_enc * 2
print(f"Encoder.ts → MACs: {macs_enc:,}, FLOPs: {flops_enc:,}, Params: {params_enc:,}")

# %%
# Profile Decoder (.ts expects input [1,4,32,32])
decoder = torch.jit.load("/datasets/abradshaw/export_coreml_Full_F16/Decoder.ts")
decoder.eval()
macs_dec, params_dec = get_model_complexity_info(
    decoder,
    (4, 32, 32),
    as_strings=False,
    print_per_layer_stat=False,
    verbose=False
)
flops_dec = macs_dec * 2
print(f"Decoder.ts → MACs: {macs_dec:,}, FLOPs: {flops_dec:,}, Params: {params_dec:,}")

# %%
# Profile UNetStep by combining two latent inputs along channels
# UNetStep.ts expects (rgb_latent, target_latent, t), we approximate by merging latents

# Load module
unet_step = torch.jit.load("/datasets/abradshaw/export_coreml_Full_F16/UNetStep.ts")
unet_step.eval()

# Define wrapper to accept merged latent
def unet_step_wrapper(x):
    # x: Tensor of shape (1,8,32,32) => split into two latents
    rgb_latent, target_latent = torch.chunk(x, 2, dim=1)
    t = torch.tensor([0.0])
    return unet_step(rgb_latent, target_latent, t)

macs_step, params_step = get_model_complexity_info(
    unet_step_wrapper,
    (8, 32, 32),
    as_strings=False,
    print_per_layer_stat=False,
    verbose=False
)
flops_step = macs_step * 2
print(f"UNetStep.ts → MACs: {macs_step:,}, FLOPs: {flops_step:,}, Params: {params_step:,}")


RuntimeError: register_forward_hook is not supported on ScriptModules

In [None]:
import torch.nn as nn

gelu_like = []
silu_like = []
relu_like = []
hswish_like = []

for n, m in unet.named_modules():
    if isinstance(m, nn.GELU):   gelu_like.append(n)
    if isinstance(m, nn.SiLU):   silu_like.append(n)
    if isinstance(m, nn.ReLU):   relu_like.append(n)
    if isinstance(m, nn.Hardswish): hswish_like.append(n)

print(f"GELU left: {len(gelu_like)} | SiLU left: {len(silu_like)}")
print(f"ReLU found: {len(relu_like)} | Hardswish found: {len(hswish_like)}")
assert len(gelu_like) == 0 and len(silu_like) == 0, "Found GELU/SiLU; sanitize_only didn't fully apply."

print("✅ Sanitized activations in place (ReLU/Hardswish); no GELU/SiLU remain.")

NameError: name 'unet' is not defined

In [9]:
# 7A.1 — List the "unexpected" keys we ignored when loading strict=False
qat_sd = torch.load(PATH_QAT_SD, map_location="cpu", weights_only=True)
unexpected_like = [k for k in qat_sd.keys()
                   if ("activation_post_process" in k) or ("weight_fake_quant" in k) or ("_fake_quant" in k)]
print(f"[QAT ckpt] fake-quant/observer-like keys in file: {len(unexpected_like)} (show 10)")
print(unexpected_like[:10])

# 7A.2 — Make sure the *model* has none of these tensors after load
has_obs_attr = []
for n, m in unet.named_modules():
    ap = getattr(m, "activation_post_process", None)
    wfq = getattr(m, "weight_fake_quant", None)
    if ap is not None or wfq is not None:
        has_obs_attr.append((n, type(m).__name__, ap is not None, wfq is not None))

print(f"[model] modules still carrying observer/fake-quant attributes: {len(has_obs_attr)}")
assert len(has_obs_attr) == 0, "Some modules still have observer/fake-quant attributes."

# 7A.3 — Conv-in audit (8 channels)
w = dict(unet.named_parameters()).get("conv_in.weight", None)
assert w is not None, "conv_in.weight not found"
print("conv_in.weight shape:", tuple(w.shape))
assert w.shape[1] == 8, "conv_in is NOT 8-channel — something is off."

print("✅ Observers stripped from runtime graph; conv_in is 8ch; QAT weights loaded.")


[QAT ckpt] fake-quant/observer-like keys in file: 3087 (show 10)
['conv_in.weight_fake_quant.fake_quant_enabled', 'conv_in.weight_fake_quant.observer_enabled', 'conv_in.weight_fake_quant.scale', 'conv_in.weight_fake_quant.zero_point', 'conv_in.weight_fake_quant.activation_post_process.eps', 'conv_in.weight_fake_quant.activation_post_process.min_val', 'conv_in.weight_fake_quant.activation_post_process.max_val', 'conv_in.activation_post_process.fake_quant_enabled', 'conv_in.activation_post_process.observer_enabled', 'conv_in.activation_post_process.scale']
[model] modules still carrying observer/fake-quant attributes: 0
conv_in.weight shape: (320, 8, 3, 3)
✅ Observers stripped from runtime graph; conv_in is 8ch; QAT weights loaded.


In [10]:
import torch.nn.functional as F

bad_funcs = {F.silu, F.gelu}
bad_refs = []
for n, m in unet.named_modules():
    for attr, val in vars(m).items():
        if callable(val) and val in bad_funcs:
            bad_refs.append((n, attr, val.__name__))

print("functional SiLU/GELU refs left:", bad_refs)
assert not bad_refs, "Found stored functional SiLU/GELU references."

functional SiLU/GELU refs left: []
