In [1]:
import os
import torch
import yaml
import sys

MULTIPOLY_FOLDER = os.path.dirname(os.path.dirname(os.getcwd()))
POLYFFUSION_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\chkpts\weights_best.pt")
POLYFFUSION_PARAMS_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\params.yaml")
CHORD_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"pretrained\chd8bar\weights.pt")

with open(POLYFFUSION_PARAMS_PATH, 'r') as f:
    params = yaml.safe_load(f)
for key,value in params.items():
    print(key,":",value)


polyffusion_checkpoint = torch.load(POLYFFUSION_CKPT_PATH)["model"]
chord_checkpoint = torch.load(CHORD_CKPT_PATH)["model"]

sys.path.append(MULTIPOLY_FOLDER)
# Define models according to the settings in `polyffusion_ckpts\...\params.yaml`
from polyffusion.dl_modules import ChordEncoder, ChordDecoder
from polyffusion.stable_diffusion.model.unet import UNetModel as PolyffusionUNet
from src.models.unet import UNetModel as MultipolyUNet

import inspect

chord_enc_params = inspect.signature(ChordEncoder.__init__).parameters
chord_enc_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_enc_params}
chord_encoder = ChordEncoder(**chord_enc_params_dict)
CHORD_ENC_PREFIX = "chord_enc."
chord_enc_state_dict = {key.removeprefix(CHORD_ENC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_ENC_PREFIX)}
chord_encoder.load_state_dict(chord_enc_state_dict)

chord_dec_params = inspect.signature(ChordDecoder.__init__).parameters
chord_dec_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_dec_params}
chord_decoder = ChordDecoder(**chord_dec_params_dict)
CHORD_DEC_PREFIX = "chord_dec."
chord_dec_state_dict = {key.removeprefix(CHORD_DEC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_DEC_PREFIX)}
chord_decoder.load_state_dict(chord_dec_state_dict)



polyffusion_unet_params = inspect.signature(PolyffusionUNet.__init__).parameters
polyffusion_unet_params_dict = {key:params[key] for key in params if key in polyffusion_unet_params}
polyffusion_unet = PolyffusionUNet(**polyffusion_unet_params_dict)
UNET_PREFIX = "ldm.eps_model."
polyffusion_unet_state_dict = {key.removeprefix(UNET_PREFIX):value for key,value in polyffusion_checkpoint.items() if key.startswith(UNET_PREFIX)}
polyffusion_unet.load_state_dict(polyffusion_unet_state_dict)

multipoly_unet_params = inspect.signature(MultipolyUNet.__init__).parameters
multipoly_unet_params_dict = {key:params[key] for key in params if key in multipoly_unet_params}
multipoly_unet_params_dict["n_intertrack_head"] = 4
multipoly_unet_params_dict["num_intertrack_encoder_layers"] = 2
multipoly_unet_params_dict["intertrack_attention_levels"] = [2,3]
multipoly_unet = MultipolyUNet(**multipoly_unet_params_dict)

multipoly_unet.load_polyffusion_checkpoints(polyffusion_checkpoint)


model_name : sdf_chd8bar
batch_size : 16
max_epoch : 100
learning_rate : 5e-05
max_grad_norm : 10
fp16 : True
num_workers : 4
pin_memory : True
in_channels : 2
out_channels : 2
channels : 64
attention_levels : [2, 3]
n_res_blocks : 2
channel_multipliers : [1, 2, 4, 4]
n_heads : 4
tf_layers : 1
d_cond : 512
linear_start : 0.00085
linear_end : 0.012
n_steps : 1000
latent_scaling_factor : 0.18215
img_h : 128
img_w : 128
cond_type : chord
cond_mode : mix
use_enc : True
chd_n_step : 32
chd_input_dim : 36
chd_z_input_dim : 512
chd_hidden_dim : 512
chd_z_dim : 512
---------------loading polyffusion weights-------------------
input_blocks.7.2.attention.layers.0.self_attn.in_proj_weight
input_blocks.7.2.attention.layers.0.self_attn.in_proj_bias
input_blocks.7.2.attention.layers.0.self_attn.out_proj.weight
input_blocks.7.2.attention.layers.0.self_attn.out_proj.bias
input_blocks.7.2.attention.layers.0.linear1.weight
input_blocks.7.2.attention.layers.0.linear1.bias
input_blocks.7.2.attention.layer