In [1]:
import os
import torch
import yaml
import sys

MULTIPOLY_FOLDER = os.path.dirname(os.path.dirname(os.getcwd()))
POLYFFUSION_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\chkpts\weights_best.pt")
POLYFFUSION_PARAMS_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\params.yaml")
CHORD_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"pretrained\chd8bar\weights.pt")

with open(POLYFFUSION_PARAMS_PATH, 'r') as f:
    params = yaml.safe_load(f)
for key,value in params.items():
    print(key,":",value)


polyffusion_checkpoint = torch.load(POLYFFUSION_CKPT_PATH)["model"]
chord_checkpoint = torch.load(CHORD_CKPT_PATH)["model"]

sys.path.append(MULTIPOLY_FOLDER)

model_name : sdf_chd8bar
batch_size : 16
max_epoch : 100
learning_rate : 5e-05
max_grad_norm : 10
fp16 : True
num_workers : 4
pin_memory : True
in_channels : 2
out_channels : 2
channels : 64
attention_levels : [2, 3]
n_res_blocks : 2
channel_multipliers : [1, 2, 4, 4]
n_heads : 4
tf_layers : 1
d_cond : 512
linear_start : 0.00085
linear_end : 0.012
n_steps : 1000
latent_scaling_factor : 0.18215
img_h : 128
img_w : 128
cond_type : chord
cond_mode : mix
use_enc : True
chd_n_step : 32
chd_input_dim : 36
chd_z_input_dim : 512
chd_hidden_dim : 512
chd_z_dim : 512


In [2]:
# Define models according to the settings in `polyffusion_ckpts\...\params.yaml`
from polyffusion.dl_modules import ChordEncoder, ChordDecoder
from polyffusion.stable_diffusion.model.unet import UNetModel as PolyffusionUNet
from src.models.unet import UNetModel as MultipolyUNet

import inspect

chord_enc_params = inspect.signature(ChordEncoder.__init__).parameters
chord_enc_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_enc_params}
chord_encoder = ChordEncoder(**chord_enc_params_dict)
CHORD_ENC_PREFIX = "chord_enc."
chord_enc_state_dict = {key.removeprefix(CHORD_ENC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_ENC_PREFIX)}
chord_encoder.load_state_dict(chord_enc_state_dict)


<All keys matched successfully>

In [3]:

chord_dec_params = inspect.signature(ChordDecoder.__init__).parameters
chord_dec_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_dec_params}
chord_decoder = ChordDecoder(**chord_dec_params_dict)
CHORD_DEC_PREFIX = "chord_dec."
chord_dec_state_dict = {key.removeprefix(CHORD_DEC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_DEC_PREFIX)}
chord_decoder.load_state_dict(chord_dec_state_dict)



<All keys matched successfully>

In [4]:


polyffusion_unet_params = inspect.signature(PolyffusionUNet.__init__).parameters
polyffusion_unet_params_dict = {key:params[key] for key in params if key in polyffusion_unet_params}
polyffusion_unet = PolyffusionUNet(**polyffusion_unet_params_dict)
UNET_PREFIX = "ldm.eps_model."
polyffusion_unet_state_dict = {key.removeprefix(UNET_PREFIX):value for key,value in polyffusion_checkpoint.items() if key.startswith(UNET_PREFIX)}
polyffusion_unet.load_state_dict(polyffusion_unet_state_dict)

for key in polyffusion_unet.state_dict().keys():
    print(key)

