#RUN KAGGLE AND COLAB

In [None]:
!pip install diffusers accelerate safetensors transformers
!pip install onnxruntime-gpu onnx
!pip install onnxruntime onnx

#CONVERT

In [None]:
import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
import onnx
import torch.nn as nn

# Load the model
model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
pipe.to("cuda")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

# Set models to evaluation mode
pipe.unet.eval()
pipe.vae.eval()

# Define dimensions
batch_size = 1
latent_channels = 4
image_latent_channels = 4
height = 64  # 512/8 due to vae_scale_factor=8
width = 64
sequence_length = 77
hidden_size = 768

# --- Export UNet ---
dummy_latents = torch.randn(batch_size, latent_channels, height, width).to("cuda").half()
dummy_image_latents = torch.randn(batch_size, image_latent_channels, height, width).to("cuda").half()
dummy_timestep = torch.ones(batch_size).to("cuda").half()
dummy_encoder_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to("cuda").half()
dummy_model_input = torch.cat([dummy_latents, dummy_image_latents], dim=1)

onnx_path_unet = "instruct_pix2pix_unet.onnx"
torch.onnx.export(
    pipe.unet,
    (dummy_model_input, dummy_timestep, dummy_encoder_hidden_states),
    onnx_path_unet,
    input_names=["latent_model_input", "timestep", "encoder_hidden_states"],
    output_names=["noise_pred"],
    dynamic_axes={
        "latent_model_input": {0: "batch_size"},
        "timestep": {0: "batch_size"},
        "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
        "noise_pred": {0: "batch_size"}
    },
    opset_version=17,
    do_constant_folding=True,
    export_params=True
)

# --- Export VAE Encoder with Wrapper Logic ---
class VAEEncoderWrapper(nn.Module):
    def __init__(self, vae):
        super().__init__()
        self.vae = vae

    def forward(self, image):
        # Mimic vae.encode() but return raw latents instead of distribution
        encoder_output = self.vae.encode(image)
        latents = encoder_output.latent_dist.mode()  # Use mode instead of sampling for determinism
        return latents

vae_encoder_wrapper = VAEEncoderWrapper(pipe.vae)
vae_encoder_wrapper.eval()

dummy_image = torch.randn(1, 3, 512, 512).to("cuda").half()
onnx_path_vae_encoder = "instruct_pix2pix_vae_encoder.onnx"
torch.onnx.export(
    vae_encoder_wrapper,
    dummy_image,
    onnx_path_vae_encoder,
    input_names=["image"],
    output_names=["latents"],
    dynamic_axes={"image": {0: "batch_size"}, "latents": {0: "batch_size"}},
    opset_version=17,
    do_constant_folding=True,
    export_params=True
)

# --- Export VAE Decoder with Wrapper Logic ---
class VAEDecoderWrapper(nn.Module):
    def __init__(self, vae):
        super().__init__()
        self.vae = vae
        self.scaling_factor = vae.config.scaling_factor

    def forward(self, latents):
        # Mimic vae.decode() including scaling
        scaled_latents = latents / self.scaling_factor
        decoded = self.vae.decode(scaled_latents)
        return decoded[0]  # Return the image tensor directly

vae_decoder_wrapper = VAEDecoderWrapper(pipe.vae)
vae_decoder_wrapper.eval()

dummy_latents = torch.randn(1, 4, 64, 64).to("cuda").half()
onnx_path_vae_decoder = "instruct_pix2pix_vae_decoder.onnx"
torch.onnx.export(
    vae_decoder_wrapper,
    dummy_latents,
    onnx_path_vae_decoder,
    input_names=["latents"],
    output_names=["decoded_image"],
    dynamic_axes={"latents": {0: "batch_size"}, "decoded_image": {0: "batch_size"}},
    opset_version=17,
    do_constant_folding=True,
    export_params=True
)

# Verify exports
for path in [onnx_path_unet, onnx_path_vae_encoder, onnx_path_vae_decoder]:
    model = onnx.load(path)
    onnx.checker.check_model(model)
    print(f"Verified {path}")

print("Export completed successfully")