# Weights test

This notebook is used for testing whether we have loaded the weights of polyffusion successfully from the polyffusion checkpoints.

Code scripts here can be used for partially loading the weights of polyffusion modules.

# Configs


In [1]:
import os
import torch
import yaml
import sys

MULTIPOLY_FOLDER = os.path.dirname(os.path.dirname(os.getcwd()))
POLYFFUSION_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\chkpts\weights_best.pt")
POLYFFUSION_PARAMS_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\params.yaml")
CHORD_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"pretrained\chd8bar\weights.pt")

with open(POLYFFUSION_PARAMS_PATH, 'r') as f:
    params = yaml.safe_load(f)
for key,value in params.items():
    print(key,":",value)

polyffusion_checkpoint = torch.load(POLYFFUSION_CKPT_PATH)["model"]
chord_checkpoint = torch.load(CHORD_CKPT_PATH)["model"]

sys.path.append(MULTIPOLY_FOLDER)

model_name : sdf_chd8bar
batch_size : 16
max_epoch : 100
learning_rate : 5e-05
max_grad_norm : 10
fp16 : True
num_workers : 4
pin_memory : True
in_channels : 2
out_channels : 2
channels : 64
attention_levels : [2, 3]
n_res_blocks : 2
channel_multipliers : [1, 2, 4, 4]
n_heads : 4
tf_layers : 1
d_cond : 512
linear_start : 0.00085
linear_end : 0.012
n_steps : 1000
latent_scaling_factor : 0.18215
img_h : 128
img_w : 128
cond_type : chord
cond_mode : mix
use_enc : True
chd_n_step : 32
chd_input_dim : 36
chd_z_input_dim : 512
chd_hidden_dim : 512
chd_z_dim : 512


## Vanilla

Load the weights into the vanilla polyffusion model

In [2]:
# Define models according to the settings in `polyffusion_ckpts\...\params.yaml`
from polyffusion.dl_modules import ChordEncoder, ChordDecoder
from polyffusion.stable_diffusion.model.unet import UNetModel
from polyffusion.stable_diffusion.latent_diffusion import LatentDiffusion
from polyffusion.models.model_sdf import Polyffusion_SDF

import inspect

chord_enc_params = inspect.signature(ChordEncoder.__init__).parameters
chord_enc_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_enc_params}
chord_encoder_poly = ChordEncoder(**chord_enc_params_dict)

chord_dec_params = inspect.signature(ChordDecoder.__init__).parameters
chord_dec_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_dec_params}
chord_decoder_poly = ChordDecoder(**chord_dec_params_dict)

unet_params = inspect.signature(UNetModel.__init__).parameters
unet_params_dict = {key:params[key] for key in params if key in unet_params}
unet_poly = UNetModel(**unet_params_dict)

ldm_params = inspect.signature(LatentDiffusion.__init__).parameters
ldm_params_dict = {key:params[key] for key in params if key in ldm_params}
ldm_poly = LatentDiffusion(unet_model=unet_poly, autoencoder=None, **ldm_params_dict)


polyffusion_params = inspect.signature(Polyffusion_SDF.__init__).parameters
polyffusion_params_dict = {key:params[key] for key in params if key in polyffusion_params}
polyffusion = Polyffusion_SDF(ldm=ldm_poly,chord_enc=chord_encoder_poly,chord_dec=chord_decoder_poly,**polyffusion_params_dict)

polyffusion.load_state_dict(polyffusion_checkpoint)


<All keys matched successfully>

# UNet only

Load the UNet weights only. Check if the UNet weights equal to the `eps_model` in the vanilla polyffusion model

In [3]:
unet_only = UNetModel(**unet_params_dict)
UNET_PREFIX = "ldm.eps_model."
unet_state_dict = {key.removeprefix(UNET_PREFIX):value for key,value in polyffusion_checkpoint.items() if key.startswith(UNET_PREFIX)}
unet_only.load_state_dict(unet_state_dict)

<All keys matched successfully>

In [4]:
for key in unet_only.state_dict().keys():
    assert torch.allclose(unet_only.state_dict()[key],unet_poly.state_dict()[key])
    print(key)

time_embed.0.weight
time_embed.0.bias
time_embed.2.weight
time_embed.2.bias
input_blocks.0.0.weight
input_blocks.0.0.bias
input_blocks.1.0.in_layers.0.weight
input_blocks.1.0.in_layers.0.bias
input_blocks.1.0.in_layers.2.weight
input_blocks.1.0.in_layers.2.bias
input_blocks.1.0.emb_layers.1.weight
input_blocks.1.0.emb_layers.1.bias
input_blocks.1.0.out_layers.0.weight
input_blocks.1.0.out_layers.0.bias
input_blocks.1.0.out_layers.3.weight
input_blocks.1.0.out_layers.3.bias
input_blocks.2.0.in_layers.0.weight
input_blocks.2.0.in_layers.0.bias
input_blocks.2.0.in_layers.2.weight
input_blocks.2.0.in_layers.2.bias
input_blocks.2.0.emb_layers.1.weight
input_blocks.2.0.emb_layers.1.bias
input_blocks.2.0.out_layers.0.weight
input_blocks.2.0.out_layers.0.bias
input_blocks.2.0.out_layers.3.weight
input_blocks.2.0.out_layers.3.bias
input_blocks.3.0.op.weight
input_blocks.3.0.op.bias
input_blocks.4.0.in_layers.0.weight
input_blocks.4.0.in_layers.0.bias
input_blocks.4.0.in_layers.2.weight
input_bl

# Chord VAE only

Load the Chord VAE only. Check if the Chord VAE weights equal to the pretrained Chord VAE in folder `pretrained/`

In [5]:
chord_encoder_only = ChordEncoder(**chord_enc_params_dict)
CHORD_ENC_PREFIX = "chord_enc."
chord_enc_state_dict = {key.removeprefix(CHORD_ENC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_ENC_PREFIX)}
chord_encoder_only.load_state_dict(chord_enc_state_dict)

<All keys matched successfully>

In [6]:
chord_decoder_only = ChordDecoder(**chord_dec_params_dict)
CHORD_DEC_PREFIX = "chord_dec."
chord_dec_state_dict = {key.removeprefix(CHORD_DEC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_DEC_PREFIX)}
chord_decoder_only.load_state_dict(chord_dec_state_dict)

<All keys matched successfully>

In [7]:
for key in chord_encoder_only.state_dict().keys():
    assert torch.allclose(chord_encoder_only.state_dict()[key],chord_encoder_poly.state_dict()[key])

In [8]:
for key in chord_decoder_only.state_dict().keys():
    assert torch.allclose(chord_decoder_only.state_dict()[key],chord_decoder_poly.state_dict()[key])