In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from math import sqrt


Linear = nn.Linear
ConvTranspose2d = nn.ConvTranspose2d


def Conv1d(*args, **kwargs):
    layer = nn.Conv1d(*args, **kwargs)
    nn.init.kaiming_normal_(layer.weight)
    return layer


@torch.jit.script
def silu(x):
    return x * torch.sigmoid(x)


class DiffusionEmbedding(nn.Module):
    def __init__(self, max_steps):
        super().__init__()
        self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False)
        self.projection1 = Linear(128, 512)
        self.projection2 = Linear(512, 512)

    def forward(self, diffusion_step):
        if diffusion_step.dtype in [torch.int32, torch.int64]:
            x = self.embedding[diffusion_step]
        else:
            x = self._lerp_embedding(diffusion_step)
        x = self.projection1(x)
        x = silu(x)
        x = self.projection2(x)
        x = silu(x)
        return x

    def _lerp_embedding(self, t):
        low_idx = torch.floor(t).long()
        high_idx = torch.ceil(t).long()
        low = self.embedding[low_idx]
        high = self.embedding[high_idx]
        return low + (high - low) * (t - low_idx)

    def _build_embedding(self, max_steps):
        steps = torch.arange(max_steps).unsqueeze(1)  # [T,1]
        dims = torch.arange(64).unsqueeze(0)          # [1,64]
        table = steps * 10.0**(dims * 4.0 / 63.0)     # [T,64]
        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
        return table


class SpectrogramUpsampler(nn.Module):
    def __init__(self):
            super().__init__()
            self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
            self.conv2 = ConvTranspose2d(1, 1,  [3, 32], stride=[1, 16], padding=[1, 8])

    def forward(self, x):
            x = torch.unsqueeze(x, 1)
            x = self.conv1(x)
            x = F.leaky_relu(x, 0.4)
            x = self.conv2(x)
            x = F.leaky_relu(x, 0.4)
            x = torch.squeeze(x, 1)
            return x


class ResidualBlock(nn.Module):
    def __init__(self, n_mels, residual_channels, dilation, uncond=False):
        '''
        :param n_mels: inplanes of conv1x1 for spectrogram conditional
        :param residual_channels: audio conv
        :param dilation: audio conv dilation
        :param uncond: disable spectrogram conditional
        '''
        super().__init__()
        self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
        self.diffusion_projection = Linear(512, residual_channels)
        if not uncond: # conditional model
            self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
        else: # unconditional model
            self.conditioner_projection = None

        self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)

    def forward(self, x, diffusion_step, conditioner=None):
        assert (conditioner is None and self.conditioner_projection is None) or \
            (conditioner is not None and self.conditioner_projection is not None)

        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
        y = x + diffusion_step
        if self.conditioner_projection is None: # using a unconditional model
            y = self.dilated_conv(y)
        else:
            conditioner = self.conditioner_projection(conditioner)
            y = self.dilated_conv(y) + conditioner

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)

        y = self.output_projection(y)
        residual, skip = torch.chunk(y, 2, dim=1)
        return (x + residual) / sqrt(2.0), skip


