In [1]:
import os
import torch
import yaml
import sys

MULTIPOLY_FOLDER = os.path.dirname(os.path.dirname(os.getcwd()))
POLYFFUSION_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\chkpts\weights_best.pt")
POLYFFUSION_PARAMS_PATH = os.path.join(MULTIPOLY_FOLDER, r"polyffusion_ckpts\ldm_chd8bar\sdf+pop909wm_mix16_chd8bar\01-11_102022\params.yaml")
CHORD_CKPT_PATH = os.path.join(MULTIPOLY_FOLDER, r"pretrained\chd8bar\weights.pt")

with open(POLYFFUSION_PARAMS_PATH, 'r') as f:
    params = yaml.safe_load(f)
for key,value in params.items():
    print(key,":",value)


polyffusion_checkpoint = torch.load(POLYFFUSION_CKPT_PATH)["model"]
chord_checkpoint = torch.load(CHORD_CKPT_PATH)["model"]

sys.path.append(MULTIPOLY_FOLDER)
# Define models according to the settings in `polyffusion_ckpts\...\params.yaml`
from polyffusion.dl_modules import ChordEncoder, ChordDecoder
from polyffusion.stable_diffusion.model.unet import UNetModel as PolyffusionUNet
from src.models.unet import UNetModel as MultipolyUNet

import inspect

chord_enc_params = inspect.signature(ChordEncoder.__init__).parameters
chord_enc_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_enc_params}
chord_encoder = ChordEncoder(**chord_enc_params_dict)
CHORD_ENC_PREFIX = "chord_enc."
chord_enc_state_dict = {key.removeprefix(CHORD_ENC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_ENC_PREFIX)}
chord_encoder.load_state_dict(chord_enc_state_dict)

chord_dec_params = inspect.signature(ChordDecoder.__init__).parameters
chord_dec_params_dict = {key.removeprefix("chd_"):params[key] for key in params if key.removeprefix("chd_") in chord_dec_params}
chord_decoder = ChordDecoder(**chord_dec_params_dict)
CHORD_DEC_PREFIX = "chord_dec."
chord_dec_state_dict = {key.removeprefix(CHORD_DEC_PREFIX):value for key,value in chord_checkpoint.items() if key.startswith(CHORD_DEC_PREFIX)}
chord_decoder.load_state_dict(chord_dec_state_dict)



polyffusion_unet_params = inspect.signature(PolyffusionUNet.__init__).parameters
polyffusion_unet_params_dict = {key:params[key] for key in params if key in polyffusion_unet_params}
polyffusion_unet = PolyffusionUNet(**polyffusion_unet_params_dict)
UNET_PREFIX = "ldm.eps_model."
polyffusion_unet_state_dict = {key.removeprefix(UNET_PREFIX):value for key,value in polyffusion_checkpoint.items() if key.startswith(UNET_PREFIX)}
polyffusion_unet.load_state_dict(polyffusion_unet_state_dict)

multipoly_unet_params = inspect.signature(MultipolyUNet.__init__).parameters
multipoly_unet_params_dict = {key:params[key] for key in params if key in multipoly_unet_params}
multipoly_unet_params_dict["n_intertrack_head"] = 4
multipoly_unet_params_dict["num_intertrack_encoder_layers"] = 2
multipoly_unet_params_dict["intertrack_attention_levels"] = [3]
multipoly_unet = MultipolyUNet(**multipoly_unet_params_dict)

multipoly_unet.load_polyffusion_checkpoints(polyffusion_checkpoint)


model_name : sdf_chd8bar
batch_size : 16
max_epoch : 100
learning_rate : 5e-05
max_grad_norm : 10
fp16 : True
num_workers : 4
pin_memory : True
in_channels : 2
out_channels : 2
channels : 64
attention_levels : [2, 3]
n_res_blocks : 2
channel_multipliers : [1, 2, 4, 4]
n_heads : 4
tf_layers : 1
d_cond : 512
linear_start : 0.00085
linear_end : 0.012
n_steps : 1000
latent_scaling_factor : 0.18215
img_h : 128
img_w : 128
cond_type : chord
cond_mode : mix
use_enc : True
chd_n_step : 32
chd_input_dim : 36
chd_z_input_dim : 512
chd_hidden_dim : 512
chd_z_dim : 512
---------------loading polyffusion weights-------------------
input_blocks.10.2.attention.layers.0.self_attn.in_proj_weight
input_blocks.10.2.attention.layers.0.self_attn.in_proj_bias
input_blocks.10.2.attention.layers.0.self_attn.out_proj.weight
input_blocks.10.2.attention.layers.0.self_attn.out_proj.bias
input_blocks.10.2.attention.layers.0.linear1.weight
input_blocks.10.2.attention.layers.0.linear1.bias
input_blocks.10.2.attentio

