## This notebook is used to prepare the quantization data for the Marigold depth estimation model. Focusing on the preperation of the Unet and VAE models.

## UNET

### 1. Load the model and save as a torch.export graph for op inspection

In [1]:
import torch
from diffusers import UNet2DConditionModel
from torch.export import export, save
from torch.export import ExportedProgram
import torch.nn.functional as F   

#monkey patch
#F.gelu = F.relu
#F.silu = F.relu 

CKPT  = "prs-eth/marigold-depth-v1-1"
unet  = UNet2DConditionModel.from_pretrained(CKPT, subfolder="unet").cpu()
unet.disable_xformers_memory_efficient_attention()

example = (
    torch.randn(1, 8, 64, 64),   # latent
    torch.tensor([0]),           # timestep
    torch.randn(1, 77, 1024)     # text enc
)

gm_unet: ExportedProgram = export(unet, example)
save(gm_unet, "unet_fp32.ep")
gm_unet.graph.print_tabular()

  from .autonotebook import tqdm as notebook_tqdm


opcode         name                                                                     target                                                                   args                                                                                                                                                           kwargs
-------------  -----------------------------------------------------------------------  -----------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------------------------------  ---------------------------------------------------------------------------
placeholder    p_time_embedding_linear_1_weight                                         p_time_embedding_linear_1_weight                                         ()                                                                                                                  

### 2. Find unsupported ops like GELU and GroupNorm for swapping

In [2]:
unsupported_ops = set()
for n in gm_unet.graph.nodes:
    if n.op in ("call_function", "call_module") and "quant" not in str(n.target):
        unsupported_ops.add(str(n.target))
print("Unsupported ops:", unsupported_ops) # helps us specific the next cell

Unsupported ops: {'aten.permute.default', 'aten.sin.default', 'aten.cos.default', 'aten.exp.default', 'aten.unsqueeze.default', 'aten.add.Tensor', 'aten.expand.default', 'aten.mul.Tensor', 'aten.linear.default', 'aten.transpose.int', '<built-in function getitem>', 'aten.scaled_dot_product_attention.default', 'aten.cat.default', 'aten.group_norm.default', 'aten._to_copy.default', 'aten.relu.default', 'aten.dropout.default', 'aten.conv2d.default', 'aten.slice.Tensor', 'aten.view.default', 'aten.upsample_nearest2d.vec', 'aten.clone.default', 'aten.layer_norm.default', 'aten.arange.start', 'aten.div.Tensor', 'aten.split.Tensor'}


### 3. Patch the model ops and apply to UNet

In [3]:
# import torch.nn as nn

# def patch_model_for_qat(module: nn.Module):
#     for name, child in module.named_children():
#         # swap layers first …
#         if isinstance(child, nn.GELU):
#             setattr(module, name, nn.ReLU(inplace=True))
#         elif isinstance(child, nn.SiLU):
#             setattr(module, name, nn.ReLU(inplace=True))
#         elif isinstance(child, nn.GroupNorm):
#             setattr(module, name, nn.BatchNorm2d(child.num_channels))
#         elif isinstance(child, nn.LayerNorm):
#             setattr(module, name, nn.Identity())
#         # … then recurse regardless of what it is now
#         patch_model_for_qat(getattr(module, name))

import torch.nn as nn
import torch.nn.functional as F

# --- module-level swaps -------------------------------------------------
_BAD2GOOD = {
    nn.GELU: lambda _: nn.ReLU(inplace=False),
    nn.SiLU: lambda _: nn.Hardswish(),
}

# --- functional activation swaps ---------------------------------------
_BAD_FUNCS = {F.gelu, F.silu}
_GOOD_FUNC = F.relu

def patch_model_for_qat(module: nn.Module):
    """Recursively replace unsupported modules and functional activations."""
    # 1) swap child modules
    for name, child in list(module.named_children()):
        for bad_cls, make_good in _BAD2GOOD.items():
            if isinstance(child, bad_cls):
                setattr(module, name, make_good(child))
                child = getattr(module, name)          # updated ref
                break
        patch_model_for_qat(child)                     # recurse

    # 2) swap stored functional refs
    for attr_name, attr_val in vars(module).items():
        if callable(attr_val) and attr_val in _BAD_FUNCS:
            setattr(module, attr_name, _GOOD_FUNC)



In [4]:
patch_model_for_qat(unet)

In [5]:
#quick sanity check
bad_types = (nn.GELU, nn.SiLU, nn.GroupNorm, nn.LayerNorm)

assert not any(isinstance(m, bad_types) for m in unet.modules()), \
       "At least one forbidden layer slipped through!"
print("Layer-type sweep is clean.")

gm_patched = export(unet, example)
ops_left = {str(n.target) for n in gm_patched.graph.nodes
            if "gelu" in str(n.target) 
            or "silu" in str(n.target)
            or "group_norm" in str(n.target)
            or "layer_norm" in str(n.target)}
print("Ops still present:", ops_left)

AssertionError: At least one forbidden layer slipped through!

### 4. Add quantization stubs, prepare the model, and save

In [6]:
# for state_dict we must rebuild the quant graph at runtime
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat
import torch, copy

torch.backends.quantized.engine = "fbgemm" 
# Set QAT config and prepare
unet.qconfig = get_default_qat_qconfig("fbgemm")
qat_model = copy.deepcopy(unet)
qat_model.train()   
prepare_qat(qat_model, inplace=True)

torch.save(qat_model.state_dict(), "unet_qat_ready.pt")
print("weights written → unet_qat_ready.pt")

# Save the fully prepared model (structure + weights)
# torch.save(qat_model, "unet_qat_prepared.pth")  # full model object with stubs
# print("Full QAT module saved → unet_qat_prepared.pth")



weights written → unet_qat_ready.pt


In [10]:
# for full model using dill, includes enitre graph
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat
import torch, copy, dill

torch.backends.quantized.engine = "fbgemm" 
# Set QAT config and prepare
unet.qconfig = get_default_qat_qconfig("fbgemm")
qat_model = copy.deepcopy(unet)
qat_model.train()   
prepare_qat(qat_model, inplace=True)

torch.save(
    qat_model,
    "unet_qat_dev.pth",
    pickle_module=dill,
    pickle_protocol=dill.HIGHEST_PROTOCOL,
)
print("Full QAT module saved → unet_qat_dev.pth  (via dill)")



Full QAT module saved → unet_qat_dev.pth  (via dill)


### 5. Sanity check – Load and run dummy input through QAT model

In [None]:
import torch
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat
from diffusers import UNet2DConditionModel

# -------- rebuild identical QAT skeleton --------------------------
qat_unet = UNet2DConditionModel.from_pretrained(CKPT, subfolder="unet").cpu()
patch_model_for_qat(qat_unet)                          # your patcher
torch.backends.quantized.engine = "fbgemm"
qat_unet.qconfig = get_default_qat_qconfig("fbgemm")
qat_unet.train()
prepare_qat(qat_unet, inplace=True)                    # add observers
qat_unet.eval()

# -------- load weights-only checkpoint ----------------------------
state_dict = torch.load("unet_qat_ready.pt", map_location="cpu")
qat_unet.load_state_dict(state_dict)

# -------- dummy forward sanity check ------------------------------
with torch.no_grad():
    latent   = torch.randn(1, 8, 64, 64)
    timestep = torch.tensor([0])
    cond     = torch.randn(1, 77, 1024)
    out      = qat_unet(latent, timestep, cond)

print("State-dict QAT model → output shape:", out.sample.shape)




  state_dict = torch.load("unet_qat_ready.pt", map_location="cpu")


State-dict QAT model → output shape: torch.Size([1, 4, 64, 64])


In [19]:
# # ── minimal, safe loader for the dill file ─────────────────────────
# import dill, torch
# torch.backends.quantized.engine = "fbgemm"   # backend first!

# qat = torch.load("unet_qat_dev.pth",
#                  pickle_module=dill,
#                  map_location="cpu")

# # ---------- initialise observer buffers once -----------------------
# qat.train()                                   # observers/fake-quants active
# with torch.no_grad():
#     _ = qat(torch.randn(1, 8, 64, 64),
#             torch.tensor([0]),
#             torch.randn(1, 77, 1024))         # warm-up pass
# qat.eval()                                    # freeze for inference

# # ---------- real sanity check --------------------------------------
# with torch.no_grad():
#     out = qat(torch.randn(1, 8, 64, 64),
#               torch.tensor([0]),
#               torch.randn(1, 77, 1024))
# print("dill-pickle QAT model →", out.sample.shape)   # should print [1, 4, 64, 64]


---

### VAE

### 1. Load the model and save as a torch.export graph for op inspection

In [1]:
import torch, torch.nn.functional as F
from diffusers import AutoencoderKL
from torch.export import export, save

#monkey patch
F.gelu = F.relu
F.silu = F.relu 

CKPT = "prs-eth/marigold-depth-v1-1"
vae  = AutoencoderKL.from_pretrained(CKPT, subfolder="vae").cpu().eval()

