From 84338f189db5d66f1e323652b38affd943c306c8 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 20 Mar 2024 09:18:38 +0000 Subject: [PATCH 1/3] Added function to load the state_dict from the diffusion model into the controlnet, informing the user - if required - of matched and unmatched layers. --- generative/networks/nets/controlnet.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/generative/networks/nets/controlnet.py b/generative/networks/nets/controlnet.py index ebe2459c..2810fac7 100644 --- a/generative/networks/nets/controlnet.py +++ b/generative/networks/nets/controlnet.py @@ -41,7 +41,6 @@ from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding - class ControlNetConditioningEmbedding(nn.Module): """ Network to encode the conditioning into a latent space. @@ -121,6 +120,29 @@ def zero_module(module): nn.init.zeros_(p) return module +def copy_weights_to_controlnet(controlnet : nn.Module, + diffusion_model: nn.Module, + verbose: bool = True): + ''' + Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output + keys that have matched and those that haven't. + Args: + controlnet: instance of ControlNet + diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet + verbose: if True, the matched and unmatched keys will be printed. + + Returns: + ''' + + output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False) + if verbose: + dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys] + print("Copied weights from %d keys of the diffusion model into the ControlNet: \n%s\n" + "ControlNet incompatible keys: %d.\n%s\n" + "DiffusionModel incompatible keys: %d.\n%s\n" %(len(dm_keys), "; ".join(dm_keys), + len(output.missing_keys), "; ".join(output.missing_keys), + len(output.unexpected_keys), "; ".join(output.unexpected_keys))) + return class ControlNet(nn.Module): """ From 2e80439806dc932449ad54498cb20209c235dab2 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 20 Mar 2024 14:22:02 +0000 Subject: [PATCH 2/3] Modify formatting: removed return statement, return in args description, and formatted the print with f-Strings. --- generative/networks/nets/controlnet.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/generative/networks/nets/controlnet.py b/generative/networks/nets/controlnet.py index 2810fac7..864e4a80 100644 --- a/generative/networks/nets/controlnet.py +++ b/generative/networks/nets/controlnet.py @@ -130,19 +130,15 @@ def copy_weights_to_controlnet(controlnet : nn.Module, controlnet: instance of ControlNet diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet verbose: if True, the matched and unmatched keys will be printed. - - Returns: ''' output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False) if verbose: dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys] - print("Copied weights from %d keys of the diffusion model into the ControlNet: \n%s\n" - "ControlNet incompatible keys: %d.\n%s\n" - "DiffusionModel incompatible keys: %d.\n%s\n" %(len(dm_keys), "; ".join(dm_keys), - len(output.missing_keys), "; ".join(output.missing_keys), - len(output.unexpected_keys), "; ".join(output.unexpected_keys))) - return + print(f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:" + f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:" + f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:" + f"\n{'; '.join(output.unexpected_keys)}") class ControlNet(nn.Module): """ From 1719cbd1bc25ce67ab4cdb84088b63f8c14f34c6 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 21 Mar 2024 11:25:04 +0000 Subject: [PATCH 3/3] Formatting of the function --- generative/networks/nets/controlnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/generative/networks/nets/controlnet.py b/generative/networks/nets/controlnet.py index 864e4a80..caedf736 100644 --- a/generative/networks/nets/controlnet.py +++ b/generative/networks/nets/controlnet.py @@ -122,10 +122,11 @@ def zero_module(module): def copy_weights_to_controlnet(controlnet : nn.Module, diffusion_model: nn.Module, - verbose: bool = True): + verbose: bool = True) -> None: ''' Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output keys that have matched and those that haven't. + Args: controlnet: instance of ControlNet diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet