From fb84060fe446cd81d311c6cc858f10777ffb96e4 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 10 Aug 2023 11:17:47 +0100 Subject: [PATCH] update --- AITemplate/ait/ait.py | 4 +- AITemplate/ait/compile/clip.py | 18 ++- AITemplate/ait/compile/controlnet.py | 12 +- AITemplate/ait/compile/unet.py | 38 ++--- AITemplate/ait/compile/util.py | 22 +-- AITemplate/ait/compile/vae.py | 16 +-- AITemplate/ait/load.py | 5 +- AITemplate/ait/modeling/clip.py | 121 +++++++++++----- AITemplate/ait/modeling/controlnet.py | 27 ++-- AITemplate/ait/modeling/resnet.py | 14 +- AITemplate/ait/modeling/unet_2d_condition.py | 142 ++++++++++++------- AITemplate/ait/modeling/unet_blocks.py | 7 +- AITemplate/ait/modeling/vae.py | 4 + AITemplate/ait/util/mapping/clip.py | 34 +---- AITemplate/ait/util/mapping/unet.py | 14 +- AITemplate/clip.py | 59 ++++++-- AITemplate/unet.py | 19 +-- AITemplate/vae.py | 15 +- 18 files changed, 325 insertions(+), 246 deletions(-) diff --git a/AITemplate/ait/ait.py b/AITemplate/ait/ait.py index 78db724..449ba3f 100644 --- a/AITemplate/ait/ait.py +++ b/AITemplate/ait/ait.py @@ -122,7 +122,7 @@ def test_unet( timesteps=timesteps_pt, encoder_hidden_states=text_embeddings_pt, benchmark=benchmark, - add_embeds=add_embeds + add_embeds=add_embeds if xl else None, ) print(output.shape) return output @@ -158,6 +158,7 @@ def test_vae( width: int = 64, dtype="float16", device="cuda", + benchmark: bool = False, ): if "vae" not in self.modules: raise ValueError("vae module not loaded") @@ -167,6 +168,7 @@ def test_vae( output = vae_inference( self.modules["vae"], vae_input=vae_input, + benchmark=benchmark, ) print(output.shape) return output diff --git a/AITemplate/ait/compile/clip.py b/AITemplate/ait/compile/clip.py index 8e6eb96..9513919 100644 --- a/AITemplate/ait/compile/clip.py +++ b/AITemplate/ait/compile/clip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import sys from aitemplate.compiler import compile_model from aitemplate.frontend import IntVar, Tensor from aitemplate.testing import detect_target @@ -30,8 +30,10 @@ def compile_clip( dim=768, num_heads=12, depth=12, - use_fp16_acc=False, - convert_conv_to_gemm=False, + output_hidden_states=False, + text_projection_dim=None, + use_fp16_acc=True, + convert_conv_to_gemm=True, act_layer="gelu", constants=True, model_name="CLIPTextModel", @@ -49,6 +51,8 @@ def compile_clip( causal=causal, mask_seq=mask_seq, act_layer=act_layer, + output_hidden_states=output_hidden_states, + text_projection_dim=text_projection_dim, ) ait_mod.name_parameter_tensor() @@ -62,17 +66,19 @@ def compile_clip( batch_size = IntVar(values=list(batch_size), name="batch_size") input_ids_ait = Tensor( - [batch_size, seqlen], name="input0", dtype="int64", is_input=True + [batch_size, seqlen], name="input_ids", dtype="int64", is_input=True ) position_ids_ait = Tensor( - [batch_size, seqlen], name="input1", dtype="int64", is_input=True + [batch_size, seqlen], name="position_ids", dtype="int64", is_input=True ) + Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait) mark_output(Y) target = detect_target( use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm ) + dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" compile_model( - Y, target, work_dir, model_name, constants=params_ait if constants else None + Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name ) diff --git a/AITemplate/ait/compile/controlnet.py b/AITemplate/ait/compile/controlnet.py index 5e0ae58..0b2f01c 100644 --- a/AITemplate/ait/compile/controlnet.py +++ b/AITemplate/ait/compile/controlnet.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import sys import torch from aitemplate.compiler import compile_model from aitemplate.frontend import IntVar, Tensor @@ -72,14 +73,14 @@ def compile_controlnet( embedding_size = IntVar(values=list(clip_chunks), name="embedding_size") latent_model_input_ait = Tensor( - [batch_size, height_d, width_d, 4], name="input0", is_input=True + [batch_size, height_d, width_d, 4], name="latent_model_input", is_input=True ) - timesteps_ait = Tensor([batch_size], name="input1", is_input=True) + timesteps_ait = Tensor([batch_size], name="timesteps", is_input=True) text_embeddings_pt_ait = Tensor( - [batch_size, embedding_size, hidden_dim], name="input2", is_input=True + [batch_size, embedding_size, hidden_dim], name="encoder_hidden_states", is_input=True ) controlnet_condition_ait = Tensor( - [batch_size, height_c, width_c, 3], name="input3", is_input=True + [batch_size, height_c, width_c, 3], name="control_hint", is_input=True ) Y = ait_mod( @@ -93,6 +94,7 @@ def compile_controlnet( target = detect_target( use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm ) + dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" compile_model( - Y, target, work_dir, model_name, constants=params_ait if constants else None + Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name, ) diff --git a/AITemplate/ait/compile/unet.py b/AITemplate/ait/compile/unet.py index 68be153..44531f2 100644 --- a/AITemplate/ait/compile/unet.py +++ b/AITemplate/ait/compile/unet.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import sys import torch from aitemplate.compiler import compile_model -from aitemplate.frontend import IntVar, Tensor +from aitemplate.frontend import IntVar, Tensor, DynamicProfileStrategy from aitemplate.testing import detect_target from ..modeling.unet_2d_condition import ( @@ -109,10 +110,9 @@ def compile_unet( pt_mod = pt_mod.eval() params_ait = map_unet(pt_mod, dim=dim, in_channels=in_channels, conv_in_key="conv_in_weight", dtype=dtype) - static_shape = width[0] == width[1] and height[0] == height[1] and batch_size[0] == batch_size[1] + static_shape = width[0] == width[1] and height[0] == height[1] if static_shape: - batch_size = batch_size[0] * 2 # double batch size for unet height = height[0] // down_factor width = width[0] // down_factor height_d = height @@ -132,20 +132,18 @@ def compile_unet( height_8_d = height_8 width_8_d = width_8 else: - batch_size = batch_size[0], batch_size[1] * 2 # double batch size for unet - batch_size = IntVar(values=list(batch_size), name="batch_size") - height = height[0] // down_factor, height[1] // down_factor - width = width[0] // down_factor, width[1] // down_factor + height = [x // down_factor for x in height] + width = [x // down_factor for x in width] height_d = IntVar(values=list(height), name="height_d") width_d = IntVar(values=list(width), name="width_d") height_1_d = IntVar(values=list(height), name="height_1_d") width_1_d = IntVar(values=list(width), name="width_1_d") - height_2 = height[0] // 2, height[1] // 2 - width_2 = width[0] // 2, width[1] // 2 - height_4 = height[0] // 4, height[1] // 4 - width_4 = width[0] // 4, width[1] // 4 - height_8 = height[0] // 8, height[1] // 8 - width_8 = width[0] // 8, width[1] // 8 + height_2 = [x // 2 for x in height] + width_2 = [x // 2 for x in width] + height_4 = [x // 4 for x in height] + width_4 = [x // 4 for x in width] + height_8 = [x // 8 for x in height] + width_8 = [x // 8 for x in width] height_2_d = IntVar(values=list(height_2), name="height_2_d") width_2_d = IntVar(values=list(width_2), name="width_2_d") height_4_d = IntVar(values=list(height_4), name="height_4_d") @@ -153,6 +151,9 @@ def compile_unet( height_8_d = IntVar(values=list(height_8), name="height_8_d") width_8_d = IntVar(values=list(width_8), name="width_8_d") + batch_size = batch_size[0], batch_size[1] * 2 # double batch size for unet + batch_size = IntVar(values=list(batch_size), name="batch_size") + if static_shape: embedding_size = 77 else: @@ -161,18 +162,18 @@ def compile_unet( latent_model_input_ait = Tensor( - [batch_size, height_d, width_d, in_channels], name="input0", is_input=True, dtype=dtype + [batch_size, height_d, width_d, in_channels], name="latent_model_input", is_input=True, dtype=dtype ) - timesteps_ait = Tensor([batch_size], name="input1", is_input=True, dtype=dtype) + timesteps_ait = Tensor([batch_size], name="timesteps", is_input=True, dtype=dtype) text_embeddings_pt_ait = Tensor( - [batch_size, embedding_size, hidden_dim], name="input2", is_input=True, dtype=dtype + [batch_size, embedding_size, hidden_dim], name="encoder_hidden_states", is_input=True, dtype=dtype ) class_labels = None #TODO: better way to handle this, enables class_labels for x4-upscaler if in_channels == 7: class_labels = Tensor( - [batch_size], name="input3", dtype="int64", is_input=True + [batch_size], name="class_labels", dtype="int64", is_input=True ) add_embeds = None @@ -287,6 +288,7 @@ def compile_unet( target = detect_target( use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm ) + dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" compile_model( - Y, target, work_dir, model_name, constants=params_ait if constants else None, do_optimize_graph=False if xl else True + Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name, ) diff --git a/AITemplate/ait/compile/util.py b/AITemplate/ait/compile/util.py index 90cc1bc..e3e9067 100644 --- a/AITemplate/ait/compile/util.py +++ b/AITemplate/ait/compile/util.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # -def mark_output(y): - if type(y) is not tuple: - y = (y,) - for i in range(len(y)): - y[i]._attrs["is_output"] = True - y[i]._attrs["name"] = "output_%d" % (i) - y_shape = [d._attrs["values"] for d in y[i]._attrs["shape"]] - print("AIT output_{} shape: {}".format(i, y_shape)) +def mark_output(ys): + if type(ys) != tuple: + ys = (ys, ) + for i in range(len(ys)): + y = ys[i] + if type(y) == tuple: + for yy in y: + y_shape = [d._attrs["values"] for d in yy._attrs["shape"]] + y_name = yy._attrs["name"] + print("AIT {} shape: {}".format(y_name, y_shape)) + else: + y_shape = [d._attrs["values"] for d in y._attrs["shape"]] + y_name = y._attrs["name"] + print("AIT {} shape: {}".format(y_name, y_shape)) diff --git a/AITemplate/ait/compile/vae.py b/AITemplate/ait/compile/vae.py index d4509f2..ef0e5ba 100644 --- a/AITemplate/ait/compile/vae.py +++ b/AITemplate/ait/compile/vae.py @@ -13,6 +13,7 @@ # limitations under the License. # +import sys import torch from aitemplate.compiler import compile_model from aitemplate.frontend import IntVar, Tensor @@ -28,8 +29,8 @@ def compile_vae( batch_size=(1, 8), height=(64, 2048), width=(64, 2048), - use_fp16_acc=False, - convert_conv_to_gemm=False, + use_fp16_acc=True, + convert_conv_to_gemm=True, model_name="AutoencoderKL", constants=True, block_out_channels=[128, 256, 512, 512], @@ -92,7 +93,7 @@ def compile_vae( ait_input = Tensor( shape=[batch_size, height_d, width_d, 3 if vae_encode else latent_channels], - name="vae_input", + name="pixels" if vae_encode else "latent", is_input=True, dtype=dtype ) @@ -100,7 +101,7 @@ def compile_vae( if vae_encode: sample = Tensor( shape=[batch_size, height_d, width_d, latent_channels], - name="vae_sample", + name="random_sample", is_input=True, dtype=dtype, ) @@ -116,10 +117,7 @@ def compile_vae( target = detect_target( use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm ) + dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so" compile_model( - Y, - target, - work_dir, - model_name, - constants=params_ait if constants else None, + Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name, ) diff --git a/AITemplate/ait/load.py b/AITemplate/ait/load.py index 1aaff34..1357030 100644 --- a/AITemplate/ait/load.py +++ b/AITemplate/ait/load.py @@ -139,8 +139,9 @@ def diffusers_unet( ): return UNet2DConditionModel.from_pretrained( hf_hub_or_path, - subfolder=subfolder, - revision=revision, + subfolder="unet" if not hf_hub_or_path.endswith("unet") else None, + variant="fp16", + use_safetensors=True, torch_dtype=torch_dtype_from_str(dtype) ) diff --git a/AITemplate/ait/modeling/clip.py b/AITemplate/ait/modeling/clip.py index eab0d9c..b03a79d 100644 --- a/AITemplate/ait/modeling/clip.py +++ b/AITemplate/ait/modeling/clip.py @@ -61,7 +61,8 @@ def __init__( self.to_k = nn.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = nn.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout, dtype=dtype) + nn.Linear(inner_dim, query_dim, dtype=dtype), + nn.Dropout(dropout, dtype=dtype), ) def forward(self, x, context=None, mask=None, residual=None): @@ -108,7 +109,9 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, dtype="float16"): + def __init__( + self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, dtype="float16" + ): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) @@ -121,7 +124,9 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, dtype="flo ) self.net = nn.Sequential( - project_in, nn.Dropout(dropout, dtype=dtype), nn.Linear(inner_dim, dim_out, dtype=dtype) + project_in, + nn.Dropout(dropout, dtype=dtype), + nn.Linear(inner_dim, dim_out, dtype=dtype), ) def forward(self, x, residual=None): @@ -145,28 +150,28 @@ def __init__( gated_ff=True, checkpoint=True, only_cross_attention=False, - dtype="float16" + dtype="float16", ): super().__init__() - self.only_cross_attention=only_cross_attention + self.only_cross_attention = only_cross_attention self.attn1 = CrossAttention( query_dim=dim, context_dim=context_dim if only_cross_attention else None, heads=n_heads, dim_head=d_head, dropout=dropout, - dtype=dtype + dtype=dtype, ) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype) if context_dim is not None: self.attn2 = CrossAttention( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - dtype=dtype - ) + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + dtype=dtype, + ) else: self.attn2 = None self.norm1 = nn.LayerNorm(dim, dtype=dtype) @@ -177,7 +182,11 @@ def __init__( self.param = (dim, n_heads, d_head, context_dim, gated_ff, checkpoint) def forward(self, x, context=None): - x = self.attn1(self.norm1(x), residual=x, context=context if self.only_cross_attention else None) + x = self.attn1( + self.norm1(x), + residual=x, + context=context if self.only_cross_attention else None, + ) if self.attn2 is not None: x = self.attn2(self.norm2(x), context=context, residual=x) x = self.ff(self.norm3(x), residual=x) @@ -185,7 +194,9 @@ def forward(self, x, context=None): def Normalize(in_channels, dtype="float16"): - return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype) + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype + ) class SpatialTransformer(nn.Module): @@ -207,7 +218,7 @@ def __init__( context_dim=None, use_linear_projection=False, only_cross_attention=False, - dtype="float16" + dtype="float16", ): super().__init__() self.in_channels = in_channels @@ -225,7 +236,13 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( - inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, only_cross_attention=only_cross_attention, dtype=dtype + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + only_cross_attention=only_cross_attention, + dtype=dtype, ) for d in range(depth) ] @@ -532,12 +549,17 @@ def forward( hidden_states = inputs_embeds for _, encoder_layer in enumerate(self.layers): - if output_hidden_states: + if output_hidden_states and encoder_states is not None: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer(hidden_states) hidden_states = layer_outputs - return hidden_states + last_hidden_state = hidden_states + output = last_hidden_state + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + output = encoder_states + return output class CLIPTextEmbeddings(nn.Module): @@ -596,6 +618,7 @@ class CLIPTextTransformer(nn.Module): def __init__( self, hidden_size=768, + text_projection_dim=None, output_attentions=False, output_hidden_states=False, use_return_dict=False, @@ -620,10 +643,19 @@ def __init__( act_layer=act_layer, ) self.final_layer_norm = nn.LayerNorm(hidden_size) + if text_projection_dim is not None: + self.text_projection = nn.Linear( + hidden_size, text_projection_dim, bias=False + ) + else: + self.text_projection = None self.output_attentions = output_attentions self.output_hidden_states = output_hidden_states self.use_return_dict = use_return_dict + self.hidden_size = hidden_size + self.seq_len = seq_len + self.num_layers = num_hidden_layers def forward( self, @@ -637,27 +669,40 @@ def forward( r""" Returns: """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.use_return_dict - - if input_ids is None: - raise ValueError("You have to specify either input_ids") + batch = ops.size()(input_ids)[0] hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, + encoder_output = self.encoder( + inputs_embeds=hidden_states, output_hidden_states=self.output_hidden_states ) - - last_hidden_state = encoder_outputs + if self.output_hidden_states: + last_hidden_state = encoder_output[-1] + else: + last_hidden_state = encoder_output last_hidden_state = self.final_layer_norm(last_hidden_state) - return last_hidden_state + + argmax = ops.argmax(-1)(input_ids) + pooled_output = ops.index_select(dim=1)(last_hidden_state, argmax) + pooled_output = ops.reshape()(pooled_output, [batch, self.hidden_size]) + last_hidden_state._attrs["is_output"] = True + last_hidden_state._attrs["name"] = "last_hidden_state" + pooled_output._attrs["is_output"] = True + pooled_output._attrs["name"] = "pooled_output" + output = ( + last_hidden_state, + pooled_output, + ) + if self.text_projection is not None: + text_embeds = self.text_projection(pooled_output) + text_embeds._attrs["is_output"] = True + text_embeds._attrs["name"] = "text_embeds" + output = output + (text_embeds,) + + if self.output_hidden_states: + for idx, hidden_state in enumerate(encoder_output[:-1]): + hidden_state._attrs["is_output"] = True + hidden_state._attrs["name"] = f"hidden_state_{idx}" + output = output + (hidden_state,) + + return output \ No newline at end of file diff --git a/AITemplate/ait/modeling/controlnet.py b/AITemplate/ait/modeling/controlnet.py index 2f2408a..4490087 100644 --- a/AITemplate/ait/modeling/controlnet.py +++ b/AITemplate/ait/modeling/controlnet.py @@ -254,19 +254,14 @@ def forward( ] mid_block_res_sample = mid_block_res_sample * conditioning_scale - # return (down_block_res_samples, mid_block_res_sample) - return ( - down_block_res_samples[0], - down_block_res_samples[1], - down_block_res_samples[2], - down_block_res_samples[3], - down_block_res_samples[4], - down_block_res_samples[5], - down_block_res_samples[6], - down_block_res_samples[7], - down_block_res_samples[8], - down_block_res_samples[9], - down_block_res_samples[10], - down_block_res_samples[11], - mid_block_res_sample, - ) + output = () + + for i in range(len(down_block_res_samples)): + down_block_res_samples[i]._attrs["is_output"] = True + down_block_res_samples[i]._attrs["name"] = f"down_block_res_sample_{i}" + output += (down_block_res_samples[i],) + mid_block_res_sample._attrs["is_output"] = True + mid_block_res_sample._attrs["name"] = "mid_block_res_sample" + output += (mid_block_res_sample,) + + return output diff --git a/AITemplate/ait/modeling/resnet.py b/AITemplate/ait/modeling/resnet.py index 7374c4a..5837dcf 100644 --- a/AITemplate/ait/modeling/resnet.py +++ b/AITemplate/ait/modeling/resnet.py @@ -13,7 +13,7 @@ # limitations under the License. # from aitemplate.compiler import ops -from aitemplate.frontend import nn +from aitemplate.frontend import nn, Tensor def get_shape(x): @@ -58,11 +58,17 @@ def __init__( else: self.Conv2d_0 = conv - def forward(self, x): + def forward(self, x, upsample_size=None): if self.use_conv_transpose: return self.conv(x) - - x = nn.Upsampling2d(scale_factor=2.0, mode="nearest")(x) + out = None + if upsample_size is not None: + out = ops.size()(x) + out[1] = upsample_size[1] + out[2] = upsample_size[2] + out = [x._attrs["int_var"] for x in out] + out = Tensor(out) + x = nn.Upsampling2d(scale_factor=2.0, mode="nearest")(x, out) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: diff --git a/AITemplate/ait/modeling/unet_2d_condition.py b/AITemplate/ait/modeling/unet_2d_condition.py index 5c55b83..6cc54a7 100644 --- a/AITemplate/ait/modeling/unet_2d_condition.py +++ b/AITemplate/ait/modeling/unet_2d_condition.py @@ -14,9 +14,10 @@ # from typing import Optional, Tuple, Union -from aitemplate.frontend import nn, Tensor from aitemplate.compiler import ops +from aitemplate.frontend import nn, Tensor + from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import get_down_block, get_up_block, UNetMidBlock2DCrossAttn @@ -86,19 +87,14 @@ def __init__( use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, - only_cross_attention=[ - True, - True, - True, - False - ], - conv_in_kernel = 3, + only_cross_attention=[True, True, True, False], + conv_in_kernel=3, dtype="float16", - time_embedding_dim = None, - projection_class_embeddings_input_dim = None, - addition_embed_type = None, - addition_time_embed_dim = None, - transformer_layers_per_block = 1, + time_embedding_dim=None, + projection_class_embeddings_input_dim=None, + addition_embed_type=None, + addition_time_embed_dim=None, + transformer_layers_per_block=1, ): super().__init__() self.center_input_sample = center_input_sample @@ -108,31 +104,52 @@ def __init__( # input self.in_channels = in_channels - if in_channels >= 1 and in_channels <= 4: - in_channels = 4 - elif in_channels > 4 and in_channels <= 8: - in_channels = 8 - elif in_channels > 8 and in_channels <= 12: - in_channels = 12 + if self.in_channels % 4 != 0: + in_channels = self.in_channels + (4 - (self.in_channels % 4)) + else: + in_channels = self.in_channels conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2dBias(in_channels, block_out_channels[0], 3, 1, conv_in_padding, dtype=dtype) + print("in_channels", in_channels) + if in_channels < 8: + self.conv_in = nn.Conv2dBiasFewChannels( + in_channels, block_out_channels[0], 3, 1, conv_in_padding, dtype=dtype + ) + else: + self.conv_in = nn.Conv2dBias( + in_channels, block_out_channels[0], 3, 1, conv_in_padding, dtype=dtype + ) # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift, dtype=dtype, arange_name="arange") + self.time_proj = Timesteps( + block_out_channels[0], + flip_sin_to_cos, + freq_shift, + dtype=dtype, + arange_name="arange", + ) timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, dtype=dtype) + self.time_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, dtype=dtype + ) self.class_embed_type = class_embed_type if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding([num_class_embeds, time_embed_dim], dtype=dtype) + self.class_embedding = nn.Embedding( + [num_class_embeds, time_embed_dim], dtype=dtype + ) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, dtype=dtype) + self.class_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, dtype=dtype + ) elif class_embed_type == "identity": self.class_embedding = nn.Identity(dtype=dtype) else: self.class_embedding = None if addition_embed_type == "text_time": - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim, dtype=dtype) + # self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift, dtype=dtype, arange_name="add_arange") + self.add_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim, dtype=dtype + ) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -184,7 +201,9 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + reversed_transformer_layers_per_block = list( + reversed(transformer_layers_per_block) + ) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel @@ -224,26 +243,28 @@ def __init__( dtype=dtype, ) - self.conv_out = nn.Conv2dBias(block_out_channels[0], out_channels, 3, 1, 1, dtype=dtype) + self.conv_out = nn.Conv2dBias( + block_out_channels[0], out_channels, 3, 1, 1, dtype=dtype + ) def forward( self, sample, timesteps, encoder_hidden_states, - down_block_residual_0 = None, - down_block_residual_1 = None, - down_block_residual_2 = None, - down_block_residual_3 = None, - down_block_residual_4 = None, - down_block_residual_5 = None, - down_block_residual_6 = None, - down_block_residual_7 = None, - down_block_residual_8 = None, - down_block_residual_9 = None, - down_block_residual_10 = None, - down_block_residual_11 = None, - mid_block_residual = None, + down_block_residual_0=None, + down_block_residual_1=None, + down_block_residual_2=None, + down_block_residual_3=None, + down_block_residual_4=None, + down_block_residual_5=None, + down_block_residual_6=None, + down_block_residual_7=None, + down_block_residual_8=None, + down_block_residual_9=None, + down_block_residual_10=None, + down_block_residual_11=None, + mid_block_residual=None, class_labels: Optional[Tensor] = None, add_embeds: Optional[Tensor] = None, return_dict: bool = True, @@ -284,12 +305,16 @@ def forward( emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) if self.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) - class_emb = ops.batch_gather()(self.class_embedding.weight.tensor(), class_labels) + class_emb = ops.batch_gather()( + self.class_embedding.weight.tensor(), class_labels + ) emb = emb + class_emb if add_embeds is not None: @@ -297,14 +322,10 @@ def forward( emb = emb + aug_emb # 2. pre-process - if self.in_channels < 4: - sample = ops.pad_last_dim(4, 4)(sample) - elif self.in_channels > 4 and self.in_channels < 8: - sample = ops.pad_last_dim(4, 8)(sample) - elif self.in_channels > 8 and self.in_channels < 12: - sample = ops.pad_last_dim(4, 12)(sample) - else: - sample = sample + if self.in_channels % 4 != 0: + channel_pad = self.in_channels + (4 - (self.in_channels % 4)) + sample = ops.pad_last_dim(4, channel_pad)(sample) + sample = self.conv_in(sample) # 3. down @@ -331,7 +352,9 @@ def forward( for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): - down_block_additional_residual._attrs["shape"] = down_block_res_sample._attrs["shape"] + down_block_additional_residual._attrs[ + "shape" + ] = down_block_res_sample._attrs["shape"] down_block_res_sample += down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) @@ -345,13 +368,16 @@ def forward( if mid_block_additional_residual is not None: mid_block_additional_residual._attrs["shape"] = sample._attrs["shape"] sample += mid_block_additional_residual - + upsample_size = None # 5. up - for upsample_block in self.up_blocks: + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[ : -len(upsample_block.resnets) ] + if not is_final_block: + upsample_size = ops.size()(down_block_res_samples[-1]) if ( hasattr(upsample_block, "attentions") @@ -362,10 +388,14 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, ) # 6. post-process @@ -373,4 +403,6 @@ def forward( # when running in half-precision sample = self.conv_norm_out(sample) sample = self.conv_out(sample) - return sample + sample._attrs["is_output"] = True + sample._attrs["name"] = "latent_output" + return sample \ No newline at end of file diff --git a/AITemplate/ait/modeling/unet_blocks.py b/AITemplate/ait/modeling/unet_blocks.py index f961652..33f7c06 100644 --- a/AITemplate/ait/modeling/unet_blocks.py +++ b/AITemplate/ait/modeling/unet_blocks.py @@ -585,6 +585,7 @@ def forward( res_hidden_states_tuple, temb=None, encoder_hidden_states=None, + upsample_size=None, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -599,7 +600,7 @@ def forward( if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -654,7 +655,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -667,7 +668,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states diff --git a/AITemplate/ait/modeling/vae.py b/AITemplate/ait/modeling/vae.py index 1fc4be4..5d1e931 100644 --- a/AITemplate/ait/modeling/vae.py +++ b/AITemplate/ait/modeling/vae.py @@ -268,6 +268,8 @@ def __init__( def decode(self, z: Tensor, return_dict: bool = True): z = self.post_quant_conv(z) dec = self.decoder(z) + dec._attrs["is_output"] = True + dec._attrs["name"] = "pixels" return dec def encode(self, x: Tensor, sample: Tensor = None, return_dict: bool = True, deterministic: bool = False): @@ -284,6 +286,8 @@ def encode(self, x: Tensor, sample: Tensor = None, return_dict: bool = True, det sample._attrs["shape"] = mean._attrs["shape"] std._attrs["shape"] = mean._attrs["shape"] z = mean + std * sample + z._attrs["is_output"] = True + z._attrs["name"] = "latent" return z diff --git a/AITemplate/ait/util/mapping/clip.py b/AITemplate/ait/util/mapping/clip.py index fb0a590..d8ba599 100644 --- a/AITemplate/ait/util/mapping/clip.py +++ b/AITemplate/ait/util/mapping/clip.py @@ -5,41 +5,12 @@ "Please install transformers with `pip install transformers` to use this script." ) +import torch from ...util import torch_dtype_from_str def map_clip(pt_mod, device="cuda", dtype="float16"): - if isinstance(pt_mod, dict): - """ - TODO: investigate whether this dependency can be removed - possibly: - * position_ids could be created in another way - * check what is missing from state dict as received here vs .named_parameters() - * create the missing tensors another way if possible - """ - if "text_model.encoder.layers.22.layer_norm1.weight" in pt_mod.keys(): - clip_text_config = CLIPTextConfig( - hidden_size=1024, - intermediate_size=4096, - num_attention_heads=16, - num_hidden_layers=23, - projection_dim=512, - hidden_act="gelu" - ) - else: - clip_text_config = CLIPTextConfig( - hidden_size=768, - intermediate_size=3072, - num_attention_heads=12, - num_hidden_layers=12, - projection_dim=768, - ) - clip_text_model = CLIPTextModel(clip_text_config) - pt_mod["text_model.embeddings.position_ids"] = clip_text_model.text_model.embeddings.get_buffer("position_ids") - clip_text_model.load_state_dict(pt_mod) - pt_params = dict(clip_text_model.named_parameters()) - else: - pt_params = dict(pt_mod.named_parameters()) + pt_params = dict(pt_mod.named_parameters()) params_ait = {} for key, arr in pt_params.items(): arr = arr.to(device, dtype=torch_dtype_from_str(dtype)) @@ -56,5 +27,4 @@ def map_clip(pt_mod, device="cuda", dtype="float16"): elif "v_proj" in name: ait_name = ait_name.replace("v_proj", "proj_v") params_ait[ait_name] = arr - return params_ait diff --git a/AITemplate/ait/util/mapping/unet.py b/AITemplate/ait/util/mapping/unet.py index 710b816..8e77ba7 100644 --- a/AITemplate/ait/util/mapping/unet.py +++ b/AITemplate/ait/util/mapping/unet.py @@ -28,17 +28,9 @@ def map_unet(pt_mod, in_channels=None, conv_in_key=None, dim=320, device="cuda", params_ait[key.replace(".", "_")] = arr if conv_in_key is not None: - if in_channels > 0 and in_channels < 4: - pad_by = 4 - in_channels - elif in_channels > 4 and in_channels < 8: - pad_by = 8 - in_channels - elif in_channels > 8 and in_channels < 12: - pad_by = 12 - in_channels - else: - pad_by = 0 - params_ait[conv_in_key] = torch.functional.F.pad( - params_ait[conv_in_key], (0, pad_by, 0, 0, 0, 0, 0, 0) - ) + if in_channels % 4 != 0: + pad_by = 4 - (in_channels % 4) + params_ait[conv_in_key] = torch.functional.F.pad(params_ait[conv_in_key], (0, pad_by)) params_ait["arange"] = ( torch.arange(start=0, end=dim // 2, dtype=torch.float32).to(device, dtype=torch_dtype_from_str(dtype)) diff --git a/AITemplate/clip.py b/AITemplate/clip.py index 4abecbe..4a2fd3b 100644 --- a/AITemplate/clip.py +++ b/AITemplate/clip.py @@ -17,32 +17,45 @@ import click import torch from aitemplate.testing import detect_target -try: - from transformers import CLIPTextModel -except ImportError: - raise ImportError( - "Please install transformers with `pip install transformers` to use this script." - ) +from transformers import CLIPTextModel, CLIPTextModelWithProjection from ait.compile.clip import compile_clip @click.command() @click.option( "--hf-hub-or-path", - default="./tmp/diffusers-pipeline/runwayml/stable-diffusion-v1-5", + default=r"runwayml/stable-diffusion-v1-5", help="the local diffusers pipeline directory or hf hub path e.g. runwayml/stable-diffusion-v1-5", ) @click.option( "--batch-size", - default=(1, 4), + default=(1, 2), type=(int, int), nargs=2, help="Minimum and maximum batch size", ) +@click.option( + "--output-hidden-states", + default=False, + type=bool, + help="Output hidden states", +) +@click.option( + "--text-projection", + default=False, + type=bool, + help="use text projection", +) @click.option( "--include-constants", - default=None, + default=False, + type=bool, help="include constants (model weights) with compiled model", ) +@click.option( + "--subfolder", + default="text_encoder", + help="subfolder of hf repo or path. default `text_encoder`, this is `text_encoder_2` for SDXL.", +) @click.option("--use-fp16-acc", default=True, help="use fp16 accumulation") @click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm") @click.option("--model-name", default="CLIPTextModel", help="module name") @@ -50,7 +63,10 @@ def compile_diffusers( hf_hub_or_path, batch_size, + output_hidden_states, + text_projection, include_constants, + subfolder="text_encoder", use_fp16_acc=True, convert_conv_to_gemm=True, model_name="CLIPTextModel", @@ -62,11 +78,22 @@ def compile_diffusers( if detect_target().name() == "rocm": convert_conv_to_gemm = False - pipe = CLIPTextModel.from_pretrained( - hf_hub_or_path, - subfolder="text_encoder" if not hf_hub_or_path.endswith("text_encoder") else None, - torch_dtype=torch.float16 - ).to("cuda") + if text_projection: + pipe = CLIPTextModelWithProjection.from_pretrained( + hf_hub_or_path, + subfolder=subfolder, + variant="fp16", + torch_dtype=torch.float16, + use_safetensors=True, + ).to("cuda") + else: + pipe = CLIPTextModel.from_pretrained( + hf_hub_or_path, + subfolder=subfolder, + variant="fp16", + torch_dtype=torch.float16, + use_safetensors=True, + ).to("cuda") compile_clip( pipe, @@ -74,11 +101,13 @@ def compile_diffusers( seqlen=pipe.config.max_position_embeddings, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm, + output_hidden_states=output_hidden_states, + text_projection_dim=pipe.config.projection_dim if text_projection else None, depth=pipe.config.num_hidden_layers, num_heads=pipe.config.num_attention_heads, dim=pipe.config.hidden_size, act_layer=pipe.config.hidden_act, - constants=True if include_constants else False, + constants=include_constants, model_name=model_name, work_dir=work_dir, ) diff --git a/AITemplate/unet.py b/AITemplate/unet.py index 632b5ad..507a2be 100644 --- a/AITemplate/unet.py +++ b/AITemplate/unet.py @@ -28,31 +28,31 @@ @click.command() @click.option( "--hf-hub-or-path", - default="./tmp/diffusers-pipeline/runwayml/stable-diffusion-v1-5", + default="runwayml/stable-diffusion-v1-5", help="the local diffusers pipeline directory or hf hub path e.g. runwayml/stable-diffusion-v1-5", ) @click.option( "--width", - default=(64, 2048), + default=(64, 1024), type=(int, int), nargs=2, help="Minimum and maximum width", ) @click.option( "--height", - default=(64, 2048), + default=(64, 1024), type=(int, int), nargs=2, help="Minimum and maximum height", ) @click.option( "--batch-size", - default=(1, 4), + default=(1, 1), type=(int, int), nargs=2, help="Minimum and maximum batch size", ) -@click.option("--clip-chunks", default=6, help="Maximum number of clip chunks") +@click.option("--clip-chunks", default=10, help="Maximum number of clip chunks") @click.option( "--include-constants", default=None, @@ -91,16 +91,11 @@ def compile_diffusers( if detect_target().name() == "rocm": convert_conv_to_gemm = False - assert ( - width[0] % 64 == 0 and width[1] % 64 == 0 - ), "Minimum Width and Maximum Width must be multiples of 64, otherwise, the compilation process will fail." - assert ( - height[0] % 64 == 0 and height[1] % 64 == 0 - ), "Minimum Height and Maximum Height must be multiples of 64, otherwise, the compilation process will fail." - pipe = UNet2DConditionModel.from_pretrained( hf_hub_or_path, subfolder="unet" if not hf_hub_or_path.endswith("unet") else None, + variant="fp16", + use_safetensors=True, torch_dtype=torch.float16, ).to("cuda") diff --git a/AITemplate/vae.py b/AITemplate/vae.py index 81ec2f7..033ad88 100644 --- a/AITemplate/vae.py +++ b/AITemplate/vae.py @@ -28,26 +28,26 @@ @click.command() @click.option( "--hf-hub-or-path", - default="./tmp/diffusers-pipeline/runwayml/stable-diffusion-v1-5", + default="runwayml/stable-diffusion-v1-5", help="the local diffusers pipeline directory or hf hub path e.g. runwayml/stable-diffusion-v1-5", ) @click.option( "--width", - default=(64, 2048), + default=(64, 1024), type=(int, int), nargs=2, help="Minimum and maximum width", ) @click.option( "--height", - default=(64, 2048), + default=(64, 1024), type=(int, int), nargs=2, help="Minimum and maximum height", ) @click.option( "--batch-size", - default=(1, 4), + default=(1, 1), type=(int, int), nargs=2, help="Minimum and maximum batch size", @@ -101,13 +101,6 @@ def compile_diffusers( if detect_target().name() == "rocm": convert_conv_to_gemm = False - assert ( - width[0] % 64 == 0 and width[1] % 64 == 0 - ), "Minimum Width and Maximum Width must be multiples of 64, otherwise, the compilation process will fail." - assert ( - height[0] % 64 == 0 and height[1] % 64 == 0 - ), "Minimum Height and Maximum Height must be multiples of 64, otherwise, the compilation process will fail." - pipe = AutoencoderKL.from_pretrained( hf_hub_or_path, subfolder="vae" if not hf_hub_or_path.endswith("vae") else None,