-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added normal and latent inferers for ControlNet. Added tests (copied … #439
Added normal and latent inferers for ControlNet. Added tests (copied … #439
Conversation
…from the normal inferer tests, but with the addition of controlnet support).
generative/inferers/inferer.py
Outdated
if isinstance(diffusion_model, SPADEDiffusionModelUNet): | ||
model_output = diffusion_model( | ||
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg, | ||
down_block_additional_residuals=down_block_res_samples, | ||
mid_block_additional_residual=mid_block_res_sample | ||
) | ||
else: | ||
model_output = diffusion_model( | ||
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, | ||
down_block_additional_residuals=down_block_res_samples, | ||
mid_block_additional_residual=mid_block_res_sample | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(diffusion_model, SPADEDiffusionModelUNet): | |
model_output = diffusion_model( | |
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample | |
) | |
else: | |
model_output = diffusion_model( | |
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample | |
) | |
_seg = seg if isinstance(diffusion_model, SPADEDiffusionModelUNet) else None | |
model_output = diffusion_model( | |
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=_seg, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample | |
) |
We could reduce the amount of code duplication with this sort of pattern, here and elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that exact formulation won't work as DiffusionModelUnet
doesn't have the argument seg
. But this would work:
diffusion_model = partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
(would need from functools import partial
up top)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Virginina,
Some channels in the comments. I think we can tidy up the Inferers quite a lot by using partial
.
Please also run ./runtests.sh --codeformat --autofix
and commit when you're done with all the other changes.
generative/inferers/inferer.py
Outdated
""" | ||
ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal | ||
forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. | ||
Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a line between the intro and Args
tests/test_controlnet_inferers.py
Outdated
], | ||
] | ||
|
||
class CN_TestDiffusionSamplingInferer(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to use CapsWord convention, i.e. remove the underscore. This gets picked up when you run formatting tests with ./runtests.sh --codeformat --autofix
tests/test_controlnet_inferers.py
Outdated
) | ||
self.assertEqual(len(intermediates), 10) | ||
|
||
class LCN_TestDiffusionSamplingInferer(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CapsWords convention also required here
generative/inferers/inferer.py
Outdated
if isinstance(diffusion_model, SPADEDiffusionModelUNet): | ||
model_output = diffusion_model( | ||
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, seg=seg, | ||
down_block_additional_residuals=down_block_res_samples, | ||
mid_block_additional_residual=mid_block_res_sample | ||
) | ||
else: | ||
model_output = diffusion_model( | ||
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None, | ||
down_block_additional_residuals=down_block_res_samples, | ||
mid_block_additional_residual=mid_block_res_sample | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that exact formulation won't work as DiffusionModelUnet
doesn't have the argument seg
. But this would work:
diffusion_model = partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample
(would need from functools import partial
up top)
generative/inferers/inferer.py
Outdated
else: | ||
return total_kl | ||
|
||
def _approx_standard_normal_cdf(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we inherit from DiffusionModelInferer
we can remove this function here
generative/inferers/inferer.py
Outdated
torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) | ||
) | ||
|
||
def _get_decoder_log_likelihood( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can also be removed if we inherit from DiffusionModelInfererer
…E check. Fixed some formatting typos. Deleted two functions and changed inheritance of ControlNetDiffusionInferer. Changed names of tests to agree with caps convention.
…E check. Fixed some formatting typos. Deleted two functions and changed inheritance of ControlNetDiffusionInferer. Changed names of tests to agree with caps convention. + run autofix
…ials that had been forgotten last commit.
…from the normal inferer tests, but with the addition of controlnet support).