Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Aug 10, 2023
1 parent 75ac16e commit fb84060
Show file tree
Hide file tree
Showing 18 changed files with 325 additions and 246 deletions.
4 changes: 3 additions & 1 deletion AITemplate/ait/ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_unet(
timesteps=timesteps_pt,
encoder_hidden_states=text_embeddings_pt,
benchmark=benchmark,
add_embeds=add_embeds
add_embeds=add_embeds if xl else None,
)
print(output.shape)
return output
Expand Down Expand Up @@ -158,6 +158,7 @@ def test_vae(
width: int = 64,
dtype="float16",
device="cuda",
benchmark: bool = False,
):
if "vae" not in self.modules:
raise ValueError("vae module not loaded")
Expand All @@ -167,6 +168,7 @@ def test_vae(
output = vae_inference(
self.modules["vae"],
vae_input=vae_input,
benchmark=benchmark,
)
print(output.shape)
return output
Expand Down
18 changes: 12 additions & 6 deletions AITemplate/ait/compile/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
from aitemplate.compiler import compile_model
from aitemplate.frontend import IntVar, Tensor
from aitemplate.testing import detect_target
Expand All @@ -30,8 +30,10 @@ def compile_clip(
dim=768,
num_heads=12,
depth=12,
use_fp16_acc=False,
convert_conv_to_gemm=False,
output_hidden_states=False,
text_projection_dim=None,
use_fp16_acc=True,
convert_conv_to_gemm=True,
act_layer="gelu",
constants=True,
model_name="CLIPTextModel",
Expand All @@ -49,6 +51,8 @@ def compile_clip(
causal=causal,
mask_seq=mask_seq,
act_layer=act_layer,
output_hidden_states=output_hidden_states,
text_projection_dim=text_projection_dim,
)
ait_mod.name_parameter_tensor()

Expand All @@ -62,17 +66,19 @@ def compile_clip(
batch_size = IntVar(values=list(batch_size), name="batch_size")

input_ids_ait = Tensor(
[batch_size, seqlen], name="input0", dtype="int64", is_input=True
[batch_size, seqlen], name="input_ids", dtype="int64", is_input=True
)
position_ids_ait = Tensor(
[batch_size, seqlen], name="input1", dtype="int64", is_input=True
[batch_size, seqlen], name="position_ids", dtype="int64", is_input=True
)

Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait)
mark_output(Y)

target = detect_target(
use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm
)
dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so"
compile_model(
Y, target, work_dir, model_name, constants=params_ait if constants else None
Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name
)
12 changes: 7 additions & 5 deletions AITemplate/ait/compile/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import torch
from aitemplate.compiler import compile_model
from aitemplate.frontend import IntVar, Tensor
Expand Down Expand Up @@ -72,14 +73,14 @@ def compile_controlnet(
embedding_size = IntVar(values=list(clip_chunks), name="embedding_size")

latent_model_input_ait = Tensor(
[batch_size, height_d, width_d, 4], name="input0", is_input=True
[batch_size, height_d, width_d, 4], name="latent_model_input", is_input=True
)
timesteps_ait = Tensor([batch_size], name="input1", is_input=True)
timesteps_ait = Tensor([batch_size], name="timesteps", is_input=True)
text_embeddings_pt_ait = Tensor(
[batch_size, embedding_size, hidden_dim], name="input2", is_input=True
[batch_size, embedding_size, hidden_dim], name="encoder_hidden_states", is_input=True
)
controlnet_condition_ait = Tensor(
[batch_size, height_c, width_c, 3], name="input3", is_input=True
[batch_size, height_c, width_c, 3], name="control_hint", is_input=True
)

Y = ait_mod(
Expand All @@ -93,6 +94,7 @@ def compile_controlnet(
target = detect_target(
use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm
)
dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so"
compile_model(
Y, target, work_dir, model_name, constants=params_ait if constants else None
Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name,
)
38 changes: 20 additions & 18 deletions AITemplate/ait/compile/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import torch
from aitemplate.compiler import compile_model
from aitemplate.frontend import IntVar, Tensor
from aitemplate.frontend import IntVar, Tensor, DynamicProfileStrategy
from aitemplate.testing import detect_target

from ..modeling.unet_2d_condition import (
Expand Down Expand Up @@ -109,10 +110,9 @@ def compile_unet(
pt_mod = pt_mod.eval()
params_ait = map_unet(pt_mod, dim=dim, in_channels=in_channels, conv_in_key="conv_in_weight", dtype=dtype)

static_shape = width[0] == width[1] and height[0] == height[1] and batch_size[0] == batch_size[1]
static_shape = width[0] == width[1] and height[0] == height[1]

if static_shape:
batch_size = batch_size[0] * 2 # double batch size for unet
height = height[0] // down_factor
width = width[0] // down_factor
height_d = height
Expand All @@ -132,27 +132,28 @@ def compile_unet(
height_8_d = height_8
width_8_d = width_8
else:
batch_size = batch_size[0], batch_size[1] * 2 # double batch size for unet
batch_size = IntVar(values=list(batch_size), name="batch_size")
height = height[0] // down_factor, height[1] // down_factor
width = width[0] // down_factor, width[1] // down_factor
height = [x // down_factor for x in height]
width = [x // down_factor for x in width]
height_d = IntVar(values=list(height), name="height_d")
width_d = IntVar(values=list(width), name="width_d")
height_1_d = IntVar(values=list(height), name="height_1_d")
width_1_d = IntVar(values=list(width), name="width_1_d")
height_2 = height[0] // 2, height[1] // 2
width_2 = width[0] // 2, width[1] // 2
height_4 = height[0] // 4, height[1] // 4
width_4 = width[0] // 4, width[1] // 4
height_8 = height[0] // 8, height[1] // 8
width_8 = width[0] // 8, width[1] // 8
height_2 = [x // 2 for x in height]
width_2 = [x // 2 for x in width]
height_4 = [x // 4 for x in height]
width_4 = [x // 4 for x in width]
height_8 = [x // 8 for x in height]
width_8 = [x // 8 for x in width]
height_2_d = IntVar(values=list(height_2), name="height_2_d")
width_2_d = IntVar(values=list(width_2), name="width_2_d")
height_4_d = IntVar(values=list(height_4), name="height_4_d")
width_4_d = IntVar(values=list(width_4), name="width_4_d")
height_8_d = IntVar(values=list(height_8), name="height_8_d")
width_8_d = IntVar(values=list(width_8), name="width_8_d")

batch_size = batch_size[0], batch_size[1] * 2 # double batch size for unet
batch_size = IntVar(values=list(batch_size), name="batch_size")

if static_shape:
embedding_size = 77
else:
Expand All @@ -161,18 +162,18 @@ def compile_unet(


latent_model_input_ait = Tensor(
[batch_size, height_d, width_d, in_channels], name="input0", is_input=True, dtype=dtype
[batch_size, height_d, width_d, in_channels], name="latent_model_input", is_input=True, dtype=dtype
)
timesteps_ait = Tensor([batch_size], name="input1", is_input=True, dtype=dtype)
timesteps_ait = Tensor([batch_size], name="timesteps", is_input=True, dtype=dtype)
text_embeddings_pt_ait = Tensor(
[batch_size, embedding_size, hidden_dim], name="input2", is_input=True, dtype=dtype
[batch_size, embedding_size, hidden_dim], name="encoder_hidden_states", is_input=True, dtype=dtype
)

class_labels = None
#TODO: better way to handle this, enables class_labels for x4-upscaler
if in_channels == 7:
class_labels = Tensor(
[batch_size], name="input3", dtype="int64", is_input=True
[batch_size], name="class_labels", dtype="int64", is_input=True
)

add_embeds = None
Expand Down Expand Up @@ -287,6 +288,7 @@ def compile_unet(
target = detect_target(
use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm
)
dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so"
compile_model(
Y, target, work_dir, model_name, constants=params_ait if constants else None, do_optimize_graph=False if xl else True
Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name,
)
22 changes: 14 additions & 8 deletions AITemplate/ait/compile/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
def mark_output(y):
if type(y) is not tuple:
y = (y,)
for i in range(len(y)):
y[i]._attrs["is_output"] = True
y[i]._attrs["name"] = "output_%d" % (i)
y_shape = [d._attrs["values"] for d in y[i]._attrs["shape"]]
print("AIT output_{} shape: {}".format(i, y_shape))
def mark_output(ys):
if type(ys) != tuple:
ys = (ys, )
for i in range(len(ys)):
y = ys[i]
if type(y) == tuple:
for yy in y:
y_shape = [d._attrs["values"] for d in yy._attrs["shape"]]
y_name = yy._attrs["name"]
print("AIT {} shape: {}".format(y_name, y_shape))
else:
y_shape = [d._attrs["values"] for d in y._attrs["shape"]]
y_name = y._attrs["name"]
print("AIT {} shape: {}".format(y_name, y_shape))
16 changes: 7 additions & 9 deletions AITemplate/ait/compile/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
#

import sys
import torch
from aitemplate.compiler import compile_model
from aitemplate.frontend import IntVar, Tensor
Expand All @@ -28,8 +29,8 @@ def compile_vae(
batch_size=(1, 8),
height=(64, 2048),
width=(64, 2048),
use_fp16_acc=False,
convert_conv_to_gemm=False,
use_fp16_acc=True,
convert_conv_to_gemm=True,
model_name="AutoencoderKL",
constants=True,
block_out_channels=[128, 256, 512, 512],
Expand Down Expand Up @@ -92,15 +93,15 @@ def compile_vae(

ait_input = Tensor(
shape=[batch_size, height_d, width_d, 3 if vae_encode else latent_channels],
name="vae_input",
name="pixels" if vae_encode else "latent",
is_input=True,
dtype=dtype
)
sample = None
if vae_encode:
sample = Tensor(
shape=[batch_size, height_d, width_d, latent_channels],
name="vae_sample",
name="random_sample",
is_input=True,
dtype=dtype,
)
Expand All @@ -116,10 +117,7 @@ def compile_vae(
target = detect_target(
use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm
)
dll_name = model_name + ".dll" if sys.platform == "win32" else model_name + ".so"
compile_model(
Y,
target,
work_dir,
model_name,
constants=params_ait if constants else None,
Y, target, work_dir, model_name, constants=params_ait if constants else None, dll_name=dll_name,
)
5 changes: 3 additions & 2 deletions AITemplate/ait/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def diffusers_unet(
):
return UNet2DConditionModel.from_pretrained(
hf_hub_or_path,
subfolder=subfolder,
revision=revision,
subfolder="unet" if not hf_hub_or_path.endswith("unet") else None,
variant="fp16",
use_safetensors=True,
torch_dtype=torch_dtype_from_str(dtype)
)

Expand Down

0 comments on commit fb84060

Please sign in to comment.