In [1]:
import os
import torch
import yaml
import sys

MULTIPOLY_FOLDER = os.path.dirname(os.path.dirname(os.getcwd()))
POLYFFUSION_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\chkpts\weights_best.pt")
POLYFFUSION_PARAMS_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\params.yaml")
CHORD_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"pretrained\chd8bar\weights.pt")

with open(POLYFFUSION_PARAMS_PATH, 'r') as f:
    params = yaml.safe_load(f)
for key,value in params.items():
    print(key,":",value)


polyffusion_checkpoint = torch.load(POLYFFUSION_CKPT_PATH)["model"]
chord_checkpoint = torch.load(CHORD_CKPT_PATH)["model"]

sys.path.append(MULTIPOLY_FOLDER)

model_name : sdf_chd8bar
batch_size : 16
max_epoch : 100
learning_rate : 5e-05
max_grad_norm : 10
fp16 : True
num_workers : 4
pin_memory : True
in_channels : 2
out_channels : 2
channels : 64
attention_levels : [2, 3]
n_res_blocks : 2
channel_multipliers : [1, 2, 4, 4]
n_heads : 4
tf_layers : 1
d_cond : 512
linear_start : 0.00085
linear_end : 0.012
n_steps : 1000
latent_scaling_factor : 0.18215
img_h : 128
img_w : 128
cond_type : chord
cond_mode : mix
use_enc : True
chd_n_step : 32
chd_input_dim : 36
chd_z_input_dim : 512
chd_hidden_dim : 512
chd_z_dim : 512


In [2]:
# Define models according to the settings in `polyffusion_ckpts\...\params.yaml`
from polyffusion.dl_modules import ChordEncoder, ChordDecoder
from polyffusion.stable_diffusion.model.unet import UNetModel as PolyffusionUNet
from src.models.unet import UNetModel as MultipolyUNet

import inspect

chord_enc_params = inspect.signature(ChordEncoder.__init__).parameters
chord_enc_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_enc_params}
chord_encoder = ChordEncoder(**chord_enc_params_dict)
CHORD_ENC_PREFIX = "chord_enc."
chord_enc_state_dict = {key.removeprefix(CHORD_ENC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_ENC_PREFIX)}
chord_encoder.load_state_dict(chord_enc_state_dict)


<All keys matched successfully>

In [3]:

chord_dec_params = inspect.signature(ChordDecoder.__init__).parameters
chord_dec_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_dec_params}
chord_decoder = ChordDecoder(**chord_dec_params_dict)
CHORD_DEC_PREFIX = "chord_dec."
chord_dec_state_dict = {key.removeprefix(CHORD_DEC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_DEC_PREFIX)}
chord_decoder.load_state_dict(chord_dec_state_dict)



<All keys matched successfully>

In [4]:


polyffusion_unet_params = inspect.signature(PolyffusionUNet.__init__).parameters
polyffusion_unet_params_dict = {key:params[key] for key in params if key in polyffusion_unet_params}
polyffusion_unet = PolyffusionUNet(**polyffusion_unet_params_dict)
UNET_PREFIX = "ldm.eps_model."
polyffusion_unet_state_dict = {key.removeprefix(UNET_PREFIX):value for key,value in polyffusion_checkpoint.items() if key.startswith(UNET_PREFIX)}
polyffusion_unet.load_state_dict(polyffusion_unet_state_dict)

for key in polyffusion_unet.state_dict().keys():
    print(key)

