Waveform generation conditional on mel spectrogram
the whole pipeline for baseline models are in this file
I also pasted some of the diffwave model code here

In [None]:
import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import soundfile as sf

# Dataset definition sampling a piece of Mel and corresponding ground truth waveform from the dataset
class Mel2WaveDataset(Dataset):
    def __init__(self, data_dir, sr=22050, hop_length=256, crop_mel_frames=None):
        self.data_dir = data_dir
        self.sr = sr
        self.hop_length = hop_length
        self.crop_mel_frames = crop_mel_frames
        wav_paths = glob.glob(os.path.join(data_dir, "*.wav"))
        self.basenames = [os.path.splitext(os.path.basename(p))[0] for p in wav_paths]

    def __len__(self):
        return len(self.basenames)

    def __getitem__(self, idx):
        name = self.basenames[idx]
        mel_full = np.load(os.path.join(self.data_dir, f"{name}.wav.spec.npy"))
        mel_full = torch.from_numpy(mel_full).float()
        n_mels, T_mel_total = mel_full.shape

        wav_full, sr_load = sf.read(os.path.join(self.data_dir, f"{name}.wav"))
        wav_full = torch.from_numpy(wav_full).float()
        T_wav_full = wav_full.size(0)

        if self.crop_mel_frames is not None and T_mel_total >= 1:
            K = self.crop_mel_frames
            if T_mel_total >= K:
                start_frame = random.randint(0, T_mel_total - K)
                mel = mel_full[:, start_frame : start_frame + K]
                start_sample = start_frame * self.hop_length
                wav_len = K * self.hop_length
                end_sample = start_sample + wav_len
                if end_sample <= T_wav_full:
                    wav = wav_full[start_sample : end_sample]
                else:
                    valid = max(0, T_wav_full - start_sample)
                    part = wav_full[start_sample:] if valid > 0 else torch.zeros(0)
                    pad_len = wav_len - valid
                    wav = torch.cat([part, torch.zeros(pad_len)], dim=0)
            else:
                pad_mel = torch.zeros(n_mels, K - T_mel_total)
                mel = torch.cat([mel_full, pad_mel], dim=1)
                wav_needed = K * self.hop_length
                if T_wav_full >= wav_needed:
                    wav = wav_full[:wav_needed]
                else:
                    pad_wav = torch.zeros(wav_needed - T_wav_full)
                    wav = torch.cat([wav_full, pad_wav], dim=0)
        else:
            mel = mel_full
            wav = wav_full


        if torch.mean(torch.abs(wav)) < 1e-3:
            new_idx = random.randint(0, len(self.basenames) - 1)
            return self.__getitem__(new_idx)

        return mel, wav

# A simple MLP model
class Mel2WaveMLP(nn.Module):
    def __init__(self, n_mels=80, K=64, hop_length=256, hidden_dim=1024):
        super().__init__()
        self.n_mels = n_mels
        self.K = K
        self.hop_length = hop_length
        self.wav_len = K * hop_length
        self.net = nn.Sequential(
            nn.Linear(n_mels * K, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, self.wav_len),
        )

    def forward(self, mel):
        B, n_mels, K = mel.shape
        x = mel.reshape(B, -1)
        wav_pred = self.net(x)
        return wav_pred


In [None]:
#training pipeline
data_dir = "train_diffwave_aug"
crop_mel_frames = 64
n_mels = 80
hop_length = 256
batch_size = 8
lr = 3e-5                      
total_epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = Mel2WaveDataset(
    data_dir=data_dir,
    sr=22050,
    hop_length=hop_length,
    crop_mel_frames=crop_mel_frames
)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True
)


mlp_model = Mel2WaveMLP(n_mels=n_mels, K=crop_mel_frames, hop_length=hop_length, hidden_dim=1024).to(device)

optimizer = torch.optim.Adam(mlp_model.parameters(), lr=lr)


for epoch in range(1, total_epochs):
    mlp_model.train()
    run_loss = 0.0
    for mel, wav_gt in dataloader:
        mel = mel.to(device)
        wav_gt = wav_gt.to(device)

        wav_pred = mlp_model(mel)
        wav_pred = torch.clamp(wav_pred, -1.0, 1.0)
        wav_gt   = torch.clamp(wav_gt,   -1.0, 1.0)

        loss = torch.nn.functional.mse_loss(wav_pred, wav_gt)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(mlp_model.parameters(), max_norm=1.0) 
        optimizer.step()

        run_loss += loss.item() * mel.size(0)
    avg = run_loss / len(dataset)
    print(f"[Warm-up {epoch:02d}/{total_epochs:02d}] MSE Loss: {avg:.6f}")


torch.save(mlp_model.state_dict(), "mel2wave_mlp_spectral_stable.pth")

In [None]:
class Mel2WaveConvNet(nn.Module):
    def __init__(self, n_mels=80, K=64, hop_length=256):
        super().__init__()
        self.n_mels = n_mels
        self.K = K
        self.hop_length = hop_length
        self.upsample_factor = 4 
        self.num_upsamples = 4   
        self.final_wav_len = K * (hop_length)  

        #project n_mel to hidden channels#
        self.conv_in = nn.Conv1d(in_channels=n_mels,
                                 out_channels=512,
                                 kernel_size=3,
                                 padding=1)         
        self.bn_in = nn.BatchNorm1d(512)

        #4 upsample layers that expand in time dimension and shrink in hidden channels#
        self.deconv1 = nn.ConvTranspose1d(in_channels=512,
                                          out_channels=256,
                                          kernel_size=4,
                                          stride=4,
                                          padding=0)     
        self.bn1 = nn.BatchNorm1d(256)

        self.deconv2 = nn.ConvTranspose1d(in_channels=256,
                                          out_channels=128,
                                          kernel_size=4,
                                          stride=4,
                                          padding=0)      
        self.bn2 = nn.BatchNorm1d(128)

        self.deconv3 = nn.ConvTranspose1d(in_channels=128,
                                          out_channels=64,
                                          kernel_size=4,
                                          stride=4,
                                          padding=0)      
        self.bn3 = nn.BatchNorm1d(64)

        self.deconv4 = nn.ConvTranspose1d(in_channels=64,
                                          out_channels=32,
                                          kernel_size=4,
                                          stride=4,
                                          padding=0)       
        self.bn4 = nn.BatchNorm1d(32)

        #further shrink the number of channels to the mono track waveform#
        self.conv_mid = nn.Conv1d(in_channels=32,
                                  out_channels=16,
                                  kernel_size=3,
                                  padding=1)        
        self.bn_mid = nn.BatchNorm1d(16)

        self.conv_out = nn.Conv1d(in_channels=16,
                                  out_channels=1,
                                  kernel_size=7,
                                  padding=3)      

        self.act_out = nn.Tanh()

        self.relu = nn.ReLU(inplace=True)

    def forward(self, mel):

        x = self.conv_in(mel)            # → (batch, 512, K)
        x = self.relu(self.bn_in(x))

        x = self.deconv1(x)              # → (batch, 256, 4K)
        x = self.relu(self.bn1(x))

        x = self.deconv2(x)              # → (batch, 128, 16K)
        x = self.relu(self.bn2(x))

        x = self.deconv3(x)              # → (batch,  64, 64K)
        x = self.relu(self.bn3(x))

        x = self.deconv4(x)              # → (batch,  32, 256K) or (batch, 32, K*hop_length)
        x = self.relu(self.bn4(x))

        x = self.conv_mid(x)             # → (batch, 16, 256K)
        x = self.relu(self.bn_mid(x))

        x = self.conv_out(x)             # → (batch, 1, 256K)
        wav = self.act_out(x)            # → (batch, 1, 256K)

        return wav.squeeze(1)            # → (batch, 256K)


In [None]:
#training pipline
data_dir = "train_diffwave"
crop_mel_frames = 64    # K
n_mels = 80
hop_length = 256
batch_size = 8
lr = 1e-4
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Dataset/ DataLoader
dataset = Mel2WaveDataset(data_dir=data_dir,
                          sr=22050,
                          hop_length=hop_length,
                          crop_mel_frames=crop_mel_frames)
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4,
                        drop_last=True)


In [None]:
model = Mel2WaveConvNet(n_mels=n_mels,
                        K=crop_mel_frames,
                        hop_length=hop_length).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

