In [None]:
import torch
from transformers import CLIPTextModel, CLIPTokenizer

# Load tokenizer from correct subfolder
tokenizer = CLIPTokenizer.from_pretrained("timbrooks/instruct-pix2pix", subfolder="tokenizer")

# Load text encoder from correct subfolder
text_encoder = CLIPTextModel.from_pretrained("timbrooks/instruct-pix2pix", subfolder="text_encoder").eval()

# Dummy input for ONNX export
dummy_input = torch.randint(0, 49408, (1, 77))  # Tokenized text input

# Export to ONNX
torch.onnx.export(
    text_encoder,
    dummy_input,
    "text_encoder.onnx",
    input_names=["input_text"],
    output_names=["text_embeddings"],
    dynamic_axes={"input_text": {0: "batch_size"}, "text_embeddings": {0: "batch_size"}},
    opset_version=17
)

print("Text Encoder ONNX exported successfully!")


In [None]:
from diffusers import UNet2DConditionModel, StableDiffusionInstructPix2PixPipeline

# Load the pipeline
model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16, safety_checker=None
)
unet = pipe.unet.to("cuda").eval()

# Define correct data type
dtype = torch.float16  # Ensure all inputs match model's dtype

# Dummy inputs (must be float16)
batch_size = 1
latent_channels = 4
latent_height = 64
latent_width = 64

# Convert all inputs to float16
dummy_noise_latents = torch.randn(batch_size, latent_channels, latent_height, latent_width, device="cuda", dtype=dtype)  # (1, 4, 64, 64)
dummy_image_latents = torch.randn(batch_size, latent_channels, latent_height, latent_width, device="cuda", dtype=dtype)  # (1, 4, 64, 64)
dummy_latents = torch.cat([dummy_noise_latents, dummy_image_latents], dim=1)  # (1, 8, 64, 64)

dummy_timestep = torch.tensor([1], device="cuda", dtype=dtype)  # Convert timestep to float16
dummy_text_embeddings = torch.randn(batch_size, 77, 768, device="cuda", dtype=dtype)  # Convert text embeddings to float16

# Export to ONNX
torch.onnx.export(
    unet,
    (dummy_latents, dummy_timestep, dummy_text_embeddings),
    "unet.onnx",
    input_names=["latents", "timestep", "text_embeddings"],
    output_names=["predicted_noise"],
    dynamic_axes={"latents": {0: "batch_size"}, "text_embeddings": {0: "batch_size"}},
    opset_version=17
)

print("UNet ONNX exported successfully!")
