In [1]:
from diffusers import UNet2DConditionModel
import torch, torch.fx as fx

CKPT = "prs-eth/marigold-depth-v1-1"

# 1️⃣  Disable every fused/compiled attention path
import torch.backends.cuda as bk
bk.enable_flash_sdp(False); bk.enable_mem_efficient_sdp(False)

unet = UNet2DConditionModel.from_pretrained(
          CKPT, subfolder="unet",
          low_cpu_mem_usage=True).cpu()

# diffusers >=0.23 has an explicit helper too
unet.disable_xformers_memory_efficient_attention()   # no-op if xformers absent :contentReference[oaicite:0]{index=0}

# 2️⃣  Optional: treat Transformer blocks as leafs so the graph stays small
class DiffusersTracer(fx.Tracer):
    def is_leaf_module(self, m, qualname):
        if "Transformer2DModel" in m.__class__.__name__:
            return True
        return super().is_leaf_module(m, qualname)

tracer = DiffusersTracer()
gm = tracer.trace(unet)

gm.graph.print_tabular()          # quick visual check
gm.recompile()                    # be sure the new GraphModule runs


  from .autonotebook import tqdm as notebook_tqdm


TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

In [7]:
import torch
from diffusers import UNet2DConditionModel
from torch.export import export
from torch.ao.quantization.qconfig_mapping import (
    get_default_qat_qconfig_mapping)
from torch.ao.quantization.quantize_pt2e import (
    prepare_qat_pt2e, convert_pt2e)
import torch, inspect
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
print("torch:", torch.__version__)
print(inspect.signature(prepare_qat_pt2e))

CKPT = "prs-eth/marigold-depth-v1-1"
unet = UNet2DConditionModel.from_pretrained(CKPT, subfolder="unet").cpu()
unet.disable_xformers_memory_efficient_attention()

example = (torch.randn(1, 8, 64, 64),     # latent
           torch.tensor([0]),             # timestep
           torch.randn(1, 77, 1024))      # text enc

gm = export(unet, example)                # records the loop as guards
qmap = get_default_qat_qconfig_mapping("x86")
gm_qat = prepare_qat_pt2e(gm, qmap)
# fine-tune ...
gm_int8 = convert_pt2e(gm_qat)
gm_int8.save("unet_int8.pt")



torch: 2.4.1+cu121
(model: torch.fx.graph_module.GraphModule, quantizer: torch.ao.quantization.quantizer.quantizer.Quantizer) -> torch.fx.graph_module.GraphModule


KeyboardInterrupt: 

## torch export for unet graph tracing

In [None]:
import torch
from diffusers import UNet2DConditionModel
from torch.export import export, save
# working
CKPT  = "prs-eth/marigold-depth-v1-1"
unet  = UNet2DConditionModel.from_pretrained(CKPT, subfolder="unet").cpu()
unet.disable_xformers_memory_efficient_attention()

example = (
    torch.randn(1, 8, 64, 64),   # latent
    torch.tensor([0]),           # timestep
    torch.randn(1, 77, 1024)     # text enc
)

gm_unet = export(unet, example)       # ← this is already a GraphModule
gm_unet.graph.print_tabular()         # nicely formatted table
# or:
print(gm_unet.graph)                  # raw ATen graph

save(gm_unet, "unet_fp32.ep") 

  from .autonotebook import tqdm as notebook_tqdm


opcode         name                                                                     target                                                                   args                                                                                                                                                           kwargs
-------------  -----------------------------------------------------------------------  -----------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------------------------------  ---------------------------------------------------------------------------
placeholder    p_time_embedding_linear_1_weight                                         p_time_embedding_linear_1_weight                                         ()                                                                                                                  

In [4]:
from torch.export import load

# 1. load the exported program
ep = load("unet_fp32.ep")

# 2. get the GraphModule that does the work
gm = ep.module()          # <-- now it's an ordinary nn.Module

# 3. run the same inputs you used for export
example = (
    torch.randn(1, 8, 64, 64),   # latent
    torch.tensor([0]),           # timestep
    torch.randn(1, 77, 1024)     # text enc
)

out1 = gm(*example)       # exported graph
out2 = unet(*example)     # original eager model

print("max |Δ| =", (out1.sample - out2.sample).abs().max())



max |Δ| = tensor(0., grad_fn=<MaxBackward1>)


In [1]:
import importlib
import torch
from diffusers import UNet2DConditionModel
from torch.export import export, save
from torch.utils._pytree import (
    register_pytree_node, SUPPORTED_NODES   # SUPPORTED_NODES = current registry
)

# ------------------------------------------------------------------
# 1) Locate UNet2DConditionOutput regardless of diffusers version
# ------------------------------------------------------------------
def _find_unet_output_class():
    # new layout (>=0.24)
    try:
        return importlib.import_module(
            "diffusers.models.unets.unet_2d_condition"
        ).UNet2DConditionOutput
    except (ModuleNotFoundError, AttributeError):
        pass
    # old layout (<=0.23)
    try:
        return importlib.import_module(
            "diffusers.models.unet_2d_condition"
        ).UNet2DConditionOutput
    except (ModuleNotFoundError, AttributeError):
        pass
    raise RuntimeError("Could not locate UNet2DConditionOutput")

