Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[df-if II] add additional input checks to ensure the input is divisible by 8 #7844

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -543,12 +543,27 @@ def check_inputs(

if isinstance(image, list):
image_batch_size = len(image)
# Check that each image is the same size:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to do this in a separate code block:
So we keep this section as it is to check image_batch_size

and then

if isinstance(image, list):
    check_image_size = image[0]
else:
    check_image_size = image

if isinstance(check_image_size, PIL.Image.Image):
    image_size = check_image_size.size
elif isinstance(check_image_size, torch.Tensor):
   image_size = check_image_size.shape[2:]
elif isinstanc(..., np.ndarray):
    image_size = check_image.shape[:1]

if image_size ....:
    raise ValueError(...)

The current code does not work for list of array or tensors

if not all([i.size == image[0].size for i in image]):
bghira marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("All images must be the same size")
# Check that the size is divisible by 8:
if (image[0].size[0] % 8 != 0 or image[0].size[1] % 8 != 0):
raise ValueError("Image size must be divisible by 8")
elif isinstance(image, torch.Tensor):
image_batch_size = image.shape[0]
# Check that the size is divisible by 8:
if (image.shape[2] % 8 != 0 or image.shape[3] % 8 != 0):
raise ValueError("Image size must be divisible by 8")
elif isinstance(image, PIL.Image.Image):
image_batch_size = 1
# Check that the size is divisible by 8:
if (image.size[0] % 8 != 0 or image.size[1] % 8 != 0):
raise ValueError("Image size must be divisible by 8")
elif isinstance(image, np.ndarray):
image_batch_size = image.shape[0]
# Check that the size is divisible by 8:
if (image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0):
raise ValueError("Image size must be divisible by 8")
else:
assert False

Expand Down
7 changes: 7 additions & 0 deletions tests/pipelines/deepfloyd_if/test_if_superresolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

def test_incorrect_input_size(self):
# Put an image non-divisible by 8 into the pipeline and check that it throws an Exception.
image = floats_tensor((1, 3, 31, 31), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
with self.assertRaises(ValueError):
self.pipeline(prompt="elegant destruction", image=image, generator=generator, num_inference_steps=2, output_type="np")

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
Expand Down