class DiffWave(nn.Module):
    def __init__(
        self,
        params
    ):
        super().__init__()
        self.params = params
        self.input_projection = Conv1d(1, params.residual_channels, 1)
        self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
        if self.params.unconditional: # use unconditional model
            self.spectrogram_upsampler = None
        else:
            self.spectrogram_upsampler = SpectrogramUpsampler(params.n_mels)

        self.residual_layers = nn.ModuleList([
            ResidualBlock(params.n_mels, params.residual_channels, 2**(i % params.dilation_cycle_length), uncond=params.unconditional)
            for i in range(params.residual_layers)
        ])
        self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1)
        self.output_projection = Conv1d(params.residual_channels, 1, 1)
        nn.init.zeros_(self.output_projection.weight)

    def forward(self, audio, diffusion_step, spectrogram=None):
        assert (spectrogram is None and self.spectrogram_upsampler is None) or \
            (spectrogram is not None and self.spectrogram_upsampler is not None)
        x = audio.unsqueeze(1)
        x = self.input_projection(x)
        x = F.relu(x)

        diffusion_step = self.diffusion_embedding(diffusion_step)
        if self.spectrogram_upsampler: # use conditional model
            spectrogram = self.spectrogram_upsampler(spectrogram)

        skip = None
        for layer in self.residual_layers:
            x, skip_connection = layer(x, diffusion_step, spectrogram)
            skip = skip_connection if skip is None else skip_connection + skip

        x = skip / sqrt(len(self.residual_layers))
        x = self.skip_projection(x)
        x = F.relu(x)
        x = self.output_projection(x)
        return x

In [None]:
from dataclasses import dataclass


class ResidualBlock(nn.Module):
    def __init__(self, residual_channels: int, dilation: int, n_mels: int = None):
        super().__init__()
        self.dilated_conv = Conv1d(
            residual_channels,
            2 * residual_channels,
            3,
            padding=dilation,
            dilation=dilation
        )
        self.diffusion_projection = Linear(512, residual_channels)

        if n_mels is not None: # conditional model
            self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
        else: # unconditional model
            self.conditioner_projection = None

        self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)

    def forward(self, x, diffusion_step, conditioner=None):
        assert (conditioner is None and self.conditioner_projection is None) or \
            (conditioner is not None and self.conditioner_projection is not None)

        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
        y = x + diffusion_step
        if self.conditioner_projection is None: # using a unconditional model
            y = self.dilated_conv(y)
        else:
            conditioner = self.conditioner_projection(conditioner)
            y = self.dilated_conv(y) + conditioner

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)

        y = self.output_projection(y)
        residual, skip = torch.chunk(y, 2, dim=1)
        return (x + residual) / sqrt(2.0), skip


class DiffWave1D(nn.Module):
    def __init__(
        self,
        residual_channels: int,
        n_residual_layers: int,
        dilation_cycle_length: int,
        diffusion_embedding: DiffusionEmbedding,
        n_mels: int = None,
    ):
        super().__init__()
        self.residual_channels = residual_channels
        self.diffusion_embedding = diffusion_embedding
        self.n_mels = n_mels
        self.spectrogram_upsampler = SpectrogramUpsampler() if self.n_mels is not None else None

        self.input_projection = nn.Conv1d(1, self.residual_channels, 1)
        self.output_projection = nn.Conv1d(self.residual_channels, 1, 1)
        nn.init.zeros_(self.output_projection.weight) # zero init for output projection
        self.residual_layers = nn.ModuleList([
            ResidualBlock(params.n_mels, params.residual_channels, 2**(i % params.dilation_cycle_length), uncond=params.unconditional)
            for i in range(params.residual_layers)
        ])
        self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1)
        

    def forward(self, audio, diffusion_step, spectrogram=None):
        assert (spectrogram is None and self.spectrogram_upsampler is None) or \
            (spectrogram is not None and self.spectrogram_upsampler is not None)
        x = audio.unsqueeze(1)
        x = self.input_projection(x)
        x = F.relu(x)

        diffusion_step = self.diffusion_embedding(diffusion_step)
        if self.spectrogram_upsampler: # use conditional model
            spectrogram = self.spectrogram_upsampler(spectrogram)

        skip = None
        for layer in self.residual_layers:
            x, skip_connection = layer(x, diffusion_step, spectrogram)
            skip = skip_connection if skip is None else skip_connection + skip

        x = skip / sqrt(len(self.residual_layers))
        x = self.skip_projection(x)
        x = F.relu(x)
        x = self.output_projection(x)
        return x

