# Import

In [None]:
!pip install diffusers accelerate safetensors transformers
!pip install onnxruntime-gpu onnx
!pip install onnxruntime onnx

# Convert_ONNX

In [None]:
import PIL
import requests
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
import onnx
import torch.nn as nn

# Load the model
model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
pipe.to("cuda")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

# Set models to evaluation mode
pipe.unet.eval()
pipe.vae.eval()

# Define dimensions
batch_size = 1
latent_channels = 4
image_latent_channels = 4
height = 64  # 512/8 due to vae_scale_factor=8
width = 64
sequence_length = 77
hidden_size = 768

# --- Export UNet ---
dummy_latents = torch.randn(batch_size, latent_channels, height, width).to("cuda").half()
dummy_image_latents = torch.randn(batch_size, image_latent_channels, height, width).to("cuda").half()
dummy_timestep = torch.ones(batch_size).to("cuda").half()
dummy_encoder_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to("cuda").half()
dummy_model_input = torch.cat([dummy_latents, dummy_image_latents], dim=1)

onnx_path_unet = "instruct_pix2pix_unet.onnx"
torch.onnx.export(
    pipe.unet,
    (dummy_model_input, dummy_timestep, dummy_encoder_hidden_states),
    onnx_path_unet,
    input_names=["latent_model_input", "timestep", "encoder_hidden_states"],
    output_names=["noise_pred"],
    dynamic_axes={
        "latent_model_input": {0: "batch_size"},
        "timestep": {0: "batch_size"},
        "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
        "noise_pred": {0: "batch_size"}
    },
    opset_version=17,
    do_constant_folding=True,
    export_params=True
)

# --- Export VAE Encoder with Wrapper Logic ---
class VAEEncoderWrapper(nn.Module):
    def __init__(self, vae):
        super().__init__()
        self.vae = vae

    def forward(self, image):
        # Mimic vae.encode() but return raw latents instead of distribution
        encoder_output = self.vae.encode(image)
        latents = encoder_output.latent_dist.mode()  # Use mode instead of sampling for determinism
        return latents

vae_encoder_wrapper = VAEEncoderWrapper(pipe.vae)
vae_encoder_wrapper.eval()

dummy_image = torch.randn(1, 3, 512, 512).to("cuda").half()
onnx_path_vae_encoder = "instruct_pix2pix_vae_encoder.onnx"
torch.onnx.export(
    vae_encoder_wrapper,
    dummy_image,
    onnx_path_vae_encoder,
    input_names=["image"],
    output_names=["latents"],
    dynamic_axes={"image": {0: "batch_size"}, "latents": {0: "batch_size"}},
    opset_version=17,
    do_constant_folding=True,
    export_params=True
)

# --- Export VAE Decoder with Wrapper Logic ---
class VAEDecoderWrapper(nn.Module):
    def __init__(self, vae):
        super().__init__()
        self.vae = vae
        self.scaling_factor = vae.config.scaling_factor

    def forward(self, latents):
        # Mimic vae.decode() including scaling
        scaled_latents = latents / self.scaling_factor
        decoded = self.vae.decode(scaled_latents)
        return decoded[0]  # Return the image tensor directly

vae_decoder_wrapper = VAEDecoderWrapper(pipe.vae)
vae_decoder_wrapper.eval()

dummy_latents = torch.randn(1, 4, 64, 64).to("cuda").half()
onnx_path_vae_decoder = "instruct_pix2pix_vae_decoder.onnx"
torch.onnx.export(
    vae_decoder_wrapper,
    dummy_latents,
    onnx_path_vae_decoder,
    input_names=["latents"],
    output_names=["decoded_image"],
    dynamic_axes={"latents": {0: "batch_size"}, "decoded_image": {0: "batch_size"}},
    opset_version=17,
    do_constant_folding=True,
    export_params=True
)

# Verify exports
for path in [onnx_path_unet, onnx_path_vae_encoder, onnx_path_vae_decoder]:
    model = onnx.load(path)
    onnx.checker.check_model(model)
    print(f"Verified {path}")

print("Export completed successfully")

# Test_inference

In [None]:
import numpy as np
import torch
import onnx
import onnxruntime as ort
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
from transformers import CLIPTokenizer, CLIPTextModel

# Load the original pipeline for non-ONNX components
model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
pipe.to("cuda")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