In [None]:
for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0.0

    for mel, wav_gt in dataloader:

        mel = mel.to(device)
        wav_gt = wav_gt.to(device)


        wav_pred = model(mel)  

        loss = criterion(wav_pred, wav_gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * mel.size(0)

    avg_loss = epoch_loss / len(dataset)
    print(f"Epoch {epoch:02d}/{epochs:02d}, Loss: {avg_loss:.6f}")

torch.save(model.state_dict(), "mel2wave_convnet.pth")

In [None]:
#inference block for baseline models
import math

mlp_model = Mel2WaveMLP(n_mels=80, K=64, hop_length=256, hidden_dim=1024).to(device)
mlp_model.load_state_dict(torch.load("mel2wave_mlp_spectral_stable.pth", map_location=device))
mlp_model.eval()

conv_model = Mel2WaveConvNet(n_mels=80, K=64, hop_length=256).to(device)
conv_model.load_state_dict(torch.load("mel2wave_convnet.pth", map_location=device))
conv_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_mels = 80
K = 64
hop_length = 256
sr = 22050

def infer_mlp_arbitrary_length(
    mlp_model,
    mel_npy_path,
    output_wav_path,
    n_mels=80,
    K=64,
    hop_length=256,
    sr=22050,
    device=torch.device("cpu"),
):

    mel_np = np.load(mel_npy_path)  # shape (n_mels, T_mel_total)
    if mel_np.ndim != 2 or mel_np.shape[0] != n_mels:
        raise ValueError(f"mel.npy should be ({n_mels}, T), but got {mel_np.shape}")

    T_mel_total = mel_np.shape[1]
    num_chunks = math.ceil(T_mel_total / K)
    wav_segments = []

    with torch.no_grad():
        for i in range(num_chunks):
            start = i * K
            end = start + K
            if end <= T_mel_total:
                mel_seg = mel_np[:, start:end]  # (n_mels, K)
            else:
                valid = mel_np[:, start:T_mel_total]
                pad_len = K - (T_mel_total - start)
                mel_seg = np.concatenate([valid, np.zeros((n_mels, pad_len), dtype=mel_np.dtype)], axis=1)

            mel_tensor = torch.from_numpy(mel_seg).float().unsqueeze(0).to(device)  # (1, n_mels, K)
            wav_pred = mlp_model(mel_tensor)  # (1, K*hop_length)
            wav_pred = wav_pred.squeeze(0).cpu().numpy()  # (K*hop_length,)

            if end > T_mel_total:
                valid_frames = T_mel_total - start
                valid_samples = valid_frames * hop_length
                wav_pred = wav_pred[:valid_samples]

            wav_segments.append(wav_pred)

    wav_full = np.concatenate(wav_segments, axis=0)
    os.makedirs(os.path.dirname(output_wav_path), exist_ok=True)
    sf.write(output_wav_path, wav_full, sr)
    print(f"generated {output_wav_path} with {wav_full.shape[0]} points")


def infer_conv_arbitrary_length(
    conv_model,
    mel_npy_path,
    output_wav_path,
    n_mels=80,
    K=64,
    hop_length=256,
    sr=22050,
    device=torch.device("cpu"),
):
    mel_np = np.load(mel_npy_path)  # shape (n_mels, T_mel_total)
    if mel_np.ndim != 2 or mel_np.shape[0] != n_mels:
        raise ValueError(f"mel.npy should be ({n_mels}, T), but got {mel_np.shape}")

    T_mel_total = mel_np.shape[1]
    num_chunks = math.ceil(T_mel_total / K)
    wav_segments = []

    with torch.no_grad():
        for i in range(num_chunks):
            start = i * K
            end = start + K
            if end <= T_mel_total:
                mel_seg = mel_np[:, start:end]
            else:
                valid = mel_np[:, start:T_mel_total]
                pad_len = K - (T_mel_total - start)
                mel_seg = np.concatenate([valid, np.zeros((n_mels, pad_len), dtype=mel_np.dtype)], axis=1)

            mel_tensor = torch.from_numpy(mel_seg).float().unsqueeze(0).to(device)  # (1, n_mels, K)
            wav_pred = conv_model(mel_tensor)  # (1, K*hop_length)
            wav_pred = wav_pred.squeeze(0).cpu().numpy()  # (K*hop_length,)

            if end > T_mel_total:
                valid_frames = T_mel_total - start
                valid_samples = valid_frames * hop_length
                wav_pred = wav_pred[:valid_samples]

            wav_segments.append(wav_pred)

    wav_full = np.concatenate(wav_segments, axis=0)
    os.makedirs(os.path.dirname(output_wav_path), exist_ok=True)
    sf.write(output_wav_path, wav_full, sr)
    print(f"generated{output_wav_path} {wav_full.shape[0]} points")


infer_mlp_arbitrary_length(
    mlp_model=mlp_model,
    mel_npy_path="yequ.npy",
    output_wav_path="outputs/mlp_output.wav",
    n_mels=n_mels,
    K=K,
    hop_length=hop_length,
    sr=sr,
    device=device,
)

infer_conv_arbitrary_length(
    conv_model=conv_model,
    mel_npy_path="yequ.npy",
    output_wav_path="outputs/conv_output.wav",
    n_mels=n_mels,
    K=K,
    hop_length=hop_length,
    sr=sr,
    device=device,
)

In [None]:
#evaluation block
import os
import numpy as np
import librosa
import soundfile as sf
import math


def load_wav(path, sr=None):
    wav, orig_sr = sf.read(path)
    wav = wav.astype(np.float32)
    if sr is not None and orig_sr != sr:
        wav = librosa.resample(wav, orig_sr, sr)
    return wav


def pad_or_trim(a, b):
    min_len = min(len(a), len(b))
    return a[:min_len], b[:min_len]


def compute_time_domain_metrics(wav_ref, wav_gen):
    ref, gen = pad_or_trim(wav_ref, wav_gen)

    diff = ref - gen
    mse = np.mean(diff ** 2)
    mae = np.mean(np.abs(diff))

    eps = 1e-8
    signal_power = np.sum(ref ** 2) + eps
    noise_power = np.sum(diff ** 2) + eps
    snr = 10 * math.log10(signal_power / noise_power)

    return {'mse': mse, 'mae': mae, 'snr_db': snr}


def compute_spectral_metrics(wav_ref, wav_gen, sr, n_fft=1024, hop_length=256):
    ref, gen = pad_or_trim(wav_ref, wav_gen)

    S_ref = np.abs(librosa.stft(ref, n_fft=n_fft, hop_length=hop_length))
    S_gen = np.abs(librosa.stft(gen, n_fft=n_fft, hop_length=hop_length))

    min_freq_bins = min(S_ref.shape[0], S_gen.shape[0])
    min_time_steps = min(S_ref.shape[1], S_gen.shape[1])
    S_ref = S_ref[:min_freq_bins, :min_time_steps]
    S_gen = S_gen[:min_freq_bins, :min_time_steps]

    diff = S_ref - S_gen
    sc_num = np.linalg.norm(diff, ord='fro')
    sc_den = np.linalg.norm(S_ref, ord='fro') + 1e-8
    spectral_convergence = sc_num / sc_den


    log_ref = np.log(S_ref + 1e-8)
    log_gen = np.log(S_gen + 1e-8)
    lsd_frame = np.linalg.norm(log_ref - log_gen, axis=0)  
    log_stft_distance = np.mean(lsd_frame)

    return {
        'spectral_convergence': spectral_convergence,
        'log_stft_distance': log_stft_distance
    }


def compute_mcd(wav_ref, wav_gen, sr, n_mels=80, n_mfcc=13, hop_length=256):
    ref, gen = pad_or_trim(wav_ref, wav_gen)

    mfcc_ref = librosa.feature.mfcc(
        y=ref, sr=sr, n_mfcc=n_mfcc, n_mels=n_mels, hop_length=hop_length
    )  # shape=(n_mfcc, T_ref)
    mfcc_gen = librosa.feature.mfcc(
        y=gen, sr=sr, n_mfcc=n_mfcc, n_mels=n_mels, hop_length=hop_length
    )  # shape=(n_mfcc, T_gen)

    min_frames = min(mfcc_ref.shape[1], mfcc_gen.shape[1])
    mfcc_ref = mfcc_ref[:, :min_frames]
    mfcc_gen = mfcc_gen[:, :min_frames]

    diff = mfcc_ref - mfcc_gen  # shape=(n_mfcc, min_frames)
    dist_per_frame = np.linalg.norm(diff, axis=0)  # (min_frames,)
    mcd = (10.0 / math.log(10.0)) * math.sqrt(2.0) * np.mean(dist_per_frame + 1e-8)
    return mcd


def evaluate_wav_pair(
    wav_ref_path,
    wav_gen_path,
    sr=22050,
    n_fft=1024,
    hop_length=256,
    n_mels=80,
    n_mfcc=13
):

    wav_ref = load_wav(wav_ref_path, sr=sr)
    wav_gen = load_wav(wav_gen_path, sr=sr)

    td_metrics = compute_time_domain_metrics(wav_ref, wav_gen)


    spec_metrics = compute_spectral_metrics(wav_ref, wav_gen, sr, n_fft=n_fft, hop_length=hop_length)


    mcd_value = compute_mcd(wav_ref, wav_gen, sr, n_mels=n_mels, n_mfcc=n_mfcc, hop_length=hop_length)

    results = {
        **td_metrics,
        **spec_metrics,
        'mcd': mcd_value
    }
    return results



wav_ref_path = "yequ.wav"

wav_gen_path = "output/mlp_output.wav"

metrics = evaluate_wav_pair(
    wav_ref_path=wav_ref_path,
    wav_gen_path=wav_gen_path,
    sr=22050,
    n_fft=1024,
    hop_length=256,
    n_mels=80,
    n_mfcc=13
)

print("===== result =====")
for k, v in metrics.items():
    print(f"{k}: {v:.6f}")


convolution 
===== result =====
mse: 0.093086
mae: 0.242965
snr_db: -0.018113
spectral_convergence: 0.995097
log_stft_distance: 64.287231
mcd: 1572.881348

mlp
===== result =====
mse: 0.092785
mae: 0.242588
snr_db: -0.004060
spectral_convergence: 0.991616
log_stft_distance: 48.453880
mcd: 1165.654419

diffwave_pretrained
===== result =====
mse: 0.109099
mae: 0.261110
snr_db: -0.707505
spectral_convergence: 0.781583
log_stft_distance: 28.555260
mcd: 375.424744

diffwave_finetuned
===== result =====
mse: 0.151037
mae: 0.301934
snr_db: -2.120124
spectral_convergence: 0.402407
log_stft_distance: 25.203579
mcd: 217.240768

diffwave_finetuned
===== result =====
mse: 0.160350
mae: 0.310657
snr_db: -2.379991
spectral_convergence: 0.352338
log_stft_distance: 24.570955
mcd: 170.337021

In [None]:
#diffwave model body
# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from math import sqrt


Linear = nn.Linear
ConvTranspose2d = nn.ConvTranspose2d


def Conv1d(*args, **kwargs):
  layer = nn.Conv1d(*args, **kwargs)
  nn.init.kaiming_normal_(layer.weight)
  return layer


@torch.jit.script
def silu(x):
  return x * torch.sigmoid(x)


class DiffusionEmbedding(nn.Module):
  def __init__(self, max_steps):
    super().__init__()
    self.register_buffer('embedding', self._build_embedding(max_steps), persistent=False)
    self.projection1 = Linear(128, 512)
    self.projection2 = Linear(512, 512)

  def forward(self, diffusion_step):
    if diffusion_step.dtype in [torch.int32, torch.int64]:
      x = self.embedding[diffusion_step]
    else:
      x = self._lerp_embedding(diffusion_step)
    x = self.projection1(x)
    x = silu(x)
    x = self.projection2(x)
    x = silu(x)
    return x

  def _lerp_embedding(self, t):
    low_idx = torch.floor(t).long()
    high_idx = torch.ceil(t).long()
    low = self.embedding[low_idx]
    high = self.embedding[high_idx]
    return low + (high - low) * (t - low_idx)

  def _build_embedding(self, max_steps):
    steps = torch.arange(max_steps).unsqueeze(1)  # [T,1]
    dims = torch.arange(64).unsqueeze(0)          # [1,64]
    table = steps * 10.0**(dims * 4.0 / 63.0)     # [T,64]
    table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
    return table


class SpectrogramUpsampler(nn.Module):
  def __init__(self, n_mels):
    super().__init__()
    self.conv1 = ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
    self.conv2 = ConvTranspose2d(1, 1,  [3, 32], stride=[1, 16], padding=[1, 8])

  def forward(self, x):
    x = torch.unsqueeze(x, 1)
    x = self.conv1(x)
    x = F.leaky_relu(x, 0.4)
    x = self.conv2(x)
    x = F.leaky_relu(x, 0.4)
    x = torch.squeeze(x, 1)
    return x


class ResidualBlock(nn.Module):
  def __init__(self, n_mels, residual_channels, dilation, uncond=False):
    '''
    :param n_mels: inplanes of conv1x1 for spectrogram conditional
    :param residual_channels: audio conv
    :param dilation: audio conv dilation
    :param uncond: disable spectrogram conditional
    '''
    super().__init__()
    self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
    self.diffusion_projection = Linear(512, residual_channels)
    if not uncond: # conditional model
      self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
    else: # unconditional model
      self.conditioner_projection = None

    self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)

  def forward(self, x, diffusion_step, conditioner=None):
    assert (conditioner is None and self.conditioner_projection is None) or \
           (conditioner is not None and self.conditioner_projection is not None)

    diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
    y = x + diffusion_step
    if self.conditioner_projection is None: # using a unconditional model
      y = self.dilated_conv(y)
    else:
      conditioner = self.conditioner_projection(conditioner)
      y = self.dilated_conv(y) + conditioner

    gate, filter = torch.chunk(y, 2, dim=1)
    y = torch.sigmoid(gate) * torch.tanh(filter)

    y = self.output_projection(y)
    residual, skip = torch.chunk(y, 2, dim=1)
    return (x + residual) / sqrt(2.0), skip


class DiffWave(nn.Module):
  def __init__(self, params):
    super().__init__()
    self.params = params
    self.input_projection = Conv1d(1, params.residual_channels, 1)
    self.diffusion_embedding = DiffusionEmbedding(len(params.noise_schedule))
    if self.params.unconditional: # use unconditional model
      self.spectrogram_upsampler = None
    else:
      self.spectrogram_upsampler = SpectrogramUpsampler(params.n_mels)

    self.residual_layers = nn.ModuleList([
        ResidualBlock(params.n_mels, params.residual_channels, 2**(i % params.dilation_cycle_length), uncond=params.unconditional)
        for i in range(params.residual_layers)
    ])
    self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1)
    self.output_projection = Conv1d(params.residual_channels, 1, 1)
    nn.init.zeros_(self.output_projection.weight)

  def forward(self, audio, diffusion_step, spectrogram=None):
    assert (spectrogram is None and self.spectrogram_upsampler is None) or \
           (spectrogram is not None and self.spectrogram_upsampler is not None)
    x = audio.unsqueeze(1)
    x = self.input_projection(x)
    x = F.relu(x)

    diffusion_step = self.diffusion_embedding(diffusion_step)
    if self.spectrogram_upsampler: # use conditional model
      spectrogram = self.spectrogram_upsampler(spectrogram)

    skip = None
    for layer in self.residual_layers:
      x, skip_connection = layer(x, diffusion_step, spectrogram)
      skip = skip_connection if skip is None else skip_connection + skip

    x = skip / sqrt(len(self.residual_layers))
    x = self.skip_projection(x)
    x = F.relu(x)
    x = self.output_projection(x)
    return x


