Skip to content

Commit

Permalink
Add a check and explanation for tensor with all NaNs.
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jan 16, 2023
1 parent 52f6e94 commit 9991967
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
28 changes: 28 additions & 0 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,33 @@ def autocast(disable=False):
return torch.autocast("cuda")


class NansException(Exception):
pass


def test_for_nans(x, where):
from modules import shared

if not torch.all(torch.isnan(x)).item():

This comment has been minimized.

Copy link
@mezotaken

mezotaken Jan 17, 2023

Collaborator

I have no idea how to pass this check with test model.
Shouldnt affect performance bcs it would only check that first element is not nan, but can we add a cl arg to disable this?

Alternative way is to download real sd1.5 model with huggingface token in github secrets.

This comment has been minimized.

Copy link
@CCRcmcpe

CCRcmcpe Jan 17, 2023

Contributor

Why torch.all instead oftorch.any?

This comment has been minimized.

Copy link
@AUTOMATIC1111

AUTOMATIC1111 Jan 20, 2023

Author Owner

My idea was that if only some tensors are nan, not all of image is corrupted, and it may be possible to salvage some use out of it.

return

if where == "unet":
message = "A tensor with all NaNs was produced in Unet."

if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."

elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."

if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
else:
message = "A tensor with all NaNs was produced."

raise NansException(message)


# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
Expand Down Expand Up @@ -156,3 +183,4 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )

3 changes: 3 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,9 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)

x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
devices.test_for_nans(x, "vae")

x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

Expand Down
2 changes: 2 additions & 0 deletions modules/sd_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):

x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})

devices.test_for_nans(x_out, "unet")

if opts.live_preview_content == "Prompt":
store_latent(x_out[0:uncond.shape[0]])
elif opts.live_preview_content == "Negative prompt":
Expand Down

0 comments on commit 9991967

Please sign in to comment.