time_embed.0.weight
time_embed.0.bias
time_embed.2.weight
time_embed.2.bias
input_blocks.0.0.weight
input_blocks.0.0.bias
input_blocks.1.0.in_layers.0.weight
input_blocks.1.0.in_layers.0.bias
input_blocks.1.0.in_layers.2.weight
input_blocks.1.0.in_layers.2.bias
input_blocks.1.0.emb_layers.1.weight
input_blocks.1.0.emb_layers.1.bias
input_blocks.1.0.out_layers.0.weight
input_blocks.1.0.out_layers.0.bias
input_blocks.1.0.out_layers.3.weight
input_blocks.1.0.out_layers.3.bias
input_blocks.2.0.in_layers.0.weight
input_blocks.2.0.in_layers.0.bias
input_blocks.2.0.in_layers.2.weight
input_blocks.2.0.in_layers.2.bias
input_blocks.2.0.emb_layers.1.weight
input_blocks.2.0.emb_layers.1.bias
input_blocks.2.0.out_layers.0.weight
input_blocks.2.0.out_layers.0.bias
input_blocks.2.0.out_layers.3.weight
input_blocks.2.0.out_layers.3.bias
input_blocks.3.0.op.weight
input_blocks.3.0.op.bias
input_blocks.4.0.in_layers.0.weight
input_blocks.4.0.in_layers.0.bias
input_blocks.4.0.in_layers.2.weight
input_bl

In [5]:
multipoly_unet_params = inspect.signature(MultipolyUNet.__init__).parameters
multipoly_unet_params_dict = {key:params[key] for key in params if key in multipoly_unet_params}
multipoly_unet_params_dict["n_intertrack_head"] = 4
multipoly_unet_params_dict["num_intertrack_encoder_layers"] = 2
multipoly_unet_params_dict["intertrack_attention_levels"] = [2,3]
multipoly_unet = MultipolyUNet(**multipoly_unet_params_dict)

for key in multipoly_unet.state_dict().keys():
    print(key)

time_embed.0.weight
time_embed.0.bias
time_embed.2.weight
time_embed.2.bias
input_blocks.0.0.weight
input_blocks.0.0.bias
input_blocks.1.0.in_layers.0.weight
input_blocks.1.0.in_layers.0.bias
input_blocks.1.0.in_layers.2.weight
input_blocks.1.0.in_layers.2.bias
input_blocks.1.0.emb_layers.1.weight
input_blocks.1.0.emb_layers.1.bias
input_blocks.1.0.out_layers.0.weight
input_blocks.1.0.out_layers.0.bias
input_blocks.1.0.out_layers.3.weight
input_blocks.1.0.out_layers.3.bias
input_blocks.2.0.in_layers.0.weight
input_blocks.2.0.in_layers.0.bias
input_blocks.2.0.in_layers.2.weight
input_blocks.2.0.in_layers.2.bias
input_blocks.2.0.emb_layers.1.weight
input_blocks.2.0.emb_layers.1.bias
input_blocks.2.0.out_layers.0.weight
input_blocks.2.0.out_layers.0.bias
input_blocks.2.0.out_layers.3.weight
input_blocks.2.0.out_layers.3.bias
input_blocks.3.0.op.weight
input_blocks.3.0.op.bias
input_blocks.4.0.in_layers.0.weight
input_blocks.4.0.in_layers.0.bias
input_blocks.4.0.in_layers.2.weight
input_bl

In [6]:

multipoly_unet.load_polyffusion_checkpoints(polyffusion_checkpoint)


---------------loading polyffusion weights-------------------
input_blocks.7.2.attention.layers.0.self_attn.in_proj_weight
input_blocks.7.2.attention.layers.0.self_attn.in_proj_bias
input_blocks.7.2.attention.layers.0.self_attn.out_proj.weight
input_blocks.7.2.attention.layers.0.self_attn.out_proj.bias
input_blocks.7.2.attention.layers.0.linear1.weight
input_blocks.7.2.attention.layers.0.linear1.bias
input_blocks.7.2.attention.layers.0.linear2.weight
input_blocks.7.2.attention.layers.0.linear2.bias
input_blocks.7.2.attention.layers.0.norm1.weight
input_blocks.7.2.attention.layers.0.norm1.bias
input_blocks.7.2.attention.layers.0.norm2.weight
input_blocks.7.2.attention.layers.0.norm2.bias
input_blocks.7.2.attention.layers.1.self_attn.in_proj_weight
input_blocks.7.2.attention.layers.1.self_attn.in_proj_bias
input_blocks.7.2.attention.layers.1.self_attn.out_proj.weight
input_blocks.7.2.attention.layers.1.self_attn.out_proj.bias
input_blocks.7.2.attention.layers.1.linear1.weight
input_block

