From 445b9b19ad4442a00418a728dca5fec1d6b8b654 Mon Sep 17 00:00:00 2001 From: Wil Kong Date: Mon, 10 Jun 2024 17:49:11 +0800 Subject: [PATCH] [SD] Fix SD CUDA Graph Failure (#9319) * [SD] Avoid redundant host & device sync breaks cuda graph. * Apply isort and black reformatting Signed-off-by: alpha0422 --------- Signed-off-by: alpha0422 Co-authored-by: Michal Futrega Co-authored-by: Pablo Garay --- .../stable_diffusion/diffusionmodules/openaimodel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 30ff0e1a9ff3..7f8b2fb20bff 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -1342,9 +1342,10 @@ def _forward(self, x, timesteps=None, context=None, y=None, **kwargs): if context is not None: context = context.type(torch.float16) - t_emb = timestep_embedding( - timesteps, self.model_channels, cached_embedding=self.time_embeddings.to(timesteps.device) - ) + if self.time_embeddings.device != timesteps.device: + self.time_embeddings = self.time_embeddings.to(timesteps.device) + + t_emb = timestep_embedding(timesteps, self.model_channels, cached_embedding=self.time_embeddings) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0]