In [2]:
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

def gather(consts: torch.Tensor, t: torch.Tensor):
    """Gather consts for $t$ and reshape to feature map shape"""
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)


class Diffusion(nn.Module):
    def __init__(
        self,
        unet_model: PolyffusionUNet,
        n_steps: int,
        linear_start: float,
        linear_end: float,
    ):
        super().__init__()
        self.eps_model = unet_model
        self.n_steps = n_steps
        beta = (
            torch.linspace(
                linear_start**0.5, linear_end**0.5, n_steps, dtype=torch.float64
            )
            ** 2
        )
        alpha = 1.0 - beta
        alpha_bar = torch.cumprod(alpha, dim=0)
        self.alpha = nn.Parameter(alpha.to(torch.float32), requires_grad=False)
        self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
        self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
        self.sigma2 = self.beta

    @property
    def device(self):
        """
        ### Get model device
        """
        return next(iter(self.eps_model.parameters())).device

    def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
        return self.eps_model(x, t, context)


    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        eps_theta = self.eps_model(xt, t)
        alpha_bar = gather(self.alpha_bar, t)
        alpha = gather(self.alpha, t)
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** 0.5
        mean = 1 / (alpha**0.5) * (xt - eps_coef * eps_theta)
        var = gather(self.sigma2, t)
        eps = torch.randn(xt.shape, device=xt.device)
        return mean + (var**0.5) * eps





In [3]:
import numpy as np
from typing import List, Optional
from labml import monit

