diff --git a/generative/networks/nets/controlnet.py b/generative/networks/nets/controlnet.py index ebe2459c..caedf736 100644 --- a/generative/networks/nets/controlnet.py +++ b/generative/networks/nets/controlnet.py @@ -41,7 +41,6 @@ from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding - class ControlNetConditioningEmbedding(nn.Module): """ Network to encode the conditioning into a latent space. @@ -121,6 +120,26 @@ def zero_module(module): nn.init.zeros_(p) return module +def copy_weights_to_controlnet(controlnet : nn.Module, + diffusion_model: nn.Module, + verbose: bool = True) -> None: + ''' + Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output + keys that have matched and those that haven't. + + Args: + controlnet: instance of ControlNet + diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet + verbose: if True, the matched and unmatched keys will be printed. + ''' + + output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False) + if verbose: + dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys] + print(f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:" + f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:" + f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:" + f"\n{'; '.join(output.unexpected_keys)}") class ControlNet(nn.Module): """