Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion generative/networks/nets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding


class ControlNetConditioningEmbedding(nn.Module):
"""
Network to encode the conditioning into a latent space.
Expand Down Expand Up @@ -121,6 +120,26 @@ def zero_module(module):
nn.init.zeros_(p)
return module

def copy_weights_to_controlnet(controlnet : nn.Module,
diffusion_model: nn.Module,
verbose: bool = True) -> None:
'''
Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output
keys that have matched and those that haven't.

Args:
controlnet: instance of ControlNet
diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet
verbose: if True, the matched and unmatched keys will be printed.
'''

output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False)
if verbose:
dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys]
print(f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:"
f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:"
f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:"
f"\n{'; '.join(output.unexpected_keys)}")

class ControlNet(nn.Module):
"""
Expand Down