class DiffusionSampler:
    model: Diffusion

    def __init__(self, model: Diffusion):
        """
        :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
        """
        super().__init__()
        # Set the model $\epsilon_\text{cond}(x_t, c)$
        self.model = model
        # Get number of steps the model was trained with $T$
        self.n_steps = model.n_steps        
        # Sampling steps $1, 2, \dots, T$
        self.time_steps = np.asarray(list(range(self.n_steps)), dtype=np.int32)

        with torch.no_grad():
            # $\bar\alpha_t$
            alpha_bar = self.model.alpha_bar
            # $\beta_t$ schedule
            beta = self.model.beta
            #  $\bar\alpha_{t-1}$
            alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.0]), alpha_bar[:-1]])

            # $\sqrt{\bar\alpha}$
            self.sqrt_alpha_bar = alpha_bar**0.5
            # $\sqrt{1 - \bar\alpha}$
            self.sqrt_1m_alpha_bar = (1.0 - alpha_bar) ** 0.5
            # $\frac{1}{\sqrt{\bar\alpha_t}}$
            self.sqrt_recip_alpha_bar = alpha_bar**-0.5
            # $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
            self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** 0.5

            # $\frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$
            variance = beta * (1.0 - alpha_bar_prev) / (1.0 - alpha_bar)
            # Clamped log of $\tilde\beta_t$
            self.log_var = torch.log(torch.clamp(variance, min=1e-20))
            # $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
            self.mean_x0_coef = beta * (alpha_bar_prev**0.5) / (1.0 - alpha_bar)
            # $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
            self.mean_xt_coef = (
                (1.0 - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1.0 - alpha_bar)
            )

    def get_eps(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        c: torch.Tensor,
        *,
        uncond_scale: float,
        uncond_cond: Optional[torch.Tensor],
    ):
        """
        ## Get $\epsilon(x_t, c)$

        :param x: is $x_t$ of shape `[batch_size, channels, height, width]`
        :param t: is $t$ of shape `[batch_size]`
        :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
        :param uncond_scale: is the unconditional guidance scale $s$. This is used for
            $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
        :param uncond_cond: is the conditional embedding for empty prompt $c_u$
        """
        # When the scale $s = 1$
        # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$
        if uncond_cond is None or uncond_scale == 1.0:
            return self.model(x, t, c)
        elif uncond_scale == 0.0:  # unconditional
            return self.model(x, t, uncond_cond)

        # Duplicate $x_t$ and $t$
        x_in = torch.cat([x] * 2)
        t_in = torch.cat([t] * 2)
        # Concatenated $c$ and $c_u$
        c_in = torch.cat([uncond_cond, c])
        # Get $\epsilon_\text{cond}(x_t, c)$ and $\epsilon_\text{cond}(x_t, c_u)$
        e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
        # Calculate
        # $$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$$
        e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
        return e_t


    @torch.no_grad()
    def p_sample(
        self,
        x: torch.Tensor,
        c: torch.Tensor,
        t: torch.Tensor,
        step: int,
        repeat_noise: bool = False,
        temperature: float = 1.0,
        uncond_scale: float = 1.0,
        uncond_cond: Optional[torch.Tensor] = None,
        
    ):
        """
        ### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$

        :param x: is $x_t$ of shape `[batch_size, channels, height, width]`
        :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
        :param t: is $t$ of shape `[batch_size]`
        :param step: is the step $t$ as an integer
        :repeat_noise: specified whether the noise should be same for all samples in the batch
        :param temperature: is the noise temperature (random noise gets multiplied by this)
        :param uncond_scale: is the unconditional guidance scale $s$. This is used for
            $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
        :param uncond_cond: is the conditional embedding for empty prompt $c_u$
        """

        # Get $\epsilon_\theta$
        
        e_t = self.get_eps(
                x, t, c, uncond_scale=uncond_scale, uncond_cond=uncond_cond
            )

        # Get batch size
        bs = x.shape[0]

        # $\frac{1}{\sqrt{\bar\alpha_t}}$
        sqrt_recip_alpha_bar = x.new_full(
            (bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step]
        )
        # $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
        sqrt_recip_m1_alpha_bar = x.new_full(
            (bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step]
        )

        # Calculate $x_0$ with current $\epsilon_\theta$
        #
        # $$x_0 = \frac{1}{\sqrt{\bar\alpha_t}} x_t -  \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta$$
        x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t

        # $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
        mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])
        # $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
        mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])

        # Calculate $\mu_t(x_t, t)$
        #
        # $$\mu_t(x_t, t) = \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
        #    + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t$$
        mean = mean_x0_coef * x0 + mean_xt_coef * x
        # $\log \tilde\beta_t$
        log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])

        # Do not add noise when $t = 1$ (final step sampling process).
        # Note that `step` is `0` when $t = 1$)
        if step == 0:
            noise = 0
        # If same noise is used for all samples in the batch
        elif repeat_noise:
            noise = torch.randn((1, *x.shape[1:]), device=x.device)
        # Different noise for each sample
        else:
            noise = torch.randn(x.shape, device=x.device)

        # Multiply noise by the temperature
        noise = noise * temperature

        # Sample from,
        #
        # $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$
        x_prev = mean + (0.5 * log_var).exp() * noise
        return x_prev, x0, e_t

    @torch.no_grad()
    def q_sample(
        self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None
    ):
        """
        ### Sample from $q(x_t|x_0)$

        $$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$

        :param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
        :param index: is the time step $t$ index
        :param noise: is the noise, $\epsilon$
        """

        # Random noise, if noise is not specified
        if noise is None:
            noise = torch.randn_like(x0, device=x0.device)

        # Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
        return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise

    @torch.no_grad()
    def sample(
        self,
        shape: List[int],
        cond: torch.Tensor,
        repeat_noise: bool = False,
        temperature: float = 1.0,
        x_last: Optional[torch.Tensor] = None,
        uncond_scale: float = 1.0,
        uncond_cond: Optional[torch.Tensor] = None,
        t_start: int = 0,
    ):
        """
        ### Sampling Loop

        :param shape: is the shape of the generated images in the
            form `[batch_size, channels, height, width]`
        :param cond: is the conditional embeddings $c$
        :param temperature: is the noise temperature (random noise gets multiplied by this)
        :param x_last: is $x_T$. If not provided random noise will be used.
        :param uncond_scale: is the unconditional guidance scale $s$. This is used for
            $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
        :param uncond_cond: is the conditional embedding for empty prompt $c_u$
        :param skip_steps: is the number of time steps to skip $t'$. We start sampling from $T - t'$.
            And `x_last` is then $x_{T - t'}$.
        """

        # Get device and batch size
        bs = shape[0]

        # Get $x_T$
        x = x_last if x_last is not None else torch.randn(shape, device=cond.device)

        # Time steps to sample at $T - t', T - t' - 1, \dots, 1$
        time_steps = np.flip(self.time_steps)[t_start:]

        # Sampling loop
        from tqdm import tqdm
        for step in tqdm( time_steps):
            # Time step $t$
            ts = x.new_full((bs,), step, dtype=torch.long)

            # Sample $x_{t-1}$
            x, pred_x0, e_t = self.p_sample(
                x,
                cond,
                ts,
                step,
                repeat_noise=repeat_noise,
                temperature=temperature,
                uncond_scale=uncond_scale,
                uncond_cond=uncond_cond,
            )
            
        return x

   


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