time_embed.0.weight
time_embed.0.bias
time_embed.2.weight
time_embed.2.bias
input_blocks.0.0.weight
input_blocks.0.0.bias
input_blocks.1.0.in_layers.0.weight
input_blocks.1.0.in_layers.0.bias
input_blocks.1.0.in_layers.2.weight
input_blocks.1.0.in_layers.2.bias
input_blocks.1.0.emb_layers.1.weight
input_blocks.1.0.emb_layers.1.bias
input_blocks.1.0.out_layers.0.weight
input_blocks.1.0.out_layers.0.bias
input_blocks.1.0.out_layers.3.weight
input_blocks.1.0.out_layers.3.bias
input_blocks.2.0.in_layers.0.weight
input_blocks.2.0.in_layers.0.bias
input_blocks.2.0.in_layers.2.weight
input_blocks.2.0.in_layers.2.bias
input_blocks.2.0.emb_layers.1.weight
input_blocks.2.0.emb_layers.1.bias
input_blocks.2.0.out_layers.0.weight
input_blocks.2.0.out_layers.0.bias
input_blocks.2.0.out_layers.3.weight
input_blocks.2.0.out_layers.3.bias
input_blocks.3.0.op.weight
input_blocks.3.0.op.bias
input_blocks.4.0.in_layers.0.weight
input_blocks.4.0.in_layers.0.bias
input_blocks.4.0.in_layers.2.weight
input_bl

In [5]:
multipoly_unet_params = inspect.signature(MultipolyUNet.__init__).parameters
multipoly_unet_params_dict = {key:params[key] for key in params if key in multipoly_unet_params}
multipoly_unet_params_dict["n_intertrack_head"] = 4
multipoly_unet_params_dict["num_intertrack_encoder_layers"] = 2
multipoly_unet_params_dict["intertrack_attention_levels"] = [2,3]
multipoly_unet = MultipolyUNet(**multipoly_unet_params_dict)

for key in multipoly_unet.state_dict().keys():
    print(key)

time_embed.0.weight
time_embed.0.bias
time_embed.2.weight
time_embed.2.bias
input_blocks.0.0.weight
input_blocks.0.0.bias
input_blocks.1.0.in_layers.0.weight
input_blocks.1.0.in_layers.0.bias
input_blocks.1.0.in_layers.2.weight
input_blocks.1.0.in_layers.2.bias
input_blocks.1.0.emb_layers.1.weight
input_blocks.1.0.emb_layers.1.bias
input_blocks.1.0.out_layers.0.weight
input_blocks.1.0.out_layers.0.bias
input_blocks.1.0.out_layers.3.weight
input_blocks.1.0.out_layers.3.bias
input_blocks.2.0.in_layers.0.weight
input_blocks.2.0.in_layers.0.bias
input_blocks.2.0.in_layers.2.weight
input_blocks.2.0.in_layers.2.bias
input_blocks.2.0.emb_layers.1.weight
input_blocks.2.0.emb_layers.1.bias
input_blocks.2.0.out_layers.0.weight
input_blocks.2.0.out_layers.0.bias
input_blocks.2.0.out_layers.3.weight
input_blocks.2.0.out_layers.3.bias
input_blocks.3.0.op.weight
input_blocks.3.0.op.bias
input_blocks.4.0.in_layers.0.weight
input_blocks.4.0.in_layers.0.bias
input_blocks.4.0.in_layers.2.weight
input_bl

In [6]:

multipoly_unet.load_polyffusion_checkpoints(polyffusion_checkpoint)


---------------loading polyffusion weights-------------------
input_blocks.7.2.attention.layers.0.self_attn.in_proj_weight
input_blocks.7.2.attention.layers.0.self_attn.in_proj_bias
input_blocks.7.2.attention.layers.0.self_attn.out_proj.weight
input_blocks.7.2.attention.layers.0.self_attn.out_proj.bias
input_blocks.7.2.attention.layers.0.linear1.weight
input_blocks.7.2.attention.layers.0.linear1.bias
input_blocks.7.2.attention.layers.0.linear2.weight
input_blocks.7.2.attention.layers.0.linear2.bias
input_blocks.7.2.attention.layers.0.norm1.weight
input_blocks.7.2.attention.layers.0.norm1.bias
input_blocks.7.2.attention.layers.0.norm2.weight
input_blocks.7.2.attention.layers.0.norm2.bias
input_blocks.7.2.attention.layers.1.self_attn.in_proj_weight
input_blocks.7.2.attention.layers.1.self_attn.in_proj_bias
input_blocks.7.2.attention.layers.1.self_attn.out_proj.weight
input_blocks.7.2.attention.layers.1.self_attn.out_proj.bias
input_blocks.7.2.attention.layers.1.linear1.weight
input_block

