In [24]:
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion

# model

In [25]:
# gen model
model = Unet(dim=64, dim_mults=(1, 2, 4, 8))

In [26]:
diffusion = GaussianDiffusion(
    model,
    image_size=128,
    timesteps=1000,  # number of steps
    sampling_timesteps=250,  # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)

# load weight

In [2]:
runs = "200"
state_dict: dict = torch.load(f"results/flowers/200000/model-{runs}.pt")
print(state_dict.keys())  # ['step', 'model', 'opt', 'ema', 'scaler', 'version']

dict_keys(['step', 'model', 'opt', 'ema', 'scaler', 'version'])


# get model keys

In [28]:
model_dict: dict = state_dict["model"]

In [29]:
model_keys = model_dict.keys()
print(len(model_keys))
model_keys

287


odict_keys(['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod', 'log_one_minus_alphas_cumprod', 'sqrt_recip_alphas_cumprod', 'sqrt_recipm1_alphas_cumprod', 'posterior_variance', 'posterior_log_variance_clipped', 'posterior_mean_coef1', 'posterior_mean_coef2', 'loss_weight', 'model.init_conv.weight', 'model.init_conv.bias', 'model.time_mlp.1.weight', 'model.time_mlp.1.bias', 'model.time_mlp.3.weight', 'model.time_mlp.3.bias', 'model.downs.0.0.mlp.1.weight', 'model.downs.0.0.mlp.1.bias', 'model.downs.0.0.block1.proj.weight', 'model.downs.0.0.block1.proj.bias', 'model.downs.0.0.block1.norm.weight', 'model.downs.0.0.block1.norm.bias', 'model.downs.0.0.block2.proj.weight', 'model.downs.0.0.block2.proj.bias', 'model.downs.0.0.block2.norm.weight', 'model.downs.0.0.block2.norm.bias', 'model.downs.0.1.mlp.1.weight', 'model.downs.0.1.mlp.1.bias', 'model.downs.0.1.block1.proj.weight', 'model.downs.0.1.block1.proj.bias', 'model.downs.0.1.block1

In [30]:
diffusion.load_state_dict(model_dict)

<All keys matched successfully>

# get ema keys

In [31]:
ema_dict: dict = state_dict["ema"]

In [32]:
ema_keys = ema_dict.keys()
print(len(ema_keys))
ema_keys

576


odict_keys(['initted', 'step', 'online_model.betas', 'online_model.alphas_cumprod', 'online_model.alphas_cumprod_prev', 'online_model.sqrt_alphas_cumprod', 'online_model.sqrt_one_minus_alphas_cumprod', 'online_model.log_one_minus_alphas_cumprod', 'online_model.sqrt_recip_alphas_cumprod', 'online_model.sqrt_recipm1_alphas_cumprod', 'online_model.posterior_variance', 'online_model.posterior_log_variance_clipped', 'online_model.posterior_mean_coef1', 'online_model.posterior_mean_coef2', 'online_model.loss_weight', 'online_model.model.init_conv.weight', 'online_model.model.init_conv.bias', 'online_model.model.time_mlp.1.weight', 'online_model.model.time_mlp.1.bias', 'online_model.model.time_mlp.3.weight', 'online_model.model.time_mlp.3.bias', 'online_model.model.downs.0.0.mlp.1.weight', 'online_model.model.downs.0.0.mlp.1.bias', 'online_model.model.downs.0.0.block1.proj.weight', 'online_model.model.downs.0.0.block1.proj.bias', 'online_model.model.downs.0.0.block1.norm.weight', 'online_mode

In [33]:
# 这样加载会失败
diffusion.load_state_dict(ema_dict)

# convnert ema to model

In [23]:
ema_model = {}
for key in ema_keys:
    if "ema_model." in key:
        ema_model[key[10:]] = ema_dict[key]
print(len(ema_model.keys()))
ema_model.keys()

287


dict_keys(['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod', 'log_one_minus_alphas_cumprod', 'sqrt_recip_alphas_cumprod', 'sqrt_recipm1_alphas_cumprod', 'posterior_variance', 'posterior_log_variance_clipped', 'posterior_mean_coef1', 'posterior_mean_coef2', 'loss_weight', 'model.init_conv.weight', 'model.init_conv.bias', 'model.time_mlp.1.weight', 'model.time_mlp.1.bias', 'model.time_mlp.3.weight', 'model.time_mlp.3.bias', 'model.downs.0.0.mlp.1.weight', 'model.downs.0.0.mlp.1.bias', 'model.downs.0.0.block1.proj.weight', 'model.downs.0.0.block1.proj.bias', 'model.downs.0.0.block1.norm.weight', 'model.downs.0.0.block1.norm.bias', 'model.downs.0.0.block2.proj.weight', 'model.downs.0.0.block2.proj.bias', 'model.downs.0.0.block2.norm.weight', 'model.downs.0.0.block2.norm.bias', 'model.downs.0.1.mlp.1.weight', 'model.downs.0.1.mlp.1.bias', 'model.downs.0.1.block1.proj.weight', 'model.downs.0.1.block1.proj.bias', 'model.downs.0.1.block1.

In [34]:
diffusion.load_state_dict(ema_model)

<All keys matched successfully>