diffusion = Diffusion(
    unet_model=polyffusion_unet,n_steps=params["n_steps"],linear_start=params["linear_start"],linear_end=params["linear_end"]
).to(device)
sampler = DiffusionSampler(diffusion)

uncond_cond = -torch.ones([10, 1, 512]).to(device)



In [5]:
from polyffusion.data.midi_to_data import get_data_for_single_midi
from polyffusion.data.datasample import DataSample
from polyffusion.inference_sdf import get_data_preprocessed
data = get_data_for_single_midi("test_data/pop909.mid","test_data/chord.out")
data_sample = DataSample(data)
print(data)

prmat2c, pnotree, chd, prmat = get_data_preprocessed(
                    data_sample, "cond"
                )



{'notes': array([[  14,   66,    2,  121,    0],
       [  16,   47,    6,   65,    0],
       [  16,   75,    2,  121,    0],
       ...,
       [1152,   58,   11,  104,    0],
       [1153,   61,   10,   87,    0],
       [1153,   66,   10,   73,    0]]), 'start_table': array({0: 0, 16: 1, 32: 16, 48: 28, 64: 44, 80: 61, 96: 81, 112: 100, 128: 116, 144: 134, 160: 146, 176: 158, 192: 168, 208: 197, 224: 209, 240: 222, 256: 231, 272: 261, 288: 281, 304: 300, 320: 316, 336: 345, 352: 365, 368: 384, 384: 400, 400: 422, 416: 437, 432: 454, 448: 478, 464: 493, 480: 505, 496: 517, 512: 527, 528: 556, 544: 568, 560: 581, 576: 590, 592: 620, 608: 640, 624: 659, 640: 675, 656: 704, 672: 724, 688: 743, 704: 759, 720: 781, 736: 799, 752: 823, 768: 844, 784: 870, 800: 888, 816: 912, 832: 933, 848: 948, 864: 962, 880: 978, 896: 992, 912: 1014, 928: 1028, 944: 1044, 960: 1060, 976: 1082, 992: 1100, 1008: 1124, 1024: 1147, 1040: 1165, 1056: 1179, 1072: 1195, 1088: 1211, 1104: 1230, 1120: 1250, 1136:

In [6]:
print(chd.shape)
chord_encoder = chord_encoder.to(device)
z = chord_encoder(chd).mean
z = z.unsqueeze(1)
print(z.shape)


torch.Size([10, 32, 36])
torch.Size([10, 1, 512])


In [7]:

cond = z

# gen = sampler.sample([cond.shape[0],2,128,128],cond,uncond_scale=2.0,uncond_cond=uncond_cond)

In [8]:
import pretty_midi as pm

def custom_round(x):
    if x > 0.95 and x < 1.05:
        return 1
    else:
        return 0

def prmat2c_to_midi_file(
    prmat2c, fpath, is_custom_round=True
):
    print(f"prmat2c : {prmat2c.shape}")
    midi = pm.PrettyMIDI()
    piano_program = pm.instrument_name_to_program("Acoustic Grand Piano")
    origin = pm.Instrument(program=piano_program)
    t = 0
    n_step = prmat2c.shape[2]
    t_bar = int(n_step / 8)
    for bar_ind, bars in enumerate(prmat2c):
        onset = bars[0]
        sustain = bars[1]
        for step_ind, step in enumerate(onset):
            for key, on in enumerate(step):
                if is_custom_round:
                    on = int(custom_round(on))
                else:
                    on = int(round(on))
                if on > 0:
                    dur = 1
                    while step_ind + dur < n_step:
                        if not (int(round(sustain[step_ind + dur, key])) > 0):
                            break
                        dur += 1
                    note = pm.Note(
                        velocity=80,
                        pitch=key,
                        start=t + step_ind * 1 / 8,
                        end=min(t + (step_ind + dur) * 1 / 8, t + t_bar),
                    )
                    
                    origin.notes.append(note)
        t += t_bar
    midi.instruments.append(origin)
    midi.write(fpath)

In [9]:
# prmat2c = gen.cpu().numpy()
# prmat2c_to_midi_file(prmat2c, "exp/test1.mid")

In [None]:



class MultiDiffusion(nn.Module):
    def __init__(
        self,
        unet_model: MultipolyUNet,
        n_steps: int,
        linear_start: float,
        linear_end: float,
    ):
        super().__init__()
        self.eps_model = unet_model
        self.n_steps = n_steps
        beta = (
            torch.linspace(
                linear_start**0.5, linear_end**0.5, n_steps, dtype=torch.float64
            )
            ** 2
        )
        alpha = 1.0 - beta
        alpha_bar = torch.cumprod(alpha, dim=0)
        self.alpha = nn.Parameter(alpha.to(torch.float32), requires_grad=False)
        self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
        self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
        self.sigma2 = self.beta

    @property
    def device(self):
        """
        ### Get model device
        """
        return next(iter(self.eps_model.parameters())).device

    def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
        return self.eps_model(x, t, context)





class MultiDiffusionSampler:
    model: MultiDiffusion

    def __init__(self, model: MultiDiffusion):
        super().__init__()
        self.model = model
        self.n_steps = model.n_steps
        self.time_steps = np.asarray(list(range(self.n_steps)), dtype=np.int32)
        with torch.no_grad():
            alpha_bar = self.model.alpha_bar
            beta = self.model.beta
            alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.0]), alpha_bar[:-1]])
            self.sqrt_alpha_bar = alpha_bar**0.5
            self.sqrt_1m_alpha_bar = (1.0 - alpha_bar) ** 0.5
            self.sqrt_recip_alpha_bar = alpha_bar**-0.5
            self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** 0.5
            variance = beta * (1.0 - alpha_bar_prev) / (1.0 - alpha_bar)
            self.log_var = torch.log(torch.clamp(variance, min=1e-20))
            self.mean_x0_coef = beta * (alpha_bar_prev**0.5) / (1.0 - alpha_bar)
            self.mean_xt_coef = (
                (1.0 - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1.0 - alpha_bar)
            )

    def get_eps(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        c: torch.Tensor,
        *,
        uncond_scale: float,
        uncond_cond: Optional[torch.Tensor],
    ):
        if uncond_cond is None or uncond_scale == 1.0:
            return self.model(x, t, c)
        elif uncond_scale == 0.0:  # unconditional
            return self.model(x, t, uncond_cond)

        x_in = torch.cat([x] * 2)
        t_in = torch.cat([t] * 2)
        c_in = torch.cat([uncond_cond, c])
        e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
        e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
        return e_t


    @torch.no_grad()
    def p_sample(
        self,
        x: torch.Tensor,
        c: torch.Tensor,
        t: torch.Tensor,
        step: int,
        repeat_noise: bool = False,
        temperature: float = 1.0,
        uncond_scale: float = 1.0,
        uncond_cond: Optional[torch.Tensor] = None,
        
    ):
        e_t = self.get_eps(
                x, t, c, uncond_scale=uncond_scale, uncond_cond=uncond_cond
            )

        bs,track_num = x.shape[0],x.shape[1]

        sqrt_recip_alpha_bar = x.new_full(
            (bs,track_num, 1, 1, 1), self.sqrt_recip_alpha_bar[step]
        )
        sqrt_recip_m1_alpha_bar = x.new_full(
            (bs,track_num, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step]
        )

        x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t

        mean_x0_coef = x.new_full((bs,track_num, 1, 1, 1), self.mean_x0_coef[step])
        mean_xt_coef = x.new_full((bs,track_num, 1, 1, 1), self.mean_xt_coef[step])
        mean = mean_x0_coef * x0 + mean_xt_coef * x
        log_var = x.new_full((bs,track_num, 1, 1, 1), self.log_var[step])
        if step == 0:
            noise = 0
        elif repeat_noise:
            noise = torch.randn((1, *x.shape[1:]), device=x.device)
        else:
            noise = torch.randn(x.shape, device=x.device)
        noise = noise * temperature
        x_prev = mean + (0.5 * log_var).exp() * noise
        return x_prev, x0, e_t

    @torch.no_grad()
    def q_sample(
        self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None
    ):
        """
        ### Sample from $q(x_t|x_0)$

        $$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$

        :param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
        :param index: is the time step $t$ index
        :param noise: is the noise, $\epsilon$
        """

        # Random noise, if noise is not specified
        if noise is None:
            noise = torch.randn_like(x0, device=x0.device)

        # Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
        return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise

    @torch.no_grad()
    def sample(
        self,
        shape: List[int],
        cond: torch.Tensor,
        repeat_noise: bool = False,
        temperature: float = 1.0,
        x_last: Optional[torch.Tensor] = None,
        uncond_scale: float = 1.0,
        uncond_cond: Optional[torch.Tensor] = None,
        t_start: int = 0,
    ):
        """
        ### Sampling Loop

        :param shape: is the shape of the generated images in the
            form `[batch_size, channels, height, width]`
        :param cond: is the conditional embeddings $c$
        :param temperature: is the noise temperature (random noise gets multiplied by this)
        :param x_last: is $x_T$. If not provided random noise will be used.
        :param uncond_scale: is the unconditional guidance scale $s$. This is used for
            $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
        :param uncond_cond: is the conditional embedding for empty prompt $c_u$
        :param skip_steps: is the number of time steps to skip $t'$. We start sampling from $T - t'$.
            And `x_last` is then $x_{T - t'}$.
        """

        # Get device and batch size
        bs, track_num = shape[0],shape[1]

        # Get $x_T$
        x = x_last if x_last is not None else torch.randn(shape, device=cond.device)

        # Time steps to sample at $T - t', T - t' - 1, \dots, 1$
        time_steps = np.flip(self.time_steps)[t_start:]

        # Sampling loop
        from tqdm import tqdm
        for step in tqdm( time_steps):
            # Time step $t$
            ts = x.new_full((bs*track_num,), step, dtype=torch.long)

            # Sample $x_{t-1}$
            x, pred_x0, e_t = self.p_sample(
                x,
                cond,
                ts,
                step,
                repeat_noise=repeat_noise,
                temperature=temperature,
                uncond_scale=uncond_scale,
                uncond_cond=uncond_cond,
            )
            
        return x

 

In [11]:
multidiffusion = MultiDiffusion(
    unet_model=multipoly_unet,n_steps=params["n_steps"],linear_start=params["linear_start"],linear_end=params["linear_end"]
).to(device)
multisampler = MultiDiffusionSampler(multidiffusion)
multidiffusion.eval()
uncond_cond = -torch.ones([10*4, 1, 512]).to(device)
multi_cond = torch.cat([cond]*4, dim=1)
print(multi_cond.shape)
multi_cond = multi_cond.reshape(40,1,512).to(device)


torch.Size([10, 4, 512])


In [12]:
with torch.no_grad():
    gen = multisampler.sample([10, 4, 2, 128, 128], multi_cond, uncond_cond=uncond_cond, uncond_scale=2.0)

100%|██████████| 1000/1000 [41:18<00:00,  2.48s/it]


In [16]:
print(gen.shape)
for i in range(4):
    track_gen = gen[:,i].cpu().numpy()
    prmat2c_to_midi_file(track_gen, f"exp/track{i}.mid")
    

torch.Size([10, 4, 2, 128, 128])
prmat2c : (10, 2, 128, 128)
prmat2c : (10, 2, 128, 128)
prmat2c : (10, 2, 128, 128)
prmat2c : (10, 2, 128, 128)