In [None]:
#diffwave training body
# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np
import os
import torch
import torch.nn as nn

from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from diffwave.dataset import from_path, from_gtzan
from diffwave.model import DiffWave
from diffwave.params import AttrDict


def _nested_map(struct, map_fn):
  if isinstance(struct, tuple):
    return tuple(_nested_map(x, map_fn) for x in struct)
  if isinstance(struct, list):
    return [_nested_map(x, map_fn) for x in struct]
  if isinstance(struct, dict):
    return { k: _nested_map(v, map_fn) for k, v in struct.items() }
  return map_fn(struct)


class DiffWaveLearner:
  def __init__(self, model_dir, model, dataset, optimizer, params, *args, **kwargs):
    os.makedirs(model_dir, exist_ok=True)
    self.model_dir = model_dir
    self.model = model
    self.dataset = dataset
    self.optimizer = optimizer
    self.params = params
    self.autocast = torch.cuda.amp.autocast(enabled=kwargs.get('fp16', False))
    self.scaler = torch.cuda.amp.GradScaler(enabled=kwargs.get('fp16', False))
    self.step = 0
    self.is_master = True

    beta = np.array(self.params.noise_schedule)
    noise_level = np.cumprod(1 - beta)
    self.noise_level = torch.tensor(noise_level.astype(np.float32))
    self.loss_fn = nn.L1Loss()
    self.summary_writer = None

  def state_dict(self):
    if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module):
      model_state = self.model.module.state_dict()
    else:
      model_state = self.model.state_dict()
    return {
        'step': self.step,
        'model': { k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model_state.items() },
        'optimizer': { k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.optimizer.state_dict().items() },
        'params': dict(self.params),
        'scaler': self.scaler.state_dict(),
    }

  def load_state_dict(self, state_dict):
    if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module):
      self.model.module.load_state_dict(state_dict['model'])
    else:
      self.model.load_state_dict(state_dict['model'])
    self.optimizer.load_state_dict(state_dict['optimizer'])
    self.scaler.load_state_dict(state_dict['scaler'])
    self.step = state_dict['step']

  def save_to_checkpoint(self, filename='weights'):
    save_basename = f'{filename}-{self.step}.pt'
    save_name = f'{self.model_dir}/{save_basename}'
    link_name = f'{self.model_dir}/{filename}.pt'
    torch.save(self.state_dict(), save_name)
    if os.name == 'nt':
      torch.save(self.state_dict(), link_name)
    else:
      if os.path.islink(link_name):
        os.unlink(link_name)
      os.symlink(save_basename, link_name)

  def restore_from_checkpoint(self, filename='weights'):
    try:
      checkpoint = torch.load(f'{self.model_dir}/{filename}.pt')
      self.load_state_dict(checkpoint)
      return True
    except FileNotFoundError:
      return False

  def train(self, max_steps=None):
    device = next(self.model.parameters()).device
    while True:
      for features in tqdm(self.dataset, desc=f'Epoch {self.step // len(self.dataset)}') if self.is_master else self.dataset:
        if max_steps is not None and self.step >= max_steps:
          return
        features = _nested_map(features, lambda x: x.to(device) if isinstance(x, torch.Tensor) else x)
        loss = self.train_step(features)
        if torch.isnan(loss).any():
          raise RuntimeError(f'Detected NaN loss at step {self.step}.')
        if self.is_master:
          if self.step % 50 == 0:
            self._write_summary(self.step, features, loss)
          if self.step % len(self.dataset) == 0:
            self.save_to_checkpoint()
        self.step += 1

  def train_step(self, features):
    for param in self.model.parameters():
      param.grad = None

    audio = features['audio']
    spectrogram = features['spectrogram']

    N, T = audio.shape
    device = audio.device
    self.noise_level = self.noise_level.to(device)

    with self.autocast:
      t = torch.randint(0, len(self.params.noise_schedule), [N], device=audio.device)
      noise_scale = self.noise_level[t].unsqueeze(1)
      noise_scale_sqrt = noise_scale**0.5
      noise = torch.randn_like(audio)
      noisy_audio = noise_scale_sqrt * audio + (1.0 - noise_scale)**0.5 * noise

      predicted = self.model(noisy_audio, t, spectrogram)
      loss = self.loss_fn(noise, predicted.squeeze(1))

    self.scaler.scale(loss).backward()
    self.scaler.unscale_(self.optimizer)
    self.grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.params.max_grad_norm or 1e9)
    self.scaler.step(self.optimizer)
    self.scaler.update()
    return loss

  def _write_summary(self, step, features, loss):
    writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step)
    writer.add_audio('feature/audio', features['audio'][0], step, sample_rate=self.params.sample_rate)
    if not self.params.unconditional:
      writer.add_image('feature/spectrogram', torch.flip(features['spectrogram'][:1], [1]), step)
    writer.add_scalar('train/loss', loss, step)
    writer.add_scalar('train/grad_norm', self.grad_norm, step)
    writer.flush()
    self.summary_writer = writer


def _train_impl(replica_id, model, dataset, args, params):
  torch.backends.cudnn.benchmark = True
  opt = torch.optim.Adam(model.parameters(), lr=params.learning_rate)

  learner = DiffWaveLearner(args.model_dir, model, dataset, opt, params, fp16=args.fp16)
  learner.is_master = (replica_id == 0)
  learner.restore_from_checkpoint()
  learner.train(max_steps=args.max_steps)


def train(args, params):
  if args.data_dirs[0] == 'gtzan':
    dataset = from_gtzan(params)
  else:
    dataset = from_path(args.data_dirs, params)
  model = DiffWave(params).cuda()
  _train_impl(0, model, dataset, args, params)


def train_distributed(replica_id, replica_count, port, args, params):
  os.environ['MASTER_ADDR'] = 'localhost'
  os.environ['MASTER_PORT'] = str(port)
  torch.distributed.init_process_group('nccl', rank=replica_id, world_size=replica_count)
  if args.data_dirs[0] == 'gtzan':
    dataset = from_gtzan(params, is_distributed=True)
  else:
    dataset = from_path(args.data_dirs, params, is_distributed=True)
  device = torch.device('cuda', replica_id)
  torch.cuda.set_device(device)
  model = DiffWave(params).to(device)
  model = DistributedDataParallel(model, device_ids=[replica_id])
  _train_impl(replica_id, model, dataset, args, params)
