In [1]:
import torch
from imagen_pytorch import Unet, BaseUnet64, SRUnet256, SRUnet1024, Imagen, ImagenTrainer

In [8]:
device = (
    torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)

# model

In [18]:
# unets for unconditional imagen
unet1 = BaseUnet64(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 1,
    layer_attns = (False, False, False, True),
    layer_cross_attns = False
)

unet2 = SRUnet256(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = (2, 4, 8),
    layer_attns = (False, False, True),
    layer_cross_attns = False
)

# unet3 = SRUnet1024(
#     dim = 32,
#     dim_mults = (1, 2, 4),
#     num_resnet_blocks = (2, 4, 8),
#     layer_attns = (False, False, True),
#     layer_cross_attns = False
# )

unets = (unet1, unet2)

# imagen, which contains the unet above
imagen = Imagen(
    condition_on_text = False,      # this must be set to False for unconditional Imagen
    unets = unets,
    image_sizes = (64, 256),
    timesteps = 1000
).to(device)

# load weight

In [10]:
# load weight
runs = "800000"
state_dict: dict = torch.load(f"results/flowers/checkpoint.{runs}.pt")
print(state_dict.keys())  # ['step', 'model', 'opt', 'ema', 'scaler', 'version']

dict_keys(['model', 'version', 'steps', 'scaler0', 'optim0', 'scaler1', 'optim1', 'ema'])


# get model keys

In [11]:
model_dict: dict = state_dict["model"]

In [12]:
model_keys = model_dict.keys()
print(len(model_keys))
model_keys

1117


odict_keys(['unets.0.null_text_embed', 'unets.0.null_text_hidden', 'unets.0.init_conv.convs.0.weight', 'unets.0.init_conv.convs.0.bias', 'unets.0.init_conv.convs.1.weight', 'unets.0.init_conv.convs.1.bias', 'unets.0.init_conv.convs.2.weight', 'unets.0.init_conv.convs.2.bias', 'unets.0.to_time_hiddens.0.weights', 'unets.0.to_time_hiddens.1.weight', 'unets.0.to_time_hiddens.1.bias', 'unets.0.to_time_cond.0.weight', 'unets.0.to_time_cond.0.bias', 'unets.0.to_time_tokens.0.weight', 'unets.0.to_time_tokens.0.bias', 'unets.0.norm_cond.weight', 'unets.0.norm_cond.bias', 'unets.0.attn_pool.latents', 'unets.0.attn_pool.pos_emb.weight', 'unets.0.attn_pool.to_latents_from_mean_pooled_seq.0.g', 'unets.0.attn_pool.to_latents_from_mean_pooled_seq.1.weight', 'unets.0.attn_pool.to_latents_from_mean_pooled_seq.1.bias', 'unets.0.attn_pool.layers.0.0.q_scale', 'unets.0.attn_pool.layers.0.0.k_scale', 'unets.0.attn_pool.layers.0.0.norm.weight', 'unets.0.attn_pool.layers.0.0.norm.bias', 'unets.0.attn_pool.l

In [13]:
imagen.load_state_dict(model_dict)

<All keys matched successfully>

# get ema keys

In [14]:
ema_dict: dict = state_dict["ema"]

In [19]:
ema_keys = ema_dict.keys()
print(len(ema_keys))
ema_keys

2238


odict_keys(['0.initted', '0.step', '0.online_model.null_text_embed', '0.online_model.null_text_hidden', '0.online_model.init_conv.convs.0.weight', '0.online_model.init_conv.convs.0.bias', '0.online_model.init_conv.convs.1.weight', '0.online_model.init_conv.convs.1.bias', '0.online_model.init_conv.convs.2.weight', '0.online_model.init_conv.convs.2.bias', '0.online_model.to_time_hiddens.0.weights', '0.online_model.to_time_hiddens.1.weight', '0.online_model.to_time_hiddens.1.bias', '0.online_model.to_time_cond.0.weight', '0.online_model.to_time_cond.0.bias', '0.online_model.to_time_tokens.0.weight', '0.online_model.to_time_tokens.0.bias', '0.online_model.norm_cond.weight', '0.online_model.norm_cond.bias', '0.online_model.attn_pool.latents', '0.online_model.attn_pool.pos_emb.weight', '0.online_model.attn_pool.to_latents_from_mean_pooled_seq.0.g', '0.online_model.attn_pool.to_latents_from_mean_pooled_seq.1.weight', '0.online_model.attn_pool.to_latents_from_mean_pooled_seq.1.bias', '0.online

# convnert ema to model

In [20]:
unet_number = len(unets)
unet_number

2

In [30]:
new_ema_dict = {}
for number in range(unet_number):
    for ema_key, ema_value in ema_dict.items():
        ema_key_prefix = f"{number}.ema_model"
        if ema_key_prefix in ema_key:
            dst_key = ema_key.replace(ema_key_prefix, f"unets.{number}")
            new_ema_dict[dst_key] = ema_value
new_ema_dict.keys()

dict_keys(['unets.0.null_text_embed', 'unets.0.null_text_hidden', 'unets.0.init_conv.convs.0.weight', 'unets.0.init_conv.convs.0.bias', 'unets.0.init_conv.convs.1.weight', 'unets.0.init_conv.convs.1.bias', 'unets.0.init_conv.convs.2.weight', 'unets.0.init_conv.convs.2.bias', 'unets.0.to_time_hiddens.0.weights', 'unets.0.to_time_hiddens.1.weight', 'unets.0.to_time_hiddens.1.bias', 'unets.0.to_time_cond.0.weight', 'unets.0.to_time_cond.0.bias', 'unets.0.to_time_tokens.0.weight', 'unets.0.to_time_tokens.0.bias', 'unets.0.norm_cond.weight', 'unets.0.norm_cond.bias', 'unets.0.attn_pool.latents', 'unets.0.attn_pool.pos_emb.weight', 'unets.0.attn_pool.to_latents_from_mean_pooled_seq.0.g', 'unets.0.attn_pool.to_latents_from_mean_pooled_seq.1.weight', 'unets.0.attn_pool.to_latents_from_mean_pooled_seq.1.bias', 'unets.0.attn_pool.layers.0.0.q_scale', 'unets.0.attn_pool.layers.0.0.k_scale', 'unets.0.attn_pool.layers.0.0.norm.weight', 'unets.0.attn_pool.layers.0.0.norm.bias', 'unets.0.attn_pool.la

In [31]:
imagen.load_state_dict(new_ema_dict)

<All keys matched successfully>