From 347804d7d3a3f5b55d83ea9b09b58e8107eafc1e Mon Sep 17 00:00:00 2001 From: Asfiya Baig Date: Mon, 11 Aug 2025 13:22:18 -0700 Subject: [PATCH 1/2] Enable Stable Diffusion 3.5 medium, large, and controlnet pipelines Signed-off-by: Asfiya Baig --- demo/Diffusion/README.md | 10 +- demo/Diffusion/demo_controlnet_sd35.py | 67 +++-- demo/Diffusion/demo_diffusion/dd_argparse.py | 47 ++-- .../demo_diffusion/model/__init__.py | 4 +- .../demo_diffusion/model/base_model.py | 2 + .../demo_diffusion/model/controlnet.py | 225 +++++++++++++++++ .../model/diffusion_transformer.py | 229 +++++------------- demo/Diffusion/demo_diffusion/model/load.py | 5 +- .../pipeline/diffusion_pipeline.py | 11 +- .../demo_diffusion/pipeline/flux_pipeline.py | 2 +- .../pipeline/stable_diffusion_35_pipeline.py | 178 ++++++++------ .../pipeline/stable_diffusion_pipeline.py | 2 +- 12 files changed, 456 insertions(+), 326 deletions(-) create mode 100644 demo/Diffusion/demo_diffusion/model/controlnet.py diff --git a/demo/Diffusion/README.md b/demo/Diffusion/README.md index 69d36c8f..a92e1d17 100755 --- a/demo/Diffusion/README.md +++ b/demo/Diffusion/README.md @@ -7,7 +7,7 @@ This demo application ("demoDiffusion") showcases the acceleration of Stable Dif ### Clone the TensorRT OSS repository ```bash -git clone git@github.com:NVIDIA/TensorRT.git -b release/10.13 --single-branch +git clone git@github.com:NVIDIA/TensorRT.git -b release/sd35 --single-branch cd TensorRT ``` @@ -210,7 +210,7 @@ Run the command below to generate an image using Stable Diffusion 3 and Stable D python3 demo_txt2img_sd3.py "A vibrant street wall covered in colorful graffiti, the centerpiece spells \"SD3 MEDIUM\", in a storm of colors" --version sd3 --hf-token=$HF_TOKEN # Stable Diffusion 3.5-medium -python3 demo_txt2img_sd35.py "a beautiful photograph of Mt. Fuji during cherry blossom" --version=3.5-medium --denoising-steps=30 --guidance-scale 3.5 --hf-token=$HF_TOKEN --bf16 +python3 demo_txt2img_sd35.py "a beautiful photograph of Mt. Fuji during cherry blossom" --version=3.5-medium --denoising-steps=30 --guidance-scale 3.5 --hf-token=$HF_TOKEN --bf16 --download-onnx-models # Stable Diffusion 3.5-large python3 demo_txt2img_sd35.py "a beautiful photograph of Mt. Fuji during cherry blossom" --version=3.5-large --denoising-steps=30 --guidance-scale 3.5 --hf-token=$HF_TOKEN --bf16 --download-onnx-models @@ -234,13 +234,13 @@ Note that a denosing-percentage is applied to the number of denoising-steps when ```bash # Depth -python3 demo_controlnet_sd35.py "a photo of a man" --controlnet-type depth --hf-token=$HF_TOKEN --denoising-steps 40 --guidance-scale 4.5 --bf16 +python3 demo_controlnet_sd35.py "a photo of a man" --controlnet-type depth --hf-token=$HF_TOKEN --denoising-steps 40 --guidance-scale 4.5 --bf16 --download-onnx-models # Canny -python3 demo_controlnet_sd35.py "A Night time photo taken by Leica M11, portrait of a Japanese woman in a kimono, looking at the camera, Cherry blossoms" --controlnet-type canny --hf-token=$HF_TOKEN --denoising-steps 60 --guidance-scale 3.5 --bf16 +python3 demo_controlnet_sd35.py "A Night time photo taken by Leica M11, portrait of a Japanese woman in a kimono, looking at the camera, Cherry blossoms" --controlnet-type canny --hf-token=$HF_TOKEN --denoising-steps 60 --guidance-scale 3.5 --bf16 --download-onnx-models # Blur -python3 demo_controlnet_sd35.py "generated ai art, a tiny, lost rubber ducky in an action shot close-up, surfing the humongous waves, inside the tube, in the style of Kelly Slater" --controlnet-type blur --hf-token=$HF_TOKEN --denoising-steps 60 --guidance-scale 3.5 --bf16 +python3 demo_controlnet_sd35.py "generated ai art, a tiny, lost rubber ducky in an action shot close-up, surfing the humongous waves, inside the tube, in the style of Kelly Slater" --controlnet-type blur --hf-token=$HF_TOKEN --denoising-steps 60 --guidance-scale 3.5 --bf16 --download-onnx-models ``` ### Generate a video guided by an initial image using Stable Video Diffusion diff --git a/demo/Diffusion/demo_controlnet_sd35.py b/demo/Diffusion/demo_controlnet_sd35.py index e2dd69db..f4b1d87c 100644 --- a/demo/Diffusion/demo_controlnet_sd35.py +++ b/demo/Diffusion/demo_controlnet_sd35.py @@ -43,7 +43,7 @@ def parseArgs(): parser.add_argument( "--max-sequence-length", type=int, - default=77, + default=256, help="Maximum sequence length to use with the prompt.", ) parser.add_argument( @@ -55,17 +55,15 @@ def parseArgs(): ) parser.add_argument( "--controlnet-type", - nargs="+", type=str, - default=["canny"], - help="Controlnet type, can be `None`, `str` or `str` list from ['canny', 'depth', 'blur']", + default="canny", + help="Controlnet type (single type only), can be 'canny', 'depth', 'blur', etc.", ) parser.add_argument( "--controlnet-scale", - nargs="+", type=float, - default=[1.0], - help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original unet, can be `None`, `float` or `float` list", + default=1.0, + help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original Transformer", ) return parser.parse_args() @@ -99,25 +97,15 @@ def process_demo_args(args): ) # Controlnet configuration - if not isinstance(args.controlnet_type, list): - raise ValueError( - f"`--controlnet-type` must be of type `str` or `str` list, but is {type(args.controlnet_type)}" - ) + if not isinstance(args.controlnet_type, str): + raise ValueError(f"`--controlnet-type` must be of type `str`, but is {type(args.controlnet_type)}") # Controlnet configuration - if not isinstance(args.controlnet_scale, list): - raise ValueError( - f"`--controlnet-scale`` must be of type `float` or `float` list, but is {type(args.controlnet_scale)}" - ) - - # Check number of ControlNets to ControlNet scales - if len(args.controlnet_type) != len(args.controlnet_scale): - raise ValueError( - f"Numbers of ControlNets {len(args.controlnet_type)} should be equal to number of ControlNet scales {len(args.controlnet_scale)}." - ) + if not isinstance(args.controlnet_scale, float): + raise ValueError(f"`--controlnet-scale` must be of type `float`, but is {type(args.controlnet_scale)}") # Convert controlnet scales to tensor - controlnet_scale = torch.FloatTensor(args.controlnet_scale) + controlnet_scale = torch.FloatTensor([args.controlnet_scale]) # Check images input_images = [] @@ -125,22 +113,23 @@ def process_demo_args(args): for image in args.control_image: input_images.append(Image.open(image)) else: - for controlnet in args.controlnet_type: - if controlnet == "canny": - canny_image = image_module.download_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/canny.png") - input_images.append(canny_image.resize((args.height, args.width))) - elif controlnet == "depth": - depth_image = image_module.download_image( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png" - ) - input_images.append(depth_image.resize((args.height, args.width))) - elif controlnet == "blur": - blur_image = image_module.download_image( - "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/blur.png" - ) - input_images.append(blur_image.resize((args.height, args.width))) - else: - raise ValueError(f"You should implement the conditonal image of this controlnet: {controlnet}") + if args.controlnet_type == "canny": + canny_image = image_module.download_image( + "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/canny.png" + ) + input_images.append(canny_image.resize((args.height, args.width))) + elif args.controlnet_type == "depth": + depth_image = image_module.download_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png" + ) + input_images.append(depth_image.resize((args.height, args.width))) + elif args.controlnet_type == "blur": + blur_image = image_module.download_image( + "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/blur.png" + ) + input_images.append(blur_image.resize((args.height, args.width))) + else: + raise ValueError(f"You should implement the conditonal image of this controlnet: {args.controlnet_type}") assert len(input_images) > 0 kwargs_run_demo = { @@ -149,7 +138,7 @@ def process_demo_args(args): "height": args.height, "width": args.width, "control_image": input_images, - "controlnet_scales": controlnet_scale, + "controlnet_scale": controlnet_scale, "batch_count": args.batch_count, "num_warmup_runs": args.num_warmup_runs, "use_cuda_graph": args.use_cuda_graph, diff --git a/demo/Diffusion/demo_diffusion/dd_argparse.py b/demo/Diffusion/demo_diffusion/dd_argparse.py index e5289153..79596ddd 100644 --- a/demo/Diffusion/demo_diffusion/dd_argparse.py +++ b/demo/Diffusion/demo_diffusion/dd_argparse.py @@ -310,36 +310,27 @@ def process_pipeline_args(args: argparse.Namespace) -> Tuple[Dict[str, Any], Dic # int8 support if args.int8 and not any(args.version.startswith(prefix) for prefix in ("xl", "1.4", "1.5", "2.1")): raise ValueError("int8 quantization is only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipelines.") - - # fp8 support validation - if args.fp8: - # Check version compatibility - supported_versions = ("xl", "1.4", "1.5", "2.1", "3.5-large") - if not (any(args.version.startswith(prefix) for prefix in supported_versions) or is_flux): - raise ValueError( - "fp8 quantization is only supported for SDXL, SD1.4, SD1.5, SD2.1, SD3.5-large and FLUX pipelines." - ) - - # Check controlnet compatibility - if hasattr(args, "controlnet_type") and args.version != "xl-1.0": + + # fp8 support + if args.fp8 and not ( + any(args.version.startswith(prefix) for prefix in ("xl", "1.4", "1.5", "2.1", "3.5-large")) or is_flux + ): + raise ValueError( + "fp8 quantization is only supported for SDXL, SD1.4, SD1.5, SD2.1, SD3.5-large and FLUX pipelines." + ) + + if args.fp8 and hasattr(args, "controlnet_type"): + if args.version != "xl-1.0": raise ValueError("fp8 controlnet quantization is only supported for SDXL.") - # Check for conflicting quantization - if args.int8: - raise ValueError("Cannot apply both int8 and fp8 quantization, please choose only one.") - - # Check GPU compute capability - if sm_version < 89: - raise ValueError( - f"Cannot apply FP8 quantization for GPU with compute capability {sm_version / 10.0}. A minimum compute capability of 8.9 is required." - ) - - # Check SD3.5-large specific requirement - if args.version == "3.5-large" and not args.download_onnx_models: - raise ValueError( - "Native FP8 quantization is not supported for SD3.5-large. Please pass --download-onnx-models." - ) - + if args.fp8 and args.int8: + raise ValueError("Cannot apply both int8 and fp8 quantization, please choose only one.") + + if args.fp8 and sm_version < 89: + raise ValueError( + f"Cannot apply FP8 quantization for GPU with compute capability {sm_version / 10.0}. Only Ada and Hopper are supported." + ) + # TensorRT ModelOpt quantization level if args.quantization_level == 0.0: def override_quant_level(level: float, dtype_str: str): diff --git a/demo/Diffusion/demo_diffusion/model/__init__.py b/demo/Diffusion/demo_diffusion/model/__init__.py index f3c51461..a74e274f 100644 --- a/demo/Diffusion/demo_diffusion/model/__init__.py +++ b/demo/Diffusion/demo_diffusion/model/__init__.py @@ -30,8 +30,8 @@ FluxTransformerModel, SD3_MMDiTModel, SD3TransformerModel, - SD3TransformerModelControlNet, ) +from demo_diffusion.model.controlnet import SD3ControlNet from demo_diffusion.model.gan import VQGANModel from demo_diffusion.model.load import unload_torch_model from demo_diffusion.model.lora import FLUXLoraLoader, SDLoraLoader, merge_loras @@ -71,7 +71,7 @@ "SD3_MMDiTModel", "FluxTransformerModel", "SD3TransformerModel", - "SD3TransformerModelControlNet", + "SD3ControlNet", # gan "VQGANModel", # lora diff --git a/demo/Diffusion/demo_diffusion/model/base_model.py b/demo/Diffusion/demo_diffusion/model/base_model.py index 175170f7..d99d8eb1 100644 --- a/demo/Diffusion/demo_diffusion/model/base_model.py +++ b/demo/Diffusion/demo_diffusion/model/base_model.py @@ -42,6 +42,7 @@ def __init__( bf16=False, int8=False, fp8=False, + fp4=False, max_batch_size=16, text_maxlen=77, embedding_dim=768, @@ -63,6 +64,7 @@ def __init__( self.bf16 = bf16 self.int8 = int8 self.fp8 = fp8 + self.fp4 = fp4 self.compression_factor = compression_factor self.min_batch = 1 diff --git a/demo/Diffusion/demo_diffusion/model/controlnet.py b/demo/Diffusion/demo_diffusion/model/controlnet.py new file mode 100644 index 00000000..d87aa946 --- /dev/null +++ b/demo/Diffusion/demo_diffusion/model/controlnet.py @@ -0,0 +1,225 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import os + +import torch + +from demo_diffusion.dynamic_import import import_from_diffusers +from demo_diffusion.model import base_model, load, optimizer + +# List of models to import from diffusers.models +models_to_import = ["SD3Transformer2DModel", "SD3ControlNetModel"] +for model in models_to_import: + globals()[model] = import_from_diffusers(model, "diffusers.models") + + +class SD3ControlNetWrapper(torch.nn.Module): + def __init__(self, controlnet): + super().__init__() + self.controlnet = controlnet + + def forward(self, hidden_states, controlnet_cond, conditioning_scale, pooled_projections, timestep): + params = { + "hidden_states": hidden_states, + "pooled_projections": pooled_projections, + "timestep": timestep, + "controlnet_cond": controlnet_cond, + "conditioning_scale": conditioning_scale, + } + out = self.controlnet(**params)["controlnet_block_samples"] + return torch.stack(out, dim=0) + + +class SD3ControlNet(base_model.BaseModel): + + def __init__( + self, + version, + controlnet, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + tf32=False, + bf16=False, + int8=False, + fp8=False, + max_batch_size=16, + build_strongly_typed=False, + do_classifier_free_guidance=False, + ): + super(SD3ControlNet, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + tf32=tf32, + bf16=bf16, + int8=int8, + fp8=fp8, + max_batch_size=max_batch_size, + ) + self.path = load.get_path(version, pipeline, controlnet) + self.subfolder = "controlnet_{}".format(controlnet) + self.controlnet_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + self.transformer_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, "transformer" + ) + if not os.path.exists(self.controlnet_model_dir): + self.config = SD3ControlNetModel.load_config(self.path, token=self.hf_token) + else: + print(f"[I] Load SD3ControlNetModel config from: {self.controlnet_model_dir}") + self.config = SD3ControlNetModel.load_config(self.controlnet_model_dir) + self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier + self.build_strongly_typed = build_strongly_typed + + def get_model(self, torch_inference=""): + model_opts = ( + {"torch_dtype": torch.float16} if self.fp16 else {"torch_dtype": torch.bfloat16} if self.bf16 else {} + ) + if not load.is_model_cached(self.controlnet_model_dir, model_opts, self.hf_safetensor): + model = SD3ControlNetModel.from_pretrained(self.path, **model_opts, use_safetensors=self.hf_safetensor).to( + self.device + ) + model.save_pretrained(self.controlnet_model_dir, **model_opts) + else: + print(f"[I] Load SD3ControlNetModel model from: {self.controlnet_model_dir}") + model = SD3ControlNetModel.from_pretrained(self.controlnet_model_dir, **model_opts).to(self.device) + + # Load transformer model for pos_embed + transformer = SD3Transformer2DModel.from_pretrained(self.transformer_model_dir, **model_opts).to(self.device) + + if hasattr(model.config, "use_pos_embed") and model.config.use_pos_embed is False: + pos_embed = model._get_pos_embed_from_transformer(transformer) + model.pos_embed = pos_embed.to(model.dtype).to(model.device) + # Free transformer model + del transformer + + model = optimizer.optimize_checkpoint(model, torch_inference) + model = SD3ControlNetWrapper(model) + return model + + def get_input_names(self): + return ["hidden_states", "controlnet_cond", "conditioning_scale", "pooled_projections", "timestep"] + + def get_output_names(self): + return ["controlnet_block_samples"] + + def get_dynamic_axes(self): + xB = "2B" if self.xB == 2 else "B" + dynamic_axes = { + "hidden_states": {0: xB, 2: "H", 3: "W"}, + "controlnet_cond": {0: xB, 2: "H", 3: "W"}, + "pooled_projections": {0: xB}, + "timestep": {0: xB}, + } + return dynamic_axes + + def get_input_profile( + self, + batch_size: int, + image_height: int, + image_width: int, + static_batch: bool, + static_shape: bool, + ): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + + input_profile = { + "hidden_states": [ + (self.xB * min_batch, self.config["in_channels"], min_latent_height, min_latent_width), + (self.xB * batch_size, self.config["in_channels"], latent_height, latent_width), + (self.xB * max_batch, self.config["in_channels"], max_latent_height, max_latent_width), + ], + "timestep": [(self.xB * min_batch,), (self.xB * batch_size,), (self.xB * max_batch,)], + "pooled_projections": [ + (self.xB * min_batch, self.config["pooled_projection_dim"]), + (self.xB * batch_size, self.config["pooled_projection_dim"]), + (self.xB * max_batch, self.config["pooled_projection_dim"]), + ], + "controlnet_cond": [ + (self.xB * min_batch, self.config["in_channels"], min_latent_height, min_latent_width), + (self.xB * batch_size, self.config["in_channels"], latent_height, latent_width), + (self.xB * max_batch, self.config["in_channels"], max_latent_height, max_latent_width), + ], + } + return input_profile + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + shape_dict = { + "hidden_states": (self.xB * batch_size, self.config["in_channels"], latent_height, latent_width), + "timestep": (self.xB * batch_size,), + "pooled_projections": (self.xB * batch_size, self.config["pooled_projection_dim"]), + "controlnet_cond": (self.xB * batch_size, self.config["in_channels"], latent_height, latent_width), + "conditioning_scale": (1,), + "controlnet_block_samples": ( + self.config["num_layers"], + self.xB * batch_size, + latent_height // 2 * latent_width // 2, + self.config["num_attention_heads"] * self.config["attention_head_dim"], + ), + } + return shape_dict + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + dtype = torch.float16 if self.fp16 else torch.bfloat16 if self.bf16 else torch.float32 + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + sample_input = ( + torch.randn( + self.xB * batch_size, + self.config["in_channels"], + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.randn( + self.xB * batch_size, + self.config["in_channels"], + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.tensor(1.0, dtype=dtype, device=self.device), + torch.randn(self.xB * batch_size, self.config["pooled_projection_dim"], dtype=dtype, device=self.device), + torch.randn(self.xB * batch_size, dtype=torch.float32, device=self.device), + ) + + return sample_input diff --git a/demo/Diffusion/demo_diffusion/model/diffusion_transformer.py b/demo/Diffusion/demo_diffusion/model/diffusion_transformer.py index 5feb07c0..7f773ced 100644 --- a/demo/Diffusion/demo_diffusion/model/diffusion_transformer.py +++ b/demo/Diffusion/demo_diffusion/model/diffusion_transformer.py @@ -27,7 +27,7 @@ from demo_diffusion.utils_sd3.sd3_impls import BaseModel as BaseModelSD3 # List of models to import from diffusers.models -models_to_import = ["FluxTransformer2DModel", "SD3Transformer2DModel", "SD3ControlNetModel"] +models_to_import = ["FluxTransformer2DModel", "SD3Transformer2DModel"] for model in models_to_import: globals()[model] = import_from_diffusers(model, "diffusers.models") @@ -323,52 +323,6 @@ def optimize(self, onnx_graph): return super().optimize(onnx_graph, fuse_mha_qkv_int8=True) return super().optimize(onnx_graph) - -class Transformer3DControlNetModel(torch.nn.Module): - def __init__(self, transformer, controlnets) -> None: - super().__init__() - self.transformer = transformer - self.controlnets = controlnets - - def forward( - self, - hidden_states, - encoder_hidden_states, - pooled_projections, - timestep, - controlnet_image, - controlnet_scales, - ): - for i, (scale, controlnet) in enumerate(zip(controlnet_scales, self.controlnets)): - block_samples = controlnet( - hidden_states=hidden_states, - timestep=timestep, - pooled_projections=pooled_projections, - controlnet_cond=controlnet_image, - conditioning_scale=scale, - return_dict=False, - )[0] - - # merge samples - if i == 0: - control_block_samples = block_samples - else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) - ] - control_block_samples = (tuple(control_block_samples),) - - noise_pred = self.transformer( - hidden_states=hidden_states, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - pooled_projections=pooled_projections, - block_controlnet_hidden_states=control_block_samples, - ) - return noise_pred - - class UpcastLayer(torch.nn.Module): def __init__(self, base_layer: torch.nn.Module, upcast_to: torch.dtype): super().__init__() @@ -409,6 +363,9 @@ def __init__( fp16=False, tf32=False, bf16=False, + fp8=False, + int8=False, + fp4=False, max_batch_size=16, text_maxlen=256, build_strongly_typed=False, @@ -426,6 +383,9 @@ def __init__( fp16=fp16, tf32=tf32, bf16=bf16, + fp8=fp8, + int8=int8, + fp4=fp4, max_batch_size=max_batch_size, text_maxlen=text_maxlen, ) @@ -443,6 +403,7 @@ def __init__( self.weight_streaming_budget_percentage = weight_streaming_budget_percentage self.out_channels = self.config.get("out_channels") self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier + self.num_controlnet_layers = 19 # Can be queried from the ControlNet model config def get_model(self, torch_inference=""): model_opts = ( @@ -470,12 +431,15 @@ def get_model(self, torch_inference=""): return model def get_input_names(self): - return [ + input_names = [ "hidden_states", "encoder_hidden_states", "pooled_projections", "timestep", ] + if not self.fp8: + input_names.append("block_controlnet_hidden_states") + return input_names def get_output_names(self): return ["latent"] @@ -489,6 +453,8 @@ def get_dynamic_axes(self): "timestep": {0: xB}, "latent": {0: xB, 2: "H", 3: "W"}, } + if not self.fp8: + dynamic_axes["block_controlnet_hidden_states"] = {1: xB, 2: "latent_dim"} return dynamic_axes def get_input_profile( @@ -531,6 +497,28 @@ def get_input_profile( ], "timestep": [(self.xB * min_batch,), (self.xB * batch_size,), (self.xB * max_batch,)], } + if not self.fp8: + input_profile["block_controlnet_hidden_states"] = [ + ( + self.num_controlnet_layers, + self.xB * min_batch, + min_latent_height // self.config["patch_size"] * min_latent_width // self.config["patch_size"], + self.config["num_attention_heads"] * self.config["attention_head_dim"], + ), + ( + self.num_controlnet_layers, + self.xB * batch_size, + latent_height // self.config["patch_size"] * latent_width // self.config["patch_size"], + self.config["num_attention_heads"] * self.config["attention_head_dim"], + ), + ( + self.num_controlnet_layers, + self.xB * max_batch, + max_latent_height // self.config["patch_size"] * max_latent_width // self.config["patch_size"], + self.config["num_attention_heads"] * self.config["attention_head_dim"], + ), + ] + return input_profile def get_shape_dict(self, batch_size, image_height, image_width): @@ -542,6 +530,13 @@ def get_shape_dict(self, batch_size, image_height, image_width): "timestep": (self.xB * batch_size,), "latent": (self.xB * batch_size, self.out_channels, latent_height, latent_width), } + if not self.fp8: + shape_dict["block_controlnet_hidden_states"] = ( + self.num_controlnet_layers, + self.xB * batch_size, + latent_height // self.config["patch_size"] * latent_width // self.config["patch_size"], + self.config["num_attention_heads"] * self.config["attention_head_dim"], + ) return shape_dict def get_sample_input(self, batch_size, image_height, image_width, static_shape): @@ -567,130 +562,18 @@ def get_sample_input(self, batch_size, image_height, image_width, static_shape): torch.randn(self.xB * batch_size, self.config["pooled_projection_dim"], dtype=dtype, device=self.device), torch.randn(self.xB * batch_size, dtype=torch.float32, device=self.device), ) - return sample_input - - -class SD3TransformerModelControlNet(SD3TransformerModel): - def __init__( - self, - version, - pipeline, - device, - hf_token, - verbose, - framework_model_dir, - fp16=False, - tf32=False, - bf16=False, - max_batch_size=16, - text_maxlen=256, - build_strongly_typed=False, - weight_streaming=False, - weight_streaming_budget_percentage=None, - do_classifier_free_guidance=False, - controlnets=None, - ): - super(SD3TransformerModelControlNet, self).__init__( - version, - pipeline, - device=device, - hf_token=hf_token, - verbose=verbose, - framework_model_dir=framework_model_dir, - fp16=fp16, - tf32=tf32, - bf16=bf16, - max_batch_size=max_batch_size, - text_maxlen=text_maxlen, - build_strongly_typed=build_strongly_typed, - weight_streaming=weight_streaming, - weight_streaming_budget_percentage=weight_streaming_budget_percentage, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - self.controlnets = load.get_path(version, pipeline, controlnets) if controlnets else None - - def get_model(self, torch_inference=""): - model = super().get_model(torch_inference) - cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 else {"torch_dtype": torch.bfloat16} if self.bf16 else {} - controlnets = torch.nn.ModuleList( - [SD3ControlNetModel.from_pretrained(path, **cnet_model_opts, use_safetensors=self.hf_safetensor).to(self.device) for path in self.controlnets] - ) - for controlnet in controlnets: - if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False: - pos_embed = controlnet._get_pos_embed_from_transformer(model) - controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device) - model = Transformer3DControlNetModel(model, controlnets) - model = optimizer.optimize_checkpoint(model, torch_inference) - return model - - def get_input_names(self): - return super().get_input_names() + ["controlnet_image", "controlnet_scales"] - - def get_dynamic_axes(self): - xB = "2B" if self.xB == 2 else "B" - dynamic_axes = super().get_dynamic_axes() - dynamic_axes.update({ - "controlnet_image": {0: xB, 2: "H", 3: "W"}, - "controlnet_scales": {0: "S"} - }) - return dynamic_axes - - def get_input_profile( - self, - batch_size: int, - image_height: int, - image_width: int, - static_batch: bool, - static_shape: bool, - ): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) - - input_profile = super().get_input_profile(batch_size, image_height, image_width, static_batch, static_shape) - input_profile.update({ - "controlnet_image": [ - (self.xB * min_batch, self.config["in_channels"], min_latent_height, min_latent_width), - (self.xB * batch_size, self.config["in_channels"], latent_height, latent_width), - (self.xB * max_batch, self.config["in_channels"], max_latent_height, max_latent_width), - ], - "controlnet_scales": [(len(self.controlnets) * min_batch,), (len(self.controlnets) * batch_size,), (len(self.controlnets) * max_batch,)], - }) - return input_profile - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - shape_dict = super().get_shape_dict(batch_size, image_height, image_width) - shape_dict.update({ - "controlnet_image": (self.xB * batch_size, self.config["in_channels"], latent_height, latent_width), - "controlnet_scales": (len(self.controlnets),), - }) - return shape_dict - - def get_sample_input(self, batch_size, image_height, image_width, static_shape): - dtype = torch.float16 if self.fp16 else torch.bfloat16 if self.bf16 else torch.float32 - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - sample_input = super().get_sample_input(batch_size, image_height, image_width, static_shape) - sample_input += ( - torch.randn( - self.xB * batch_size, - self.config["in_channels"], - latent_height, - latent_width, - dtype=dtype, - device=self.device, - ), - torch.randn(len(self.controlnets), dtype=dtype, device=self.device), - ) + if not self.fp8: + sample_input += ( + { + "block_controlnet_hidden_states": torch.randn( + self.num_controlnet_layers, + self.xB * batch_size, + latent_height // self.config["patch_size"] * latent_width // self.config["patch_size"], + self.config["num_attention_heads"] * self.config["attention_head_dim"], + dtype=dtype, + device=self.device, + ), + } + ) return sample_input diff --git a/demo/Diffusion/demo_diffusion/model/load.py b/demo/Diffusion/demo_diffusion/model/load.py index 4716a298..1848ba1a 100644 --- a/demo/Diffusion/demo_diffusion/model/load.py +++ b/demo/Diffusion/demo_diffusion/model/load.py @@ -25,9 +25,10 @@ import sys from typing import List, Optional -import onnx import torch +import onnx + def onnx_graph_needs_external_data(onnx_graph: onnx.ModelProto) -> bool: """Return true if ONNX graph needs to store external data.""" @@ -45,7 +46,7 @@ def get_path(version: str, pipeline: "pipeline.DiffusionPipeline", controlnets: if version == "xl-1.0": return ["diffusers/controlnet-canny-sdxl-1.0"] if version == "3.5-large": - return [f"stabilityai/stable-diffusion-3.5-large-controlnet-{modality}" for modality in controlnets] + return f"stabilityai/stable-diffusion-3.5-large-controlnet-{controlnets}" return ["lllyasviel/sd-controlnet-" + modality for modality in controlnets] if version in ("1.4", "1.5") and pipeline.is_inpaint(): diff --git a/demo/Diffusion/demo_diffusion/pipeline/diffusion_pipeline.py b/demo/Diffusion/demo_diffusion/pipeline/diffusion_pipeline.py index 626b903b..217499e1 100755 --- a/demo/Diffusion/demo_diffusion/pipeline/diffusion_pipeline.py +++ b/demo/Diffusion/demo_diffusion/pipeline/diffusion_pipeline.py @@ -143,6 +143,7 @@ def __init__( weight_streaming=False, text_encoder_weight_streaming_budget_percentage=None, denoiser_weight_streaming_budget_percentage=None, + controlnet=None, ): """ Initializes the Diffusion pipeline. @@ -190,6 +191,8 @@ def __init__( Weight streaming budget as a percentage of the size of total streamable weights for the text encoder model. denoiser_weight_streaming_budget_percentage (`int`, defaults to None): Weight streaming budget as a percentage of the size of total streamable weights for the denoiser model. + controlnet (str, defaults to None): + Type of ControlNet to use for the pipeline. """ self.bf16 = bf16 self.dd_path = dd_path @@ -218,7 +221,7 @@ def __init__( self.text_encoder_weight_streaming_budget_percentage = text_encoder_weight_streaming_budget_percentage self.denoiser_weight_streaming_budget_percentage = denoiser_weight_streaming_budget_percentage - self.stages = self.get_model_names(self.pipeline_type) + self.stages = self.get_model_names(self.pipeline_type, controlnet) # config to store additional info self.config = {} if torch_fallback: @@ -248,7 +251,9 @@ def __init__( try: scheduler_class = scheduler_class_map[scheduler] except KeyError: - raise ValueError(f"Unsupported scheduler {scheduler}. Should be one of {list(scheduler_class.keys())}.") + raise ValueError( + f"Unsupported scheduler {scheduler}. Should be one of {list(scheduler_class_map.keys())}." + ) self.scheduler = make_scheduler(scheduler_class, version, pipeline_type, hf_token, framework_model_dir) self.torch_inference = torch_inference @@ -286,7 +291,7 @@ def FromArgs(cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE) -> Dif @classmethod @abc.abstractmethod - def get_model_names(cls, pipeline_type: PIPELINE_TYPE) -> List[str]: + def get_model_names(cls, pipeline_type: PIPELINE_TYPE, controlnet_type: str = None) -> List[str]: """Return a list of model names used by this pipeline.""" raise NotImplementedError("get_model_names cannot be called from the abstract base class.") diff --git a/demo/Diffusion/demo_diffusion/pipeline/flux_pipeline.py b/demo/Diffusion/demo_diffusion/pipeline/flux_pipeline.py index 57241c61..cd97ae1e 100644 --- a/demo/Diffusion/demo_diffusion/pipeline/flux_pipeline.py +++ b/demo/Diffusion/demo_diffusion/pipeline/flux_pipeline.py @@ -163,7 +163,7 @@ def FromArgs(cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE) -> Flu ) @classmethod - def get_model_names(cls, pipeline_type: PIPELINE_TYPE) -> List[str]: + def get_model_names(cls, pipeline_type: PIPELINE_TYPE, controlnet_type: str = None) -> List[str]: """Return a list of model names used by this pipeline. Overrides: diff --git a/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_35_pipeline.py b/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_35_pipeline.py index 7dac9771..4ea2e3b2 100644 --- a/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_35_pipeline.py +++ b/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_35_pipeline.py @@ -20,6 +20,7 @@ import inspect import os import time +import warnings from typing import Any, List, Union import tensorrt as trt @@ -32,8 +33,8 @@ from demo_diffusion import path as path_module from demo_diffusion.model import ( CLIPWithProjModel, + SD3ControlNet, SD3TransformerModel, - SD3TransformerModelControlNet, T5Model, VAEEncoderModel, VAEModel, @@ -68,7 +69,7 @@ def __init__( pipeline_type=PIPELINE_TYPE.TXT2IMG, guidance_scale: float = 7.0, max_sequence_length: int = 256, - controlnets=None, + controlnet=None, **kwargs, ): """ @@ -84,20 +85,18 @@ def __init__( Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality. max_sequence_length (`int`, defaults to 256): Maximum sequence length to use with the `prompt`. - controlnets (str): - Which ControlNet/ControlNets to use. + controlnet (str): + Which ControlNet to use. """ - super().__init__( - version=version, - pipeline_type=pipeline_type, - **kwargs - ) + super().__init__(version=version, pipeline_type=pipeline_type, controlnet=controlnet, **kwargs) self.fp16 = True if not self.bf16 else False self.force_weakly_typed_t5 = False + self.config["clip_g_torch_fallback"] = True + self.config["clip_l_torch_fallback"] = True self.config["clip_hidden_states"] = True - self.controlnets = controlnets + self.controlnet = controlnet self.guidance_scale = guidance_scale self.do_classifier_free_guidance = self.guidance_scale > 1 @@ -115,8 +114,12 @@ def FromArgs(cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE) -> Sta DO_RETURN_LATENTS = False # Resolve all paths. + controlnet_type = args.controlnet_type if "controlnet_type" in args else None dd_path = path_module.resolve_path( - cls.get_model_names(pipeline_type), args, pipeline_type, cls._get_pipeline_uid(args.version) + cls.get_model_names(pipeline_type, controlnet_type), + args, + pipeline_type, + cls._get_pipeline_uid(args.version), ) return cls( @@ -125,7 +128,7 @@ def FromArgs(cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE) -> Sta pipeline_type=pipeline_type, guidance_scale=args.guidance_scale, max_sequence_length=args.max_sequence_length, - controlnets=args.controlnet_type if "controlnet_type" in args else None, + controlnet=controlnet_type, bf16=args.bf16, low_vram=args.low_vram, torch_fallback=args.torch_fallback, @@ -145,14 +148,15 @@ def FromArgs(cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE) -> Sta ) @classmethod - def get_model_names(cls, pipeline_type: PIPELINE_TYPE) -> List[str]: + def get_model_names(cls, pipeline_type: PIPELINE_TYPE, controlnet_type: str = None) -> List[str]: """Return a list of model names used by this pipeline. Overrides: DiffusionPipeline.get_model_names """ if pipeline_type.is_controlnet(): - return ["clip_l", "clip_g", "t5", "transformer", "vae", "vae_encoder"] + assert controlnet_type, "ControlNet type must be specified for ControlNet pipelines" + return ["clip_l", "clip_g", "t5", "transformer", "vae", "vae_encoder", f"controlnet_{controlnet_type}"] return ["clip_l", "clip_g", "t5", "transformer", "vae"] def download_onnx_models(self, model_name: str, model_config: dict[str, Any]) -> None: @@ -160,12 +164,6 @@ def download_onnx_models(self, model_name: str, model_config: dict[str, Any]) -> raise ValueError( "ONNX models can be downloaded only for the following precisions: BF16, FP8. This pipeline is running in FP16." ) - if self.version == "3.5-medium": - raise ValueError( - "ONNX models can be downloaded only for the large variant of Stable Diffusion 3.5. This pipeline is running the medium variant." - ) - if self.controlnets: - raise ValueError("ONNX model download is not supported for ControlNet models.") hf_download_path = "-".join([load.get_path(self.version, self.pipeline_type.name), "tensorrt"]) model_path = model_config["onnx_opt_path"] @@ -177,6 +175,9 @@ def download_onnx_models(self, model_name: str, model_config: dict[str, Any]) -> dirname = os.path.join(model_name, "fp8") elif self.bf16: dirname = os.path.join(model_name, "bf16") + elif "controlnet" in model_name: + hf_download_path_cnet = hf_download_path.replace("large", "controlnets") + dirname = f"controlnet_{self.controlnet}" elif model_name in self.stages: dirname = model_name else: @@ -184,7 +185,7 @@ def download_onnx_models(self, model_name: str, model_config: dict[str, Any]) -> dirname = os.path.join("ONNX", dirname) snapshot_download( - repo_id=hf_download_path, + repo_id=hf_download_path if "controlnet" not in model_name else hf_download_path_cnet, allow_patterns=os.path.join(dirname, "*"), local_dir=base_dir, token=self.hf_token, @@ -241,6 +242,9 @@ def _initialize_models(self, framework_model_dir, int8, fp8, fp4): self.bf16 = True if int8 or fp8 or fp4 else self.bf16 self.fp16 = True if not self.bf16 else False self.tf32=True + self.fp8 = fp8 + self.int8 = int8 + self.fp4 = fp4 if "clip_l" in self.stages: self.models["clip_l"] = CLIPWithProjModel( **models_args, @@ -274,26 +278,29 @@ def _initialize_models(self, framework_model_dir, int8, fp8, fp4): ) if "transformer" in self.stages: - transformer_args = { - "bf16": self.bf16, - "fp16": self.fp16, - "tf32": self.tf32, - "text_maxlen": self.models["t5"].text_maxlen + self.models["clip_g"].text_maxlen, - "build_strongly_typed": not self.controlnets, - "weight_streaming": self.weight_streaming, - "do_classifier_free_guidance": self.do_classifier_free_guidance, - } - if self.controlnets: - self.models["transformer"] = SD3TransformerModelControlNet( - **models_args, - **transformer_args, - controlnets=self.controlnets, - ) - else: - self.models["transformer"] = SD3TransformerModel( - **models_args, - **transformer_args, - ) + self.models["transformer"] = SD3TransformerModel( + **models_args, + fp16=self.fp16, + tf32=self.tf32, + bf16=self.bf16, + fp8=self.fp8, + int8=self.int8, + fp4=self.fp4, + text_maxlen=self.models["t5"].text_maxlen + self.models["clip_g"].text_maxlen, + build_strongly_typed=False, + weight_streaming=self.weight_streaming, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + if f"controlnet_{self.controlnet}" in self.stages: + self.models[f"controlnet_{self.controlnet}"] = SD3ControlNet( + **models_args, + fp16=self.fp16, + tf32=self.tf32, + bf16=self.bf16, + do_classifier_free_guidance=self.do_classifier_free_guidance, + controlnet=self.controlnet, + ) if "vae" in self.stages: self.models["vae"] = VAEModel(**models_args, fp16=self.fp16, tf32=self.tf32, bf16=self.bf16) @@ -314,7 +321,7 @@ def _initialize_models(self, framework_model_dir, int8, fp8, fp4): if "vae" in self.stages and self.models["vae"] is not None else 16 ) - if "canny" in self.controlnets: + if "canny" in self.controlnet: self.image_processor = SD3CannyImageProcessor() else: self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -387,7 +394,7 @@ def _tokenize( if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - TRT_LOGGER.warning( + warnings.warn( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) @@ -638,6 +645,15 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + def get_control_block_samples(self, params_controlnet, controlnet_name="controlnet"): + # Predict the controlnet block samples + if self.torch_inference or self.torch_fallback[controlnet_name]: + block_samples = self.torch_models[controlnet_name](**params_controlnet) + else: + block_samples = self.run_engine(controlnet_name, params_controlnet)["controlnet_block_samples"].clone() + + return block_samples + def denoise_latents( self, latents: torch.Tensor, @@ -647,7 +663,8 @@ def denoise_latents( guidance_scale: float, denoiser="transformer", control_image=None, - cond_scale=None, + controlnet_scale=None, + controlnet_keep=None, ) -> torch.Tensor: do_autocast = self.torch_inference != "" and self.models[denoiser].fp16 with torch.autocast("cuda", enabled=do_autocast): @@ -659,14 +676,45 @@ def denoise_latents( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep_inp = timestep.expand(latents_model_input.shape[0]) + controlnet_name = f"controlnet_{self.controlnet}" + if control_image is not None: + cond_scale = controlnet_scale * controlnet_keep[step_index] + + cast_to = ( + torch.float16 + if self.models[controlnet_name].fp16 + else torch.bfloat16 if self.models[controlnet_name].bf16 else torch.float32 + ) + params_controlnet = { + "hidden_states": latents_model_input, + "timestep": timestep_inp, + "pooled_projections": pooled_prompt_embeds, + "controlnet_cond": control_image, + "conditioning_scale": cond_scale.to(self.device).to(cast_to), + } + + control_block_samples = self.get_control_block_samples(params_controlnet, controlnet_name) + else: + latent_shape = latents_model_input.shape + # Initialize control block samples with zeros. Hard-coding some dimensions that can only be queried if a controlnet is used. + control_block_samples = torch.zeros( + self.models["transformer"].num_controlnet_layers, + latent_shape[0], # batch size + latent_shape[2] // 2 * latent_shape[3] // 2, + self.models["transformer"].config["num_attention_heads"] + * self.models["transformer"].config["attention_head_dim"], + dtype=latents.dtype, + device=latents.device, + ) params = { "hidden_states": latents_model_input, "timestep": timestep_inp, "encoder_hidden_states": prompt_embeds, "pooled_projections": pooled_prompt_embeds, } - if control_image is not None: - params.update({"controlnet_image": control_image, "controlnet_scales": cond_scale}) + if not self.fp8: + params["block_controlnet_hidden_states"] = control_block_samples + # Predict the noise residual if self.torch_inference or self.torch_fallback[denoiser]: noise_pred = self.torch_models[denoiser](**params)["sample"] @@ -743,7 +791,7 @@ def infer( image_height: int, image_width: int, control_image=None, - controlnet_scales=None, + controlnet_scale=None, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, warmup=False, @@ -763,9 +811,8 @@ def infer( Width (in pixels) of the image to be generated. Must be a multiple of 8. control_image (PIL.Image.Image): The control image to guide the image generation. - controlnet_scales (torch.Tensor): - A tensor which containes ControlNet scales, essential for multi ControlNet. - Must be equal to number of Controlnets. + controlnet_scale (torch.Tensor): + A tensor which contains ControlNet scale, essential for multi ControlNet. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): @@ -778,14 +825,6 @@ def infer( assert len(prompt) == len(negative_prompt) self.batch_size = len(prompt) - # controlnet guidance start and end - if self.controlnets: - mult = len(self.controlnets) - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - # Spatial dimensions of latent tensor assert image_height % (self.vae_scale_factor * self.patch_size) == 0, ( f"image height not supported {image_height}" @@ -836,23 +875,15 @@ def infer( ) # 4. Prepare control image - cond_scale = None + controlnet_keep = [] if control_image is not None: # Process controlnet_scales - controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) + for s, e in zip([control_guidance_start], [control_guidance_end]) ] - controlnet_keep.append(keeps[0] if len(self.controlnets) == 1 else keeps) - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_scales, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_scales - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] + controlnet_keep.append(keeps[0]) control_image = self.prepare_image( image=control_image, @@ -865,8 +896,9 @@ def infer( with self.model_memory_manager(["vae_encoder"], low_vram=self.low_vram): control_image = self.encode_image(control_image) - # 5 Denoise - with self.model_memory_manager(["transformer"], low_vram=self.low_vram): + # 5. Denoise + denoiser_list = ["transformer", f"controlnet_{self.controlnet}"] if self.controlnet else ["transformer"] + with self.model_memory_manager(denoiser_list, low_vram=self.low_vram): latents = self.denoise_latents( latents=latents, prompt_embeds=prompt_embeds, @@ -874,7 +906,9 @@ def infer( timesteps=timesteps, guidance_scale=self.guidance_scale, control_image=control_image, - cond_scale=cond_scale, + # TODO: support multiple controlnets + controlnet_scale=controlnet_scale, + controlnet_keep=controlnet_keep, ) # 6. Decode Latents diff --git a/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_pipeline.py b/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_pipeline.py index 3ae7b687..01f6e2a8 100644 --- a/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_pipeline.py +++ b/demo/Diffusion/demo_diffusion/pipeline/stable_diffusion_pipeline.py @@ -49,10 +49,10 @@ CLIPModel, CLIPWithProjModel, SDLoraLoader, - UNet2DConditionControlNetModel, UNetModel, UNetXLModel, UNetXLModelControlNet, + UNet2DConditionControlNetModel, VAEEncoderModel, VAEModel, get_clip_embedding_dim, From 488b23171b44beafefefea0dc4c0af096e95e2b1 Mon Sep 17 00:00:00 2001 From: Asfiya Baig Date: Mon, 11 Aug 2025 13:54:25 -0700 Subject: [PATCH 2/2] Update FP8 support validation Signed-off-by: Asfiya Baig --- demo/Diffusion/demo_diffusion/dd_argparse.py | 43 ++++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/demo/Diffusion/demo_diffusion/dd_argparse.py b/demo/Diffusion/demo_diffusion/dd_argparse.py index 79596ddd..ccfeb46f 100644 --- a/demo/Diffusion/demo_diffusion/dd_argparse.py +++ b/demo/Diffusion/demo_diffusion/dd_argparse.py @@ -311,25 +311,34 @@ def process_pipeline_args(args: argparse.Namespace) -> Tuple[Dict[str, Any], Dic if args.int8 and not any(args.version.startswith(prefix) for prefix in ("xl", "1.4", "1.5", "2.1")): raise ValueError("int8 quantization is only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipelines.") - # fp8 support - if args.fp8 and not ( - any(args.version.startswith(prefix) for prefix in ("xl", "1.4", "1.5", "2.1", "3.5-large")) or is_flux - ): - raise ValueError( - "fp8 quantization is only supported for SDXL, SD1.4, SD1.5, SD2.1, SD3.5-large and FLUX pipelines." - ) - - if args.fp8 and hasattr(args, "controlnet_type"): - if args.version != "xl-1.0": + # fp8 support validation + if args.fp8: + # Check version compatibility + supported_versions = ("xl", "1.4", "1.5", "2.1", "3.5-large") + if not (any(args.version.startswith(prefix) for prefix in supported_versions) or is_flux): + raise ValueError( + "fp8 quantization is only supported for SDXL, SD1.4, SD1.5, SD2.1, SD3.5-large and FLUX pipelines." + ) + + # Check controlnet compatibility + if hasattr(args, "controlnet_type") and args.version != "xl-1.0": raise ValueError("fp8 controlnet quantization is only supported for SDXL.") - if args.fp8 and args.int8: - raise ValueError("Cannot apply both int8 and fp8 quantization, please choose only one.") - - if args.fp8 and sm_version < 89: - raise ValueError( - f"Cannot apply FP8 quantization for GPU with compute capability {sm_version / 10.0}. Only Ada and Hopper are supported." - ) + # Check for conflicting quantization + if args.int8: + raise ValueError("Cannot apply both int8 and fp8 quantization, please choose only one.") + + # Check GPU compute capability + if sm_version < 89: + raise ValueError( + f"Cannot apply FP8 quantization for GPU with compute capability {sm_version / 10.0}. A minimum compute capability of 8.9 is required." + ) + + # Check SD3.5-large specific requirement + if args.version == "3.5-large" and not args.download_onnx_models: + raise ValueError( + "Native FP8 quantization is not supported for SD3.5-large. Please pass --download-onnx-models." + ) # TensorRT ModelOpt quantization level if args.quantization_level == 0.0: