Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added normal and latent inferers for ControlNet. Added tests (copied … #439

Conversation

virginiafdez
Copy link
Contributor

…from the normal inferer tests, but with the addition of controlnet support).

…from the normal inferer tests, but with the addition of controlnet support).
@virginiafdez virginiafdez linked an issue Nov 27, 2023 that may be closed by this pull request
Comment on lines 686 to 697
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
else:
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
else:
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
_seg = seg if isinstance(diffusion_model, SPADEDiffusionModelUNet) else None
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=_seg,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)

We could reduce the amount of code duplication with this sort of pattern, here and elsewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that exact formulation won't work as DiffusionModelUnet doesn't have the argument seg. But this would work:

    diffusion_model = partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model
    model_output = diffusion_model(
        model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None,
        down_block_additional_residuals=down_block_res_samples,
        mid_block_additional_residual=mid_block_res_sample

(would need from functools import partial up top)

Copy link
Collaborator

@marksgraham marksgraham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Virginina,

Some channels in the comments. I think we can tidy up the Inferers quite a lot by using partial.

Please also run ./runtests.sh --codeformat --autofix and commit when you're done with all the other changes.

"""
ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal
forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning.
Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a line between the intro and Args

],
]

class CN_TestDiffusionSamplingInferer(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to use CapsWord convention, i.e. remove the underscore. This gets picked up when you run formatting tests with ./runtests.sh --codeformat --autofix

)
self.assertEqual(len(intermediates), 10)

class LCN_TestDiffusionSamplingInferer(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CapsWords convention also required here

Comment on lines 686 to 697
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
else:
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that exact formulation won't work as DiffusionModelUnet doesn't have the argument seg. But this would work:

    diffusion_model = partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model
    model_output = diffusion_model(
        model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None,
        down_block_additional_residuals=down_block_res_samples,
        mid_block_additional_residual=mid_block_res_sample

(would need from functools import partial up top)

else:
return total_kl

def _approx_standard_normal_cdf(self, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we inherit from DiffusionModelInferer we can remove this function here

torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3)))
)

def _get_decoder_log_likelihood(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can also be removed if we inherit from DiffusionModelInfererer

virginiafdez added 4 commits November 30, 2023 15:24
…E check.

Fixed some formatting typos.
Deleted two functions and changed inheritance of ControlNetDiffusionInferer.
Changed names of tests to agree with caps convention.
…E check.

Fixed some formatting typos.
Deleted two functions and changed inheritance of ControlNetDiffusionInferer.
Changed names of tests to agree with caps convention.
+ run autofix
@marksgraham marksgraham merged commit 18fef51 into main Dec 1, 2023
@marksgraham marksgraham deleted the 438-create-a-new-latentdiffusioninferer-compatible-with-controlnet branch December 1, 2023 11:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Create a new LatentDiffusionInferer compatible with ControlNet
3 participants