-
Notifications
You must be signed in to change notification settings - Fork 136
/
sdxl_base_1024_compile.py
378 lines (290 loc) · 13 KB
/
sdxl_base_1024_compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_neuronx
import math
import copy
import diffusers
from diffusers import DiffusionPipeline
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.attention_processor import Attention
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from packaging import version
def apply_neuron_attn_override(
diffusers_pkg, get_attn_scores_func, neuron_scaled_dot_product_attention
):
diffusers_version = version.parse(diffusers_pkg.__version__)
use_new_diffusers = diffusers_version >= version.parse("0.18.0")
if use_new_diffusers:
diffusers_pkg.models.attention_processor.Attention.get_attention_scores = (
get_attn_scores_func
)
else:
diffusers_pkg.models.cross_attention.CrossAttention.get_attention_scores = (
get_attn_scores_func
)
# If Pytorch 2 is available, a F.scaled_dot_product_attention will be used, so we need to
# monkey patch that too to be Neuron optimized attention
if hasattr(F, "scaled_dot_product_attention"):
F.scaled_dot_product_attention = neuron_scaled_dot_product_attention
# Define datatype
DTYPE = torch.float32
# Optimized attention
def get_attention_scores_neuron(self, query, key, attn_mask):
if query.size() == key.size():
attention_scores = custom_badbmm(
key,
query.transpose(-1, -2),
self.scale
)
attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)
else:
attention_scores = custom_badbmm(
query,
key.transpose(-1, -2),
self.scale
)
attention_probs = attention_scores.softmax(dim=-1)
return attention_probs
def custom_badbmm(a, b, scale):
bmm = torch.bmm(a, b)
scaled = bmm * scale
return scaled
def neuron_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=None, is_causal=None
):
orig_shape = None
if len(query.shape) == 4:
orig_shape = query.shape
def to3d(x):
return x.reshape(-1, x.shape[2], x.shape[3])
query, key, value = map(to3d, [query, key, value])
if query.size() == key.size():
attention_scores = torch.bmm(key, query.transpose(-1, -2)) * (
1 / math.sqrt(query.size(-1))
)
attention_probs = attention_scores.softmax(dim=1).permute(0, 2, 1)
else:
attention_scores = torch.bmm(query, key.transpose(-1, -2)) * (
1 / math.sqrt(query.size(-1))
)
attention_probs = attention_scores.softmax(dim=-1)
attn_out = torch.bmm(attention_probs, value)
if orig_shape:
attn_out = attn_out.reshape(
orig_shape[0], orig_shape[1], attn_out.shape[1], attn_out.shape[2]
)
return attn_out
# Replace original cross-attention module with custom cross-attention module for better performance
apply_neuron_attn_override(
diffusers, get_attention_scores_neuron, neuron_scaled_dot_product_attention
)
class UNetWrap(nn.Module):
def __init__(self, unet):
super().__init__()
self.unet = unet
def forward(
self, sample, timestep, encoder_hidden_states, text_embeds=None, time_ids=None
):
out_tuple = self.unet(
sample,
timestep,
encoder_hidden_states,
added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids},
return_dict=False,
)
return out_tuple
class NeuronUNet(nn.Module):
def __init__(self, unetwrap):
super().__init__()
self.unetwrap = unetwrap
self.config = unetwrap.unet.config
self.in_channels = unetwrap.unet.in_channels
self.add_embedding = unetwrap.unet.add_embedding
self.device = unetwrap.unet.device
def forward(
self,
sample,
timestep,
encoder_hidden_states,
added_cond_kwargs=None,
return_dict=False,
cross_attention_kwargs=None,
):
sample = self.unetwrap(
sample,
timestep.float().expand((sample.shape[0],)),
encoder_hidden_states,
added_cond_kwargs["text_embeds"],
added_cond_kwargs["time_ids"],
)[0]
return UNet2DConditionOutput(sample=sample)
class TextEncoderOutputWrapper(nn.Module):
def __init__(self, traceable_text_encoder, original_text_encoder):
super().__init__()
self.traceable_text_encoder = traceable_text_encoder
self.config = original_text_encoder.config
self.dtype = original_text_encoder.dtype
self.device = original_text_encoder.device
def forward(self, text_input_ids, output_hidden_states=True):
out_tuple = self.traceable_text_encoder(text_input_ids)
return CLIPTextModelOutput(text_embeds=out_tuple[0], last_hidden_state=out_tuple[1], hidden_states=out_tuple[2])
class TraceableTextEncoder(nn.Module):
def __init__(self, text_encoder):
super().__init__()
self.text_encoder = text_encoder
def forward(self, text_input_ids):
out_tuple = self.text_encoder(text_input_ids, output_hidden_states=True, return_dict=False)
return out_tuple
# For saving compiler artifacts
COMPILER_WORKDIR_ROOT = 'sdxl_base_compile_dir_1024'
# Model ID for SD XL version pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
# --- Compile Text Encoders and save ---
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)
# Apply wrappers to make text encoders traceable
traceable_text_encoder = copy.deepcopy(TraceableTextEncoder(pipe.text_encoder))
traceable_text_encoder_2 = copy.deepcopy(TraceableTextEncoder(pipe.text_encoder_2))
del pipe
text_input_ids_1 = torch.tensor([[49406, 736, 1615, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
text_input_ids_2 = torch.tensor([[49406, 736, 1615, 49407, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]])
# Text Encoder 1
neuron_text_encoder = torch_neuronx.trace(
traceable_text_encoder,
text_input_ids_1,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder'),
)
text_encoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder/model.pt')
torch.jit.save(neuron_text_encoder, text_encoder_filename)
# Text Encoder 2
neuron_text_encoder_2 = torch_neuronx.trace(
traceable_text_encoder_2,
text_input_ids_2,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder_2'),
)
text_encoder_2_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder_2/model.pt')
torch.jit.save(neuron_text_encoder_2, text_encoder_2_filename)
# --- Compile Text Encoders and save ---
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
# Apply wrappers to make text encoders traceable
traceable_text_encoder = copy.deepcopy(TraceableTextEncoder(pipe.text_encoder))
traceable_text_encoder_2 = copy.deepcopy(TraceableTextEncoder(pipe.text_encoder_2))
del pipe
text_input_ids_1 = torch.tensor([[49406, 736, 1615, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
text_input_ids_2 = torch.tensor([[49406, 736, 1615, 49407, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]])
# Text Encoder 1
neuron_text_encoder = torch_neuronx.trace(
traceable_text_encoder,
text_input_ids_1,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder'),
)
text_encoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder/model.pt')
torch.jit.save(neuron_text_encoder, text_encoder_filename)
# Text Encoder 2
neuron_text_encoder_2 = torch_neuronx.trace(
traceable_text_encoder_2,
text_input_ids_2,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder_2'),
)
text_encoder_2_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder_2/model.pt')
torch.jit.save(neuron_text_encoder_2, text_encoder_2_filename)
# --- Compile UNet and save ---
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)
# Replace original cross-attention module with custom cross-attention module for better performance
Attention.get_attention_scores = get_attention_scores_neuron
# Apply double wrapper to deal with custom return type
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
# Only keep the model being compiled in RAM to minimze memory pressure
unet = copy.deepcopy(pipe.unet.unetwrap)
del pipe
# Compile unet - FP32
sample_1b = torch.randn([1, 4, 128, 128], dtype=DTYPE)
timestep_1b = torch.tensor(999, dtype=DTYPE).expand((1,))
encoder_hidden_states_1b = torch.randn([1, 77, 2048], dtype=DTYPE)
added_cond_kwargs_1b = {"text_embeds": torch.randn([1, 1280], dtype=DTYPE),
"time_ids": torch.randn([1, 6], dtype=DTYPE)}
example_inputs = (sample_1b, timestep_1b, encoder_hidden_states_1b, added_cond_kwargs_1b["text_embeds"], added_cond_kwargs_1b["time_ids"],)
unet_neuron = torch_neuronx.trace(
unet,
example_inputs,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),
compiler_args=["--model-type=unet-inference"]
)
# Enable asynchronous and lazy loading to speed up model load
torch_neuronx.async_load(unet_neuron)
torch_neuronx.lazy_load(unet_neuron)
# save compiled unet
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')
torch.jit.save(unet_neuron, unet_filename)
# delete unused objects
del unet
del unet_neuron
# --- Compile VAE decoder and save ---
# Only keep the model being compiled in RAM to minimze memory pressure
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)
decoder = copy.deepcopy(pipe.vae.decoder)
del pipe
# Compile vae decoder
decoder_in = torch.randn([1, 4, 128, 128], dtype=DTYPE)
decoder_neuron = torch_neuronx.trace(
decoder,
decoder_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder')
)
# Enable asynchronous loading to speed up model load
torch_neuronx.async_load(decoder_neuron)
# Save the compiled vae decoder
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
torch.jit.save(decoder_neuron, decoder_filename)
# delete unused objects
del decoder
del decoder_neuron
# --- Compile VAE post_quant_conv and save ---
# Only keep the model being compiled in RAM to minimze memory pressure
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)
post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
del pipe
# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 128, 128], dtype=DTYPE)
post_quant_conv_neuron = torch_neuronx.trace(
post_quant_conv,
post_quant_conv_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),
)
# Enable asynchronous loading to speed up model load
torch_neuronx.async_load(post_quant_conv_neuron)
# Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')
torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)
# delete unused objects
del post_quant_conv
del post_quant_conv_neuron