In [2]:

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
            super(AttrDict, self).__init__(*args, **kwargs)
            self.__dict__ = self

    def override(self, attrs):
        if isinstance(attrs, dict):
            self.__dict__.update(**attrs)
        elif isinstance(attrs, (list, tuple, set)):
            for attr in attrs:
                self.override(attr)
        elif attrs is not None:
            raise NotImplementedError
        return self

In [3]:
params = AttrDict(
    # Training params
    batch_size=16,
    learning_rate=2e-4,
    max_grad_norm=None,

    # Data params
    sample_rate=22050,
    n_mels=80,
    n_fft=1024,
    hop_samples=256,
    crop_mel_frames=62,  # Probably an error in paper.

    # Model params
    residual_layers=30,
    residual_channels=64,
    dilation_cycle_length=10,
    unconditional = False,
    noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(),
    inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],

    # unconditional sample len
    audio_len = 22050*5, # unconditional_synthesis_samples
)

In [4]:
params

{'batch_size': 16,
 'learning_rate': 0.0002,
 'max_grad_norm': None,
 'sample_rate': 22050,
 'n_mels': 80,
 'n_fft': 1024,
 'hop_samples': 256,
 'crop_mel_frames': 62,
 'residual_layers': 30,
 'residual_channels': 64,
 'dilation_cycle_length': 10,
 'unconditional': False,
 'noise_schedule': [0.0001,
  0.0011183673469387756,
  0.002136734693877551,
  0.0031551020408163264,
  0.004173469387755102,
  0.005191836734693878,
  0.006210204081632653,
  0.007228571428571429,
  0.008246938775510203,
  0.009265306122448979,
  0.010283673469387754,
  0.01130204081632653,
  0.012320408163265305,
  0.013338775510204081,
  0.014357142857142857,
  0.015375510204081632,
  0.016393877551020408,
  0.017412244897959183,
  0.01843061224489796,
  0.019448979591836734,
  0.02046734693877551,
  0.021485714285714285,
  0.02250408163265306,
  0.023522448979591836,
  0.02454081632653061,
  0.025559183673469387,
  0.026577551020408163,
  0.027595918367346938,
  0.028614285714285714,
  0.02963265306122449,
  0.030

In [56]:
class DiffusionEmbedding(nn.Module):
    """
    PyTorch Module for performing Diffusion Embedding based on specified parameters.

    Implements an architecture that utilizes a learned embedding for diffusion steps.

    Args:
        max_steps (int): The maximum number of diffusion steps.
        embedding_dim (int): The dimension of the diffusion embedding (default is 64).
        diffusion_dim (int): The dimension of the diffusion process (default is 512).

    Attributes:
        max_steps (int): The maximum number of diffusion steps.
        embedding_dim (int): The dimension of the diffusion embedding.
        diffusion_dim (int): The dimension of the diffusion process.
        embedding (Tensor): The precomputed embedding for diffusion steps.
        input_projection (Linear): Input projection layer.
        output_projection (Linear): Output projection layer.

    Methods:
        _build_embedding(): Builds the diffusion embedding table based on the max_steps and embedding_dim.
        _lerp_embedding(t: float): Performs linear interpolation for the diffusion embedding at a specific time step.
        forward(diffusion_step: int | float): Forward pass through the diffusion embedding model.

    References:
        Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017).
        "Attention is All you Need." arXiv (Cornell University), 30, 5998–6008.
        [https://arxiv.org/pdf/1706.03762v5](https://arxiv.org/pdf/1706.03762v5)
    """
    def __init__(
        self,
        max_steps: int,
        embedding_dim: int = 64,
        diffusion_dim: int = 512,
    ):
        super().__init__()
        self.max_steps = max_steps
        self.embedding_dim = embedding_dim
        self.diffusion_dim = diffusion_dim
        self.register_buffer('embedding', self._build_embedding(), persistent=False)
        self.input_projection = Linear(self.embedding_dim*2, self.diffusion_dim)
        self.output_projection = Linear(self.diffusion_dim, self.diffusion_dim)

    def _build_embedding(self):
        """
        Builds the diffusion embedding based on the maximum steps and embedding dimension.
        Generates the embedding table using sinusoidal functions.
        """
        steps = torch.arange(self.max_steps).unsqueeze(1)  # [T,1]
        dims = torch.arange(self.embedding_dim).unsqueeze(0)          # [1,D_e]
        table = steps * 10.0**(dims * 4.0 / (self.embedding_dim - 1))     # [T,D_e]
        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
        return table

    def _lerp_embedding(self, t: float):
        """
        Performs linear interpolation for the diffusion embedding at a specific time step (t).
        Utilizes the precomputed embedding table to interpolate values for non-integer time steps.
        """
        low_idx = torch.floor(t).long()
        high_idx = torch.ceil(t).long()
        low = self.embedding[low_idx]
        high = self.embedding[high_idx]
        return low + (high - low) * (t - low_idx)

    def forward(self, diffusion_step: int | float):
        """
        Forward pass through the diffusion embedding model.

        Args:
            diffusion_step (int | float): The diffusion step for which to compute the embedding.

        Returns:
            Tensor: The output tensor resulting from the diffusion embedding process.
        """
        if isinstance(diffusion_step, (int, torch.int32, torch.int64)):
            x = self.embedding[diffusion_step]
        else:
            x = self._lerp_embedding(diffusion_step)
        x = self.input_projection(x)
        x = nn.functional.silu(x)
        x = self.output_projection(x)
        x = nn.functional.silu(x)
        return x


In [63]:
embedding = DiffusionEmbedding(100)
embedding(10).shape

torch.Size([512])

In [177]:
class SpectrogramUpsampler(nn.Module):
    """
    PyTorch Module for upsampling spectrogram data using transpose convolutional layers.

    Args:
        n_channels (int): Number of input and output channels (default is 1).
        kernel_size (tuple[int, int]): Size of the convolutional kernel (default is (3, 32)).
        stride (tuple[int, int]): Stride value for the convolutional operation (default is (1, 16)).
        padding (tuple[int, int]): Padding applied to the input tensor (default is (1, 8)).
        negative_slope (float): Slope value for the LeakyReLU activation (default is 0.4).
        n_layers (int): Number of upsampling layers (default is 2).

    Attributes:
        upsampler (Sequential): Sequential module consisting of ConvTranspose2d layers followed by LeakyReLU.

    Methods:
        forward(x): Performs a forward pass through the upsampler module.

    """
    def __init__(
        self,
        n_channels: int = 1,
        kernel_size: tuple[int, int] = (3, 32),
        stride: tuple[int, int] = (1, 16),
        padding: tuple[int, int] = (1, 8),
        negative_slope=0.4,
        n_layers: int = 2,
    ):
        super().__init__()
        self.upsampler = nn.Sequential(
            *(
                nn.Sequential(
                    ConvTranspose2d(n_channels, n_channels, kernel_size, stride, padding),
                    nn.LeakyReLU(negative_slope),
                ) for _ in range(n_layers)
            )
        )

    def forward(self, x: torch.Tensor):
        """
        Performs an upsampling operation on the input tensor using transpose convolutional layers.

        Args:
            x (Tensor): Input tensor representing the spectrogram data.
                        Shape should be n_batch x n_channels x n_mel_frames x n_mel_bins.

        Returns:
            Tensor: Output tensor after the upsampling operation.
        """
        # x.shape ~ n_batch x n_channels x n_mel_frames x n_mel_bins
        x = self.upsampler(x)
        return x

In [178]:
updampler = SpectrogramUpsampler(n_channels=2)
spectrogram = torch.randn(10, 2, 16, 80)
upsampled = updampler(spectrogram)
upsampled.shape

torch.Size([10, 2, 16, 20480])

In [346]:
class ResidualBlock(nn.Module):
    """
    PyTorch Module for a residual block in a DiffWave model.

    Args:
        n_residual_channels (int): Number of residual channels in the block.
        dilation (int): Dilation value for the convolutional layers.
        kernel (int, optional): Kernel size for dilated convolution (default is 3).
        diffusion_dim (int, optional): Dimension for diffusion projection (default is 512).
        n_mels (int, optional): Number of Mel-spectrogram channels. If None, unconditional model is used (default is None).
        n_conditional_channels (int, optional): Number of channels for conditional melspectrogram (default is 1).

    Attributes:
        dilated_conv (Conv1d): Dilated convolutional layer in the residual block.
        diffusion_projection (Linear): Projection layer for diffusion dimension.
        conditioner_projection (Conv2d or None): Conditional projection layer (None for unconditional model).
        output_projection (Conv1d): Output projection layer for the residual block.

    Methods:
        forward(x, diffusion_step, conditioner=None): Performs a forward pass through the residual block.

    """
    def __init__(
        self,
        n_residual_channels: int,
        dilation: int,
        kernel: int = 3,
        diffusion_dim: int = 512,
        n_mels: int = None,
        n_conditional_channels: int = 1,
    ):
        super().__init__()
        self.dilated_conv = nn.Conv1d(
            n_residual_channels,
            2 * n_residual_channels,
            kernel,
            padding='same',
            dilation=dilation
        )
        self.diffusion_projection = Linear(diffusion_dim, n_residual_channels)

        if n_mels is not None: # conditional model
            self.conditioner_projection = nn.Conv2d(n_mels, 2 * n_residual_channels, (n_conditional_channels, 1))
        else: # unconditional model
            self.conditioner_projection = None

        self.output_projection = nn.Conv1d(n_residual_channels, 2 * n_residual_channels, 1)

    def forward(
        self,
        x: torch.Tensor,
        diffusion_step: torch.Tensor,
        conditioner: torch.Tensor = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Performs the forward pass through the ResidualBlock.

        Args:
            x (torch.Tensor): Input tensor to the block.
            diffusion_step (torch.Tensor): Tensor representing the diffusion step.
            conditioner (torch.Tensor, optional): Conditioning tensor for the model (default is None).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Tuple of the output and the intermediate skip connection.

        """
        if (conditioner is None) != (self.conditioner_projection is None):
            raise ValueError('Conditioner and projection should be both None or not None')

        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
        y = x + diffusion_step

        if self.conditioner_projection is None: # using a unconditional model
            y = self.dilated_conv(y)
        else:
            y = self.dilated_conv(y) + torch.squeeze(self.conditioner_projection(conditioner), 2)

        gate, filter_ = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter_)
        y = self.output_projection(y)
        residual, skip = torch.chunk(y, 2, dim=1)
        return (x + residual) / sqrt(2.0), skip

In [249]:
X = torch.randn(10, 20, 100)
embedding = DiffusionEmbedding(100)
res_layer = ResidualBlock(20, 1)
x, skip = res_layer(X, embedding(10))
x.shape, skip.shape

(torch.Size([10, 20, 100]), torch.Size([10, 20, 100]))

In [250]:
X = torch.randn(10, 20, 101)
x, skip = res_layer(X, embedding(10))
x.shape, skip.shape

(torch.Size([10, 20, 101]), torch.Size([10, 20, 101]))

In [348]:
updampler = SpectrogramUpsampler(n_channels=1)
spectrogram = torch.randn(10, 1, 16, 100)
upsampled = updampler(spectrogram)
X = torch.randn(10, 20, 25600)
upsampled = torch.permute(upsampled, (0, 2, 1, 3))
print(upsampled.shape)

embedding = DiffusionEmbedding(100)
x, skip = ResidualBlock(20, 1, n_mels=16, n_conditional_channels=1)(X, embedding(10), upsampled)
x.shape, skip.shape

torch.Size([10, 16, 1, 25600])


(torch.Size([10, 20, 25600]), torch.Size([10, 20, 25600]))

In [342]:
X = torch.randn(10, 20, 100)
embedding = DiffusionEmbedding(100)
spectrogram = torch.randn(10, 16, 1, 100)

x, skip = ResidualBlock(20, 1, n_mels=16, n_conditional_channels=1)(X, embedding(10), spectrogram)
x.shape, skip.shape

torch.Size([10, 40, 100]) torch.Size([10, 40, 100])


(torch.Size([10, 20, 100]), torch.Size([10, 20, 100]))

In [297]:
from dataclasses import dataclass

@dataclass
class ConditionerParams:
    """
    Data class defining parameters for the conditioner model.

    Args:
        n_mels (int): Number of mel-frequency bands.
        upsampler (SpectrogramUpsampler): Spectrogram upsampler for the conditioner model (default is SpectrogramUpsampler).
        n_channels (int): Number of input channels (default is 1).
        kernel (tuple[int, int]): Size of the convolutional kernel (default is (3, 32)).
        stride (tuple[int, int]): Stride value for the convolutional operation (default is (1, 16)).
        padding (tuple[int, int]): Padding applied to the input tensor (default is (1, 8)).
        negative_slope (float): Slope value for the LeakyReLU activation (default is 0.4).
        n_layers (int): Number of upsampling layers (default is 2).

    Attributes:
        n_mels (int): Number of mel-frequency bands.
        upsampler (SpectrogramUpsampler): Spectrogram upsampler instance.
        n_channels (int): Number of input channels for the conditioner model.
        kernel (tuple[int, int]): Size of the convolutional kernel.
        stride (tuple[int, int]): Stride value for the convolutional operation.
        padding (tuple[int, int]): Padding applied to the input tensor.
        negative_slope (float): Slope value for the LeakyReLU activation.
        n_layers (int): Number of upsampling layers.

    Methods:
        __post_init__(): Initializes the upsampler instance using the provided parameters.

    """
    n_mels: int
    upsampler: SpectrogramUpsampler = SpectrogramUpsampler
    n_channels: int = 1
    kernel: tuple[int, int] = (3, 32)
    stride: tuple[int, int] = (1, 16)
    padding: tuple[int, int] = (1, 8)
    negative_slope: float = 0.4
    n_layers: int = 2

    def __post_init__(self):
        self.upsampler = self.upsampler(
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            negative_slope=self.negative_slope,
            n_layers=self.n_layers,
        )

@dataclass
class ResidualParams:
    """
    Data class defining parameters for the residual blocks.

    Args:
        n_residual_layers (int): Number of residual layers.
        n_residual_channels (int): Number of channels for the residual blocks.
        dilation_cycle_length (int): Length of the dilation cycle.
        kernel (int): Size of the convolutional kernel (default is 3).
        conditioner (ConditionerParams): Conditioner parameters (default is None).

    Attributes:
        n_residual_layers (int): Number of residual layers.
        n_residual_channels (int): Number of channels for the residual blocks.
        dilation_cycle_length (int): Length of the dilation cycle.
        kernel (int): Size of the convolutional kernel.
        conditioner (ConditionerParams): Conditioner parameters.

    """
    n_residual_layers: int
    n_residual_channels: int
    dilation_cycle_length: int
    kernel: int = 3
    conditioner: ConditionerParams = None


@dataclass
class DiffusionParams:
    """
    Data class defining parameters for the diffusion model.

    Args:
        max_steps (int): Maximum number of steps in diffusion.
        embedding_dim (int): Dimension for the embedding (default is 64).
        diffusion_dim (int): Dimension for diffusion (default is 512).
        embedding (DiffusionEmbedding): Embedding instance for diffusion.

    Attributes:
        max_steps (int): Maximum number of steps in diffusion.
        embedding_dim (int): Dimension for the embedding.
        diffusion_dim (int): Dimension for diffusion.
        embedding (DiffusionEmbedding): Embedding instance for diffusion.

    Methods:
        __post_init__(): Initializes the embedding instance using the provided parameters.

    """
    max_steps: int
    embedding_dim: int = 64
    diffusion_dim: int = 512
    embedding: DiffusionEmbedding = DiffusionEmbedding

    def __post_init__(self):
        self.embedding = self.embedding(self.max_steps, self.embedding_dim, self.diffusion_dim)

In [349]:

class DiffWave(nn.Module):
    """
    PyTorch Module for a DiffWave model implementing diffusion probabilistic models.

    Args:
        input_channels (int): Number of input channels.
        diffusion_params (DiffusionParams): Parameters for the diffusion process.
        residual_params (ResidualParams): Parameters for the residual blocks.

    Attributes:
        input_channels (int): Number of input channels for the model.
        diffusion_params (DiffusionParams): Parameters for the diffusion process.
        residual_params (ResidualParams): Parameters for the residual blocks.
        diffusion_embedding: The diffusion embedding to be used in the model.
        spectrogram_upsampler: Upsampler for the input spectrogram.
        input_projection (Conv1d): Input projection layer.
        output_projection (Conv1d): Output projection layer.
        residual_layers (ModuleList): List of residual blocks in the model.
        skip_projection (Conv1d): Projection for skip connection.

    Methods:
        forward(input_, diffusion_step, spectrogram=None): Performs a forward pass through the DiffWave model.
    """
    def __init__(
        self,
        input_channels: int,
        diffusion_params: DiffusionParams,
        residual_params: ResidualParams
    ):
        super().__init__()
        self.input_channels = input_channels
        self.diffusion_params = diffusion_params
        self.residual_params = residual_params

        self.diffusion_embedding = self.diffusion_params.embedding
        self.spectrogram_upsampler = None if self.residual_params.conditioner is None\
            else self.residual_params.conditioner.upsampler

        self.input_projection = nn.Conv1d(
            self.input_channels,
            self.residual_params.n_residual_channels,
            1
        )
        self.output_projection = nn.Conv1d(
            self.residual_params.n_residual_channels,
            self.input_channels,
            1
        )
        nn.init.zeros_(self.output_projection.weight) # zero init for output projection

        if self.residual_params.conditioner is not None:
            n_mels = self.residual_params.conditioner.n_mels
            n_channels = self.residual_params.conditioner.n_channels
        else:
            n_mels = None
            n_channels = 1

        self.residual_layers = nn.ModuleList([
            ResidualBlock(
                self.residual_params.n_residual_channels,
                2**(i % self.residual_params.dilation_cycle_length),
                self.residual_params.kernel,
                self.diffusion_params.diffusion_dim,
                n_mels, n_channels
            )
            for i in range(self.residual_params.n_residual_layers)
        ])

        self.skip_projection = Conv1d(
            self.residual_params.n_residual_channels,
            self.residual_params.n_residual_channels,
            1
        )

    def forward(
        self,
        input_: torch.Tensor,
        diffusion_step: int | float,
        spectrogram: torch.Tensor = None
    ):
        """
        Performs a forward pass through the DiffWave model.

        Args:
            input_ (torch.Tensor): Input tensor to the model.
            diffusion_step (int | float): Tensor representing the diffusion step.
            spectrogram (torch.Tensor, optional): Spectrogram tensor (default is None).

        Returns:
            torch.Tensor: Output tensor after the forward pass.
        """
        if (spectrogram is None) != (self.spectrogram_upsampler is None):
            if (spectrogram is None):
                raise ValueError('Conditioner is required for conditional model')
            else:
                raise ValueError('Conditioner is provided, but the model is unconditional')


        x = input_.unsqueeze(1)
        x = self.input_projection(x)
        x = F.relu(x)

        diffusion_step = self.diffusion_embedding(diffusion_step)
        if self.spectrogram_upsampler: # use conditional model
            spectrogram = self.spectrogram_upsampler(spectrogram)
            spectrogram = torch.permute(spectrogram, (0, 2, 1, 3))

        skip = None
        for layer in self.residual_layers:
            x, skip_connection = layer(x, diffusion_step, spectrogram)
            skip = skip_connection if skip is None else skip_connection + skip

        x = skip / sqrt(len(self.residual_layers))
        x = self.skip_projection(x)
        x = F.relu(x)
        x = self.output_projection(x)
        return x

In [310]:
residual_params = ResidualParams(
    n_residual_layers=10,
    n_residual_channels=64,
    dilation_cycle_length=10,
)
diffusion_params = DiffusionParams(
    max_steps=100
)

diffwave = DiffWave(1, diffusion_params, residual_params)

In [311]:
x = torch.randn(10, 100)

In [315]:
diffwave(x, 3)

tensor([[[0.0780, 0.0913, 0.0657, 0.0949, 0.0802, 0.0968, 0.0743, 0.0636,
          0.0756, 0.0757, 0.0717, 0.0552, 0.0718, 0.0663, 0.0656, 0.0658,
          0.0582, 0.0669, 0.0754, 0.0788, 0.0751, 0.0847, 0.0687, 0.0923,
          0.0809, 0.0889, 0.0803, 0.0700, 0.0859, 0.0811, 0.0905, 0.0793,
          0.0812, 0.0905, 0.0919, 0.0794, 0.1245, 0.0820, 0.0775, 0.1109,
          0.0716, 0.1059, 0.0487, 0.0697, 0.0848, 0.0799, 0.0781, 0.0915,
          0.0811, 0.0986, 0.1290, 0.0772, 0.1087, 0.0916, 0.0741, 0.1024,
          0.1104, 0.0714, 0.0790, 0.0946, 0.1059, 0.0745, 0.0863, 0.0847,
          0.0713, 0.0595, 0.0999, 0.0965, 0.0465, 0.0986, 0.0991, 0.0727,
          0.0872, 0.1201, 0.0880, 0.1083, 0.0810, 0.0809, 0.1102, 0.0857,
          0.0738, 0.0729, 0.1044, 0.0624, 0.0710, 0.0834, 0.0667, 0.0712,
          0.0843, 0.0818, 0.0607, 0.0847, 0.0973, 0.0640, 0.0793, 0.1169,
          0.1027, 0.0708, 0.0921, 0.0875]],

        [[0.0497, 0.0673, 0.0689, 0.0638, 0.1037, 0.0857, 0.0851, 0

In [350]:
residual_params = ResidualParams(
    n_residual_layers=10,
    n_residual_channels=64,
    dilation_cycle_length=10,
    conditioner=ConditionerParams(
        n_mels=16,
    )
)
diffusion_params = DiffusionParams(
    max_steps=100
)
diffwave = DiffWave(1, diffusion_params, residual_params)

In [351]:
x = torch.randn(10, 2560)
spec = torch.randn(10, 1, 16, 10)

diffwave(x, 3, spec)
# residual_params.conditioner.upsampler(spec).shape

tensor([[[0.0772, 0.0605, 0.0637,  ..., 0.0398, 0.0566, 0.0611]],

        [[0.0732, 0.0595, 0.0742,  ..., 0.0498, 0.0558, 0.0656]],

        [[0.0710, 0.0831, 0.0719,  ..., 0.0495, 0.0514, 0.0697]],

        ...,

        [[0.0839, 0.0621, 0.0638,  ..., 0.0611, 0.0614, 0.0684]],

        [[0.0718, 0.0688, 0.0551,  ..., 0.0688, 0.0648, 0.0644]],

        [[0.0718, 0.0729, 0.0724,  ..., 0.0516, 0.0503, 0.0763]]],
       grad_fn=<ConvolutionBackward0>)