Skip to content

Commit

Permalink
[SD] Fix SD CUDA Graph Failure (#9319)
Browse files Browse the repository at this point in the history
* [SD] Avoid redundant host & device sync breaks cuda graph.

* Apply isort and black reformatting

Signed-off-by: alpha0422 <alpha0422@users.noreply.github.com>

---------

Signed-off-by: alpha0422 <alpha0422@users.noreply.github.com>
Co-authored-by: Michal Futrega <mfutrega@nvidia.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
  • Loading branch information
3 people committed Jun 10, 2024
1 parent b7408dc commit 445b9b1
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1342,9 +1342,10 @@ def _forward(self, x, timesteps=None, context=None, y=None, **kwargs):
if context is not None:
context = context.type(torch.float16)

t_emb = timestep_embedding(
timesteps, self.model_channels, cached_embedding=self.time_embeddings.to(timesteps.device)
)
if self.time_embeddings.device != timesteps.device:
self.time_embeddings = self.time_embeddings.to(timesteps.device)

t_emb = timestep_embedding(timesteps, self.model_channels, cached_embedding=self.time_embeddings)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
Expand Down

0 comments on commit 445b9b1

Please sign in to comment.