In [7]:
batch_size = 2
track_num = 5
image_width = 128
image_height = 128
channels = 2

device = "cuda"

input_tensors = torch.randn(batch_size, track_num, channels, image_width, image_height).to(device)
cond = torch.randn(batch_size*track_num, 1, 512).to(device)
t = torch.randn(batch_size*track_num, ).to(device)

polyffusion_unet = polyffusion_unet.to(device)
polyffusion_unet.eval()
with torch.no_grad():
    poly_output = polyffusion_unet(input_tensors.reshape(batch_size*track_num, channels, image_width, image_height),t,  cond)
print(poly_output)


tensor([[[[-4.2153e-01,  2.2796e+00,  8.9296e-01,  ...,  1.6508e+00,
            1.0649e+00,  7.7144e-02],
          [-1.2095e+00,  2.2932e+00,  3.1827e+00,  ..., -7.6543e-01,
           -2.3157e+00, -5.4130e-02],
          [ 3.3175e-01, -2.8370e+00, -2.8006e-01,  ..., -1.3234e+00,
            1.4200e+00,  1.0681e+00],
          ...,
          [-2.2961e+00,  1.3349e+00,  1.2207e+00,  ..., -4.4104e-01,
           -2.3516e+00,  8.0565e-01],
          [ 8.2355e-02, -5.0008e-01,  3.7398e-01,  ..., -1.5104e+00,
            2.0714e+00,  1.8730e+00],
          [ 5.5025e-01,  2.6173e-02,  6.0537e-01,  ..., -1.1773e+00,
           -5.3282e-01,  8.9610e-01]],

         [[-3.1474e-01, -2.5832e+00,  9.3191e-01,  ...,  1.3512e-01,
           -1.4860e+00, -5.6623e-01],
          [-1.1148e+00,  1.5074e+00,  4.2294e-01,  ...,  2.1198e-01,
            2.6903e-01, -1.3249e+00],
          [ 6.1299e-01, -2.0538e+00, -3.5259e-03,  ...,  2.6435e+00,
            2.3889e+00, -2.4689e-01],
          ...,
     

In [8]:
multipoly_unet = multipoly_unet.to(device)
multipoly_unet.eval()
with torch.no_grad():
    multi_output = multipoly_unet(input_tensors, t, cond)
print(multi_output)

tensor([[[[[-4.2145e-01,  2.2796e+00,  8.9285e-01,  ...,  1.6507e+00,
             1.0647e+00,  7.7178e-02],
           [-1.2095e+00,  2.2932e+00,  3.1828e+00,  ..., -7.6538e-01,
            -2.3158e+00, -5.4117e-02],
           [ 3.3190e-01, -2.8370e+00, -2.8012e-01,  ..., -1.3235e+00,
             1.4201e+00,  1.0681e+00],
           ...,
           [-2.2961e+00,  1.3349e+00,  1.2207e+00,  ..., -4.4104e-01,
            -2.3515e+00,  8.0562e-01],
           [ 8.2380e-02, -5.0005e-01,  3.7395e-01,  ..., -1.5103e+00,
             2.0713e+00,  1.8730e+00],
           [ 5.5017e-01,  2.6166e-02,  6.0540e-01,  ..., -1.1774e+00,
            -5.3280e-01,  8.9613e-01]],

          [[-3.1481e-01, -2.5831e+00,  9.3184e-01,  ...,  1.3507e-01,
            -1.4860e+00, -5.6623e-01],
           [-1.1148e+00,  1.5074e+00,  4.2303e-01,  ...,  2.1185e-01,
             2.6898e-01, -1.3248e+00],
           [ 6.1301e-01, -2.0538e+00, -3.5293e-03,  ...,  2.6434e+00,
             2.3890e+00, -2.4687e-01],
 

In [9]:
poly_output = poly_output.flatten()
multi_output = multi_output.flatten()
avg_diff = (poly_output-multi_output).abs().mean()
print(avg_diff)

tensor(5.8036e-05, device='cuda:0')
