-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
在扩散模型浅空间中更新扰动,将潜在向量变为图片时,需要使用
def diffusion_step(model, latents, context, t, guidance_scale):
latents_input = torch.cat([latents] * 2)
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
return latents
这时会显示显存不足,我是24G的显存
Metadata
Metadata
Assignees
Labels
No labels