Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
cross_attention fix
Browse files Browse the repository at this point in the history
  • Loading branch information
FizzleDorf committed Sep 1, 2023
1 parent 10f7510 commit ab84fac
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion AITemplate/AITemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_control(self, x_noisy, t, cond, batched_number):

with precision_scope(comfy.model_management.get_autocast_device(self.device)):
comfy.model_management.load_models_gpu([self.control_model_wrapped])
context = torch.cat(cond['c_crossattn'], 1)
context = c_crossattn
y = cond.get('c_adm', None)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
comfy.model_management.unload_model_clones(self.control_model_wrapped)
Expand Down
4 changes: 2 additions & 2 deletions AITemplate/ait/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def apply_model(
mid_block_residual = None
add_embeds = None
if c_crossattn is not None:
encoder_hidden_states = torch.cat(c_crossattn, dim=1)
encoder_hidden_states = c_crossattn
if c_concat is not None:
latent_model_input = torch.cat([x] + c_concat, dim=1)
if control is not None:
Expand Down Expand Up @@ -122,7 +122,7 @@ def controlnet_inference(
if controlnet_cond.shape[0] != latent_model_input.shape[0]:
controlnet_cond = controlnet_cond.expand(latent_model_input.shape[0], -1, -1, -1)
if type(encoder_hidden_states) == dict:
encoder_hidden_states = torch.cat(encoder_hidden_states['c_crossattn'], 1)
encoder_hidden_states = encoder_hidden_states['c_crossattn']
inputs = {
"latent_model_input": latent_model_input.permute((0, 2, 3, 1))
.contiguous()
Expand Down

0 comments on commit ab84fac

Please sign in to comment.