UNet2DConditionOutput = _find_unet_output_class()

# ------------------------------------------------------------------
# 2) Register as a pytree **only if not registered yet**
# ------------------------------------------------------------------
if UNet2DConditionOutput not in SUPPORTED_NODES:
    def _flatten(o: UNet2DConditionOutput):
        return ((o.sample,), None)      # children, context

    def _unflatten(ctx, children):
        (sample,) = children
        return UNet2DConditionOutput(sample=sample)

    register_pytree_node(
        UNet2DConditionOutput, _flatten, _unflatten,
        serialized_type_name="UNet2DConditionOutput"
    )

# ------------------------------------------------------------------
# 3) Load the fp32 UNet
# ------------------------------------------------------------------
CKPT = "prs-eth/marigold-depth-v1-1"
unet = UNet2DConditionModel.from_pretrained(CKPT, subfolder="unet").cpu()
unet.disable_xformers_memory_efficient_attention()   # avoid un-exportable ops
unet.eval()

# ------------------------------------------------------------------
# 4) Example inputs (realistic shapes)
# ------------------------------------------------------------------
example_inputs = (
    torch.randn(1, 8, 64, 64),   # latent
    torch.tensor([0]),           # timestep
    torch.randn(1, 77, 1024)     # text embedding
)

# ------------------------------------------------------------------
# 5) Export  ➜  ExportedProgram  ➜  .ep file
# ------------------------------------------------------------------
ep = export(unet, example_inputs)
save(ep, "unet_fp32.ep")

print("✅  Exported UNet saved to: unet_fp32.ep")


  from .autonotebook import tqdm as notebook_tqdm


✅  Exported UNet saved to: unet_fp32.ep


In [10]:
# Find ops that lack a quantisation pattern
from torch.fx.passes.graph_draw import FxGraphDrawer
 
unsupported = []
for n in gm.graph.nodes:
    if n.op == "call_function":
        # Replace with your backend's supported op list
        if n.target not in torch._inductor.lowering._registered_ops:
            unsupported.append(n.target)

print("Ops with no Inductor lowering:", set(unsupported))
# view the graph visually
FxGraphDrawer(gm, "unet").run()  


ModuleNotFoundError: No module named 'torch.fx.passes.graph_draw'

In [None]:
# hwhe nweneed ot wuanitze
from torch.ao.quantization.qconfig_mapping import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantizer import X86InductorQuantizer
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e

qmap      = get_default_qat_qconfig_mapping("x86")
quantizer = X86InductorQuantizer().set_global(qmap)

gm_qat = prepare_qat_pt2e(gm, quantizer)   # inserts fake-quant nodes
# … fine-tune gm_qat for a few epochs …
gm_int8 = convert_pt2e(gm_qat)
gm_int8.save("unet_int8.pt")


In [7]:
# import torch
# from diffusers import AutoencoderKL
# from torch.export import export
# import inspect

# CKPT = "prs-eth/marigold-depth-v1-1"

# vae = AutoencoderKL.from_pretrained(CKPT, subfolder="vae").cpu()
# vae.eval()                                   # switch off dropout
# class VAEEncodeDecode(torch.nn.Module):
#     def __init__(self, core):
#         super().__init__()
#         self.core = core
#     def forward(self, x):
#         # Encode RGB → latent
#         lat = self.core.encode(x).latent_dist.sample()
#         # Decode latent → reconstruction
#         img = self.core.decode(lat).sample
#         return lat, img            # returns a tuple

# wrapper = VAEEncodeDecode(vae)
# rgb = torch.randn(1, 3, 512, 512)
# gm = export(wrapper, (rgb,))        # gm is already a GraphModule
# gm.graph.print_tabular()            # pretty table
# # or just:
# print(gm.graph)                   # raw text

  
import torch
from diffusers import AutoencoderKL
from torch.export import export

vae = AutoencoderKL.from_pretrained(
        "prs-eth/marigold-depth-v1-1", subfolder="vae").cpu().eval()

rgb = torch.randn(1, 3, 512, 512)

# AutoencoderKL.forward does: encode → (optionally sample) → decode
gm_vae = export(vae, (rgb,))            # default sample_posterior = False
gm_vae.graph.print_tabular()            # full encoder + decoder graph
print(gm_vae.graph)

save(gm_vae, "vae_fp32.ep")
print("VAE exported ➜  vae_fp32.ep")



opcode         name                                                              target                                                            args                                                                                                                                                            kwargs
-------------  ----------------------------------------------------------------  ----------------------------------------------------------------  --------------------------------------------------------------------------------------------------------------------------------------------------------------  ------------------------------------------
placeholder    p_encoder_conv_in_weight                                          p_encoder_conv_in_weight                                          ()                                                                                                                                                              {}
placeholder    p_encoder_c

In [9]:
ep_loaded = load("vae_fp32.ep")   # the ExportedProgram

gm = ep_loaded.module()           # ← call with no args, returns GraphModule
out_graph = gm(rgb)               # run the graph
out_eager = vae(rgb)              # original eager result

print("max |Δ| =", (out_graph.sample - out_eager.sample).abs().max())  # ≈ 0


max |Δ| = tensor(0., grad_fn=<MaxBackward1>)
