In [1]:
import os
import torch
import yaml
import sys

MULTIPOLY_FOLDER = os.path.dirname(os.path.dirname(os.getcwd()))
POLYFFUSION_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\chkpts\weights_best.pt")
POLYFFUSION_PARAMS_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\params.yaml")
CHORD_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"pretrained\chd8bar\weights.pt")

with open(POLYFFUSION_PARAMS_PATH, 'r') as f:
    params = yaml.safe_load(f)
for key,value in params.items():
    print(key,":",value)

polyffusion_checkpoint = torch.load(POLYFFUSION_CKPT_PATH)["model"]
chord_checkpoint = torch.load(CHORD_CKPT_PATH)["model"]

sys.path.append(MULTIPOLY_FOLDER)

model_name : sdf_chd8bar
batch_size : 16
max_epoch : 100
learning_rate : 5e-05
max_grad_norm : 10
fp16 : True
num_workers : 4
pin_memory : True
in_channels : 2
out_channels : 2
channels : 64
attention_levels : [2, 3]
n_res_blocks : 2
channel_multipliers : [1, 2, 4, 4]
n_heads : 4
tf_layers : 1
d_cond : 512
linear_start : 0.00085
linear_end : 0.012
n_steps : 1000
latent_scaling_factor : 0.18215
img_h : 128
img_w : 128
cond_type : chord
cond_mode : mix
use_enc : True
chd_n_step : 32
chd_input_dim : 36
chd_z_input_dim : 512
chd_hidden_dim : 512
chd_z_dim : 512


In [4]:
import inspect

from src.models.unet import UNetModel
unet_params = inspect.signature(UNetModel.__init__).parameters
unet_params_dict = {key:params[key] for key in params if key in unet_params}
unet_params_dict["n_intertrack_head"] = 4
unet_params_dict["num_intertrack_encoder_layers"] = 2
unet_params_dict["intertrack_attention_levels"] = [2,3]
unet = UNetModel(**unet_params_dict)

In [10]:
# zero out all
for param in unet.parameters():
    param.data.fill_(0.)

In [13]:
unet_from_polyffusion_state_dict = {k.removeprefix("ldm.eps_model."):v for k,v in polyffusion_checkpoint.items() if k.removeprefix("ldm.eps_model.") in unet.state_dict()}

for key,value in unet_from_polyffusion_state_dict.items():
    print(key,value)

time_embed.0.weight tensor([[-0.0717, -0.0725,  0.1073,  ...,  0.0771, -0.0501,  0.0828],
        [ 0.0240,  0.0322,  0.1113,  ..., -0.1082, -0.1121, -0.1238],
        [-0.0072,  0.0863,  0.0209,  ..., -0.1100, -0.1254, -0.0038],
        ...,
        [ 0.0061, -0.0068, -0.0696,  ..., -0.0303, -0.1490,  0.0173],
        [-0.0471, -0.0032, -0.0486,  ...,  0.0377,  0.0782, -0.0618],
        [-0.0856, -0.0943,  0.0224,  ...,  0.0376, -0.0388,  0.0787]])
time_embed.0.bias tensor([-3.3317e-02, -8.0416e-03, -7.8428e-02, -4.3651e-02,  9.3959e-02,
        -4.7979e-02, -3.2156e-02, -8.4114e-02, -2.8934e-02,  2.0685e-02,
        -6.1823e-02,  6.6695e-02,  1.0160e-01,  4.8214e-02, -7.1400e-02,
        -9.2358e-02,  2.3931e-02, -1.0669e-01,  2.1136e-02,  7.7611e-03,
        -7.4752e-02,  1.1354e-01,  6.9470e-02, -3.3611e-02,  8.9758e-02,
        -2.2095e-02,  2.8513e-02,  8.7625e-02,  4.4205e-02,  9.3050e-02,
         7.7091e-02,  8.5656e-03,  1.9089e-02, -9.6568e-02, -1.2663e-01,
         5.4565e-

In [None]:
missing_keys,_= unet.load_state_dict(unet_from_polyffusion_state_dict, strict=False)
new_state_dict = unet.state_dict()
for key in missing_keys:
    print(new_state_dict[key])
    

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.

In [22]:
for key in unet.state_dict().keys():
    if key not in missing_keys:
        print(unet.state_dict()[key])
        print(polyffusion_checkpoint["ldm.eps_model."+key])
        break

tensor([[-0.0717, -0.0725,  0.1073,  ...,  0.0771, -0.0501,  0.0828],
        [ 0.0240,  0.0322,  0.1113,  ..., -0.1082, -0.1121, -0.1238],
        [-0.0072,  0.0863,  0.0209,  ..., -0.1100, -0.1254, -0.0038],
        ...,
        [ 0.0061, -0.0068, -0.0696,  ..., -0.0303, -0.1490,  0.0173],
        [-0.0471, -0.0032, -0.0486,  ...,  0.0377,  0.0782, -0.0618],
        [-0.0856, -0.0943,  0.0224,  ...,  0.0376, -0.0388,  0.0787]])
tensor([[-0.0717, -0.0725,  0.1073,  ...,  0.0771, -0.0501,  0.0828],
        [ 0.0240,  0.0322,  0.1113,  ..., -0.1082, -0.1121, -0.1238],
        [-0.0072,  0.0863,  0.0209,  ..., -0.1100, -0.1254, -0.0038],
        ...,
        [ 0.0061, -0.0068, -0.0696,  ..., -0.0303, -0.1490,  0.0173],
        [-0.0471, -0.0032, -0.0486,  ...,  0.0377,  0.0782, -0.0618],
        [-0.0856, -0.0943,  0.0224,  ...,  0.0376, -0.0388,  0.0787]])