example_rgb = torch.randn(1, 3, 512, 512)       # typical input

gm_vae = export(vae, (example_rgb,))            # sample_posterior = False
save(gm_vae, "vae_fp32.ep")
print("VAE exported ➜  vae_fp32.ep")
gm_vae.graph.print_tabular()

  from .autonotebook import tqdm as notebook_tqdm


VAE exported ➜  vae_fp32.ep
opcode         name                                                              target                                                            args                                                                                                                                                            kwargs
-------------  ----------------------------------------------------------------  ----------------------------------------------------------------  --------------------------------------------------------------------------------------------------------------------------------------------------------------  ------------------------------------------
placeholder    p_encoder_conv_in_weight                                          p_encoder_conv_in_weight                                          ()                                                                                                                                                              {

### 2. Find unsupported ops like GELU and GroupNorm for swapping

In [2]:
unsupported_ops = {
    str(n.target)
    for n in gm_vae.graph.nodes
    if n.op in ("call_function", "call_module") and "quant" not in str(n.target)
}
print("Unsupported ops in VAE:", unsupported_ops)


Unsupported ops in VAE: {'aten.upsample_nearest2d.vec', 'aten.clone.default', 'aten.split.Tensor', 'aten.dropout.default', 'aten.relu.default', 'aten.pad.default', 'aten.linear.default', 'aten.scaled_dot_product_attention.default', 'aten._to_copy.default', '<built-in function getitem>', 'aten.view.default', 'aten.add.Tensor', 'aten.div.Tensor', 'aten.group_norm.default', 'aten.transpose.int', 'aten.conv2d.default'}


### 3. Patch the model ops and apply to VAE

In [2]:
import torch.nn as nn
import torch.nn.functional as F

F.dropout = lambda x, p=0.0, train=False, inplace=False: x        # keep
if hasattr(F, "scaled_dot_product_attention"):
    def _fake_sdpa(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False):
        w = (q @ k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
        return (w.softmax(-1) @ v)
    F.scaled_dot_product_attention = _fake_sdpa

_BAD2GOOD = {
    nn.GELU:      lambda _: nn.ReLU(inplace=False),      # ← NOT in-place
    nn.SiLU:      lambda _: nn.ReLU(inplace=False),
    nn.GroupNorm: lambda _: nn.Identity(),
    #nn.GroupNorm: lambda m: nn.InstanceNorm2d(m.num_channels,
                                          #eps=1e-3, affine=True),
    nn.LayerNorm: lambda _: nn.Identity(),
}
_BAD_FUNCS = {F.gelu, F.silu}
_GOOD_FUNC = lambda x: F.relu(x, inplace=False)         # functional, no mutate

def patch_model_for_qat(m: nn.Module):
    for name, child in list(m.named_children()):
        for bad, make_good in _BAD2GOOD.items():
            if isinstance(child, bad):
                setattr(m, name, make_good(child))
                child = getattr(m, name)
                break
        patch_model_for_qat(child)
    for attr, val in vars(m).items():
        if callable(val) and val in _BAD_FUNCS:
            setattr(m, attr, _GOOD_FUNC)


# import torch.nn as nn
# import torch.nn.functional as F

# _BAD2GOOD = {
#     nn.GELU:      lambda _: nn.ReLU(inplace=True),
#     nn.SiLU:      lambda _: nn.ReLU(inplace=True),
#     nn.GroupNorm: lambda _: nn.Identity(),
#     nn.LayerNorm: lambda _: nn.Identity(),
# }
# _BAD_FUNCS = {F.gelu, F.silu}
# _GOOD_FUNC = F.relu

# def patch_model_for_qat(module: nn.Module):
#     for name, child in list(module.named_children()):
#         for bad_cls, make_good in _BAD2GOOD.items():
#             if isinstance(child, bad_cls):
#                 setattr(module, name, make_good(child))
#                 child = getattr(module, name)
#                 break
#         patch_model_for_qat(child)
#     for attr, val in vars(module).items():
#         if callable(val) and val in _BAD_FUNCS:
#             setattr(module, attr, _GOOD_FUNC)



In [4]:
patch_model_for_qat(vae)

In [5]:
#quick sanity check
bad_types = (nn.GELU, nn.SiLU, nn.GroupNorm, nn.LayerNorm)

assert not any(isinstance(m, bad_types) for m in vae.modules()), \
        "At least one forbidden layer slipped through!"
print("Layer-type sweep is clean.")

gm_patched = export(vae, (example_rgb,))
left = {str(n.target) for n in gm_patched.graph.nodes
        if any(k in str(n.target) for k in ("gelu","silu","group_norm","layer_norm"))}
print("Ops still present:", left)


Layer-type sweep is clean.
Ops still present: set()


### 4. Add quantization stubs, prepare the model, and save

In [6]:
# for state_dict we must rebuild the quant graph at runtime
import torch, copy, dill
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat

torch.backends.quantized.engine = "fbgemm"

# -- 4.1  weights-only checkpoint -------------------------------------------
vae.qconfig = get_default_qat_qconfig("fbgemm")   # attach default QAT cfg
qat_vae = copy.deepcopy(vae)                      # deep-copy keeps orig safe
qat_vae.train()                                   # observers need train()
prepare_qat(qat_vae, inplace=True)                # inserts fake-quant/observers

torch.save(qat_vae.state_dict(), "vae_qat_ready.pt")
print("weights written →  vae_qat_ready.pt")



weights written →  vae_qat_ready.pt


In [7]:
# for full model using dill, includes enitre graph
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat
import torch, copy, dill

torch.backends.quantized.engine = "fbgemm" 
# Set QAT config and prepare
vae.qconfig = get_default_qat_qconfig("fbgemm")   # attach default QAT cfg
qat_vae = copy.deepcopy(vae)                      # deep-copy keeps orig safe
qat_vae.train()                                   # observers need train()
prepare_qat(qat_vae, inplace=True)                # inserts fake-quant/observers

torch.save(
    qat_vae,
    "vae_qat_dev.pth",
    pickle_module=dill,
    pickle_protocol=dill.HIGHEST_PROTOCOL,
)
print("Full QAT module saved →  vae_qat_dev.pth  (via dill)")

Full QAT module saved →  vae_qat_dev.pth  (via dill)


### 5. Sanity check – Load and run dummy input through QAT model

In [None]:
import torch
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat
from diffusers import AutoencoderKL

# ---- 5.1  Re-instantiate *vanilla* VAE -------------------------------------
qat_vae = AutoencoderKL.from_pretrained(CKPT, subfolder="vae").cpu()

patch_model_for_qat(qat_vae)                      # <-- your sanitizer
torch.backends.quantized.engine = "fbgemm"
qat_vae.qconfig = get_default_qat_qconfig("fbgemm")
qat_vae.train()
prepare_qat(qat_vae, inplace=True)                # add observers
qat_vae.eval()                                    # switch to eval for test

# ---- 5.2  Load weights-only checkpoint -------------------------------------
state_dict = torch.load("vae_qat_ready.pt", map_location="cpu")
qat_vae.load_state_dict(state_dict)


with torch.no_grad():
    x_dummy = torch.randn(1, 3, 512, 512)
    # full round-trip → returns AutoencoderKLOutput
    out = qat_vae(x_dummy)

print(type(out))                 # <class 'diffusers.utils.outputs.DecoderOutput'>
print(out.keys())                # ('sample',)
print("reconstruction shape:", out.sample.shape)

# with torch.no_grad():
#     x_dummy  = torch.randn(1, 3, 512, 512)
#     recon    = qat_vae(x_dummy)                   # encode→decode happens
# print("State-dict QAT VAE ✓ output shape:", recon.shape)

  state_dict = torch.load("vae_qat_ready.pt", map_location="cpu")


<class 'diffusers.models.autoencoders.vae.DecoderOutput'>
odict_keys(['sample'])
reconstruction shape: torch.Size([1, 3, 512, 512])


## sandbox
---

In [5]:
# ----------------- PTQ Convert + Dummy Inference -----------------
import torch
from torch import nn
from torch.ao.quantization import (
    QuantStub, DeQuantStub,
    get_default_qat_qconfig, prepare_qat, convert,
    disable_observer
)
from diffusers import UNet2DConditionModel

# 1) Load & patch FP32 UNet (GELU→ReLU, GN→BN, etc)
CKPT = "prs-eth/marigold-depth-v1-1"
fp32 = UNet2DConditionModel.from_pretrained(CKPT, subfolder="unet").cpu()
patch_model_for_qat(fp32)

# 1b) Patch out the dtype-cast so quantized tensors don’t break
def _patched_time_embed(self, sample, timestep):
    return self.time_proj(timestep)   # drop `.to(dtype=sample.dtype)`
fp32.get_time_embed = _patched_time_embed.__get__(fp32, UNet2DConditionModel)

# 2) Wrap with QuantStub/DeQuantStub so we can feed pure float32
class QuantWrapper(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.quant_lat  = QuantStub()
        self.quant_cond = QuantStub()
        self.unet       = unet
        self.dequant    = DeQuantStub()
    def forward(self, latent, timestep, cond):
        lq = self.quant_lat(latent)
        cq = self.quant_cond(cond)
        out = self.unet(lq, timestep, cq)
        out.sample = self.dequant(out.sample)
        return out

model = QuantWrapper(fp32)

# 3) Insert observers/fake-quant
model.qconfig = get_default_qat_qconfig("fbgemm")
prepare_qat(model, inplace=True)

# 4) Calibration pass (dummy data just to fill min/max)
model.train()
for _ in range(5):
    _ = model(
        torch.randn(1, 8, 64, 64),
        torch.tensor([0]),
        torch.randn(1, 77, 1024),
    )

# 5) Freeze observers & BN, switch to eval
disable_observer(model)
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.eval()

# 6) Convert to real INT8
int8_model = convert(model, inplace=False)

# 7) Dummy inference still works now that we’ve patched get_time_embed
with torch.no_grad():
    out = int8_model(
        torch.randn(1, 8, 64, 64),
        torch.tensor([0]),
        torch.randn(1, 77, 1024),
    )
print("✓ dummy INT8 inference OK – output shape:", out.sample.shape)

# 8) Save both checkpoints
torch.save(model.state_dict(),    "unet_ptq_fakequant.pt")
torch.save(int8_model.state_dict(), "unet_ptq_int8.pt")
print("✔ Saved unet_ptq_fakequant.pt and unet_ptq_int8.pt")


NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear' is only available for these backends: [Meta, QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at ../aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at ../aten/src/ATen/native/quantized/cpu/qlinear.cpp:1317 [kernel]
QuantizedCUDA: registered at ../aten/src/ATen/native/quantized/cudnn/Linear.cpp:359 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:497 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:349 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradMeta: registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at ../torch/csrc/autograd/TraceTypeManual.cpp:297 [backend fallback]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastXPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:351 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:207 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:493 [backend fallback]
PreDispatch: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]


In [None]:
# ================================================================
#  Minimal “convert-only” sanity cell (no extra wrappers, no INT-8
#  forward test).  Converts and saves without crashing.
# ================================================================
import torch, copy
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat, convert
from diffusers import UNet2DConditionModel

torch.backends.quantized.engine = "fbgemm"          # static-INT8 backend
CKPT = "prs-eth/marigold-depth-v1-1"

# ----------------------------------------------------------------
# 0)  fresh, already-patched UNet (you wrote patch_model_for_qat)
# ----------------------------------------------------------------
fp32 = UNet2DConditionModel.from_pretrained(CKPT, subfolder="unet").cpu()
patch_model_for_qat(fp32)                           # GELU→ReLU, GN→BN …
fp32.qconfig = get_default_qat_qconfig("fbgemm")

# ----------------------------------------------------------------
# 1)  insert observers / fake-quant layers
# ----------------------------------------------------------------
qat = copy.deepcopy(fp32).train()
prepare_qat(qat, inplace=True)

# ----------------------------------------------------------------
# 2)  *one* calibration forward – fills observer stats
# ----------------------------------------------------------------
with torch.no_grad():
    latent = torch.randn(1, 8, 64, 64)
    tstep  = torch.tensor([0])
    cond   = torch.randn(1, 77, 1024)
    _      = qat(latent, tstep, cond)               # fake-quant graph runs

print("✓ observers populated – ready to convert")

qat.eval()  # convert() requires eval mode

# ----------------------
# ------------------------------------------
# 3)  bake INT-8 kernels
# ----------------------------------------------------------------
int8_net = convert(qat, inplace=False)              # success == “sanitised”

print("✓ convert() succeeded – INT-8 weights produced")

# ----------------------------------------------------------------
# 4)  save both checkpoints
# ----------------------------------------------------------------
torch.save(qat.state_dict(),  "unet_qat_fakequant.pt")  # resume-training point
torch.save(int8_net.state_dict(), "unet_qat_int8.pt")   # deploy-time weights
print("✔ wrote  unet_qat_fakequant.pt  and  unet_qat_int8.pt")


✓ observers populated – ready to convert
✓ convert() succeeded – INT-8 weights produced
✔ wrote  unet_qat_fakequant.pt  and  unet_qat_int8.pt


In [15]:
import torch
import torch.nn.quantized as nnq

# Should NOT raise an error
m = nnq.Linear(4, 4)
x = torch.quantize_per_tensor(torch.randn(1, 4), scale=1.0, zero_point=0, dtype=torch.quint8)
out = m(x)
print("✓ Quantized linear op works")


✓ Quantized linear op works