In [7]:
batch_size = 2
track_num = 5
image_width = 128
image_height = 128
channels = 2

device = "cuda"

input_tensors = torch.randn(batch_size, track_num, channels, image_width, image_height).to(device)
cond = torch.randn(batch_size*track_num, 1, 512).to(device)
t = torch.randn(batch_size*track_num, ).to(device)

polyffusion_unet = polyffusion_unet.to(device)
polyffusion_unet.eval()
with torch.no_grad():
    poly_output = polyffusion_unet(input_tensors.reshape(batch_size*track_num, channels, image_width, image_height),t,  cond)
print(poly_output)


tensor([[[[-0.7392,  2.2486,  2.7797,  ..., -1.5133,  0.3401,  1.0701],
          [ 1.1331, -0.0945,  0.6112,  ...,  0.5868,  0.1905, -1.2726],
          [ 1.2748,  1.2093,  0.6495,  ...,  0.7117, -1.8169, -0.5004],
          ...,
          [-0.7669, -0.4255,  0.2116,  ..., -1.8107,  0.5774, -0.0803],
          [ 0.2443,  1.1582, -1.0554,  ..., -1.1977,  0.3561, -1.0011],
          [ 0.4833, -0.3195,  0.5699,  ...,  0.2710,  0.6038, -0.4301]],

         [[ 1.6031, -2.6465,  1.9635,  ...,  1.0061,  0.3744,  0.8301],
          [ 1.2215, -1.8282, -1.3217,  ..., -0.7328, -2.2711,  0.5652],
          [-0.8916, -0.8180,  1.2046,  ...,  3.6916, -3.4847,  1.3160],
          ...,
          [ 2.9825,  0.0461,  0.2895,  ..., -0.2570,  0.0445, -0.8112],
          [ 0.3294, -1.1261,  0.1702,  ..., -2.1505, -3.0178, -1.6419],
          [ 0.3454, -0.5245, -0.0265,  ..., -0.9442, -1.3108,  0.7070]]],


        [[[ 2.0909, -2.1719,  1.1893,  ..., -0.6459,  0.2964,  1.1713],
          [ 0.6948,  1.9504,

In [8]:
multipoly_unet = multipoly_unet.to(device)
multipoly_unet.eval()
with torch.no_grad():
    multi_output = multipoly_unet(input_tensors, t, cond)
print(multi_output)

tensor([[[[[-0.7391,  2.2486,  2.7801,  ..., -1.5134,  0.3401,  1.0701],
           [ 1.1331, -0.0945,  0.6112,  ...,  0.5868,  0.1905, -1.2726],
           [ 1.2747,  1.2093,  0.6495,  ...,  0.7117, -1.8170, -0.5004],
           ...,
           [-0.7669, -0.4255,  0.2115,  ..., -1.8109,  0.5775, -0.0803],
           [ 0.2443,  1.1580, -1.0554,  ..., -1.1976,  0.3561, -1.0011],
           [ 0.4834, -0.3196,  0.5699,  ...,  0.2710,  0.6037, -0.4301]],

          [[ 1.6031, -2.6464,  1.9635,  ...,  1.0061,  0.3743,  0.8301],
           [ 1.2215, -1.8284, -1.3217,  ..., -0.7329, -2.2711,  0.5652],
           [-0.8916, -0.8181,  1.2046,  ...,  3.6917, -3.4847,  1.3160],
           ...,
           [ 2.9826,  0.0460,  0.2895,  ..., -0.2569,  0.0445, -0.8111],
           [ 0.3295, -1.1262,  0.1702,  ..., -2.1505, -3.0178, -1.6420],
           [ 0.3454, -0.5245, -0.0265,  ..., -0.9442, -1.3108,  0.7070]]],


         [[[ 2.0909, -2.1720,  1.1893,  ..., -0.6460,  0.2965,  1.1714],
           [ 

In [12]:
poly_output = poly_output.flatten()
multi_output = multi_output.flatten()
avg_diff = (poly_output-multi_output).abs().mean()
print(avg_diff)

tensor(5.8938e-05, device='cuda:0')
