diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 175f338b92..711bdf2d5e 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -161,16 +161,16 @@ def set_timesteps( # standard deviation of the initial noise distribution self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps).to(device) - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) - interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() - timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) - if str(device).startswith("mps"): # mps does not support float64 - self.timesteps = timesteps.to(device, dtype=torch.float32) + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: - self.timesteps = timesteps + timesteps = torch.from_numpy(timesteps).to(device) + + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() + + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 18dd976716..a46cc06052 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -149,18 +149,17 @@ def set_timesteps( # standard deviation of the initial noise distribution self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps).to(device) + if str(device).startswith("mps"): + # mps does not support float64 + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() - timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = timesteps.to(torch.float32) - else: - self.timesteps = timesteps + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None