# Load ONNX models
unet_session = ort.InferenceSession("instruct_pix2pix_unet.onnx", providers=['CUDAExecutionProvider'])
vae_encoder_session = ort.InferenceSession("instruct_pix2pix_vae_encoder.onnx", providers=['CUDAExecutionProvider'])
vae_decoder_session = ort.InferenceSession("instruct_pix2pix_vae_decoder.onnx", providers=['CUDAExecutionProvider'])

# Setup CLIP components (kept in PyTorch)
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
text_encoder.eval()

# Setup scheduler and image processor
scheduler = pipe.scheduler
image_processor = pipe.image_processor
vae_scale_factor = pipe.vae_scale_factor

# Parameters
prompt = "white balance"
negative_prompt = ""
num_inference_steps = 10
guidance_scale = 7.5
image_guidance_scale = 1.5
batch_size = 1
height = 512
width = 512
do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0

# Load and preprocess input image
image = load_image("/kaggle/input/test-data/IMG_1685.png").resize((512, 512))
image_tensor = image_processor.preprocess(image).to("cuda").half()

# Encode image to latents using ONNX VAE encoder
image_np = image_tensor.cpu().numpy().astype(np.float16)
image_latents = vae_encoder_session.run(["latents"], {"image": image_np})[0]
image_latents = torch.from_numpy(image_latents).to("cuda").half()

# Encode prompt using CLIP
def encode_prompt(prompt, negative_prompt, device, num_images_per_prompt, do_classifier_free_guidance):
    text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        prompt_embeds = text_encoder(text_inputs.input_ids.to(device))[0]
    
    if do_classifier_free_guidance:
        neg_inputs = tokenizer(negative_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        with torch.no_grad():
            negative_prompt_embeds = text_encoder(neg_inputs.input_ids.to(device))[0]
        prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
    
    prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1, 1)
    return prompt_embeds

prompt_embeds = encode_prompt(prompt, negative_prompt, "cuda", 1, do_classifier_free_guidance)

# Prepare initial noise latents
latents_shape = (batch_size, 4, height // vae_scale_factor, width // vae_scale_factor)
latents = torch.randn(latents_shape, device="cuda", dtype=torch.float16) * scheduler.init_noise_sigma

# Prepare image latents for CFG
if do_classifier_free_guidance:
    uncond_image_latents = torch.zeros_like(image_latents)
    image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)

# Diffusion loop
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

for i, t in enumerate(timesteps):
    # Prepare inputs for UNet
    if do_classifier_free_guidance:
        latent_model_input = torch.cat([latents] * 3)
    else:
        latent_model_input = latents
    
    # Scale latents
    scaled_latent_model_input = scheduler.scale_model_input(latent_model_input, t)
    
    # Concatenate with image latents
    scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
    
    # Convert to numpy with float16
    latent_input_np = scaled_latent_model_input.cpu().numpy().astype(np.float16)
    timestep_np = torch.full((latent_model_input.shape[0],), t.item(), device="cuda", dtype=torch.float16).cpu().numpy().astype(np.float16)
    encoder_hidden_states_np = prompt_embeds.cpu().numpy().astype(np.float16)
    
    # Run UNet inference with ONNX
    noise_pred = unet_session.run(
        ["noise_pred"],
        {
            "latent_model_input": latent_input_np,
            "timestep": timestep_np,
            "encoder_hidden_states": encoder_hidden_states_np
        }
    )[0]
    noise_pred = torch.from_numpy(noise_pred).to("cuda").half()
    
    # Perform guidance
    if do_classifier_free_guidance:
        noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
        noise_pred = (
            noise_pred_uncond
            + guidance_scale * (noise_pred_text - noise_pred_image)
            + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
        )
    
    # Step the scheduler
    latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

# Decode latents to image using ONNX VAE decoder
latents_np = latents.cpu().numpy().astype(np.float16)
decoded_image = vae_decoder_session.run(["decoded_image"], {"latents": latents_np})[0]
decoded_image = torch.from_numpy(decoded_image).to("cuda").half()

# Post-process the output
output_image = image_processor.postprocess(decoded_image, output_type="pil")[0]

# Save the result
output_image.save("/kaggle/working/output_image.png")
print("Inference completed, image saved as 'output_image.png'")