In [1]:
import math
import yaml
import librosa
import torch
import torch.nn as nn
from os import path
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from effortless_config import Config
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from core import harmonic_synth, amp_to_impulse_response, fft_convolve
from core import mlp, gru, scale_function, remove_above_nyquist, upsample, get_scheduler, multiscale_fft, safe_log, mean_std_loudness

In [2]:
from models import DDSP_signal_only, DDSP_with_features
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

class args(Config):
    CONFIG = "config.yaml"

args.parse_args("")
with open(args.CONFIG, "r") as config:
    config = yaml.safe_load(config)

ddsp_model = DDSP_with_features(**config["model"]).to(device)

Using device: cpu


In [3]:
# from datasets.dataset_signal import Dataset

# dataset = Dataset(config)
# batch_size = config["hyperparams"]["batch_size"]
# dataloader = torch.utils.data.DataLoader(dataset,
#                                         batch_size,
#                                         shuffle=False,
#                                         drop_last=False,
#                                         )

# print("Size of dataset:", len(dataset), "\nSize of sig batch:", next(iter(dataloader)).size())

In [4]:
from datasets.dataset_all import Dataset, get_files
from effortless_config import Config
import yaml
import torch

#get_files()

class args(Config):
    CONFIG = "config.yaml"

args.parse_args("")
with open(args.CONFIG, "r") as config:
    config = yaml.safe_load(config)

out_dir = config["preprocess"]["out_dir"]

dataset = Dataset(out_dir)
batch_size = config["hyperparams"]["batch_size"]
dataloader = torch.utils.data.DataLoader(dataset,
                                        batch_size,
                                        True,
                                        drop_last=True,
                                        )

print("Size of dataset:", len(dataset), "\nSize of sig batch:", next(iter(dataloader))['signals'].size(), "\nSize of sig batch:", next(iter(dataloader))['pitches'].size(), "\nSize of sig batch:", next(iter(dataloader))['loudness'].size())

Size of dataset: 81 
Size of sig batch: torch.Size([16, 64000]) 
Size of sig batch: torch.Size([16, 400]) 
Size of sig batch: torch.Size([16, 400])


In [5]:
class args(Config):
    CONFIG = "config.yaml"
    NAME = "debug"
    ROOT = "runs"
    STEPS = 500000
    START_LR = 1e-3
    STOP_LR = 1e-4
    DECAY_OVER = 400000

mean_loudness, std_loudness = mean_std_loudness(dataloader)
config["data"]["mean_loudness"] = mean_loudness
config["data"]["std_loudness"] = std_loudness

writer = SummaryWriter(path.join(args.ROOT, args.NAME), flush_secs=20)

with open(path.join(args.ROOT, args.NAME, "config.yaml"), "w") as out_config:
    yaml.safe_dump(config, out_config)

opt = torch.optim.Adam(ddsp_model.parameters(), lr=args.START_LR)

schedule = get_scheduler(
    len(dataloader),
    args.START_LR,
    args.STOP_LR,
    args.DECAY_OVER,
)

best_loss = float("inf")
mean_loss = 0
n_element = 0
step = 0
epochs = int(np.ceil(args.STEPS / len(dataloader)))

In [6]:
# losses = []

# def train(model, loader, optimizer):
#     model.train()
#     device = next(model.parameters()).device
#     total_loss = 0

#     for batch in loader:
#         batch = batch.to(device)
#         y = ddsp_model(batch).squeeze(-1)
                
#         ori_stft = multiscale_fft(
#             batch,
#             config["train"]["scales"],
#             config["train"]["overlap"],
#         )
#         rec_stft = multiscale_fft(
#             y,
#             config["train"]["scales"],
#             config["train"]["overlap"],
#         )

#         loss = 0
#         for s_x, s_y in zip(ori_stft, rec_stft):
#             lin_loss = (s_x - s_y).abs().mean()
#             log_loss = (safe_log(s_x) - safe_log(s_y)).abs().mean()
#             loss = loss + lin_loss + log_loss

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()

#     total_loss /= len(loader)
#     losses.append(total_loss)
    
#     return total_loss

# for e in tqdm(range(epochs)):
#     loss = train(ddsp_model, dataloader, opt)
#     print("Epoch {} -- Loss {:3E}".format(e, loss))

In [7]:
import soundfile as sf

for e in tqdm(range(epochs)):
    for batch in dataloader:
        s = batch['signals'].to(device)
        p = batch['pitches'].unsqueeze(-1).to(device)
        l = batch['loudness'].unsqueeze(-1).to(device)

        l = (l - mean_loudness) / std_loudness

        y = ddsp_model(s, p, l).squeeze(-1)

        ori_stft = multiscale_fft(
            s,
            config["train"]["scales"],
            config["train"]["overlap"],
        )
        rec_stft = multiscale_fft(
            y,
            config["train"]["scales"],
            config["train"]["overlap"],
        )

        loss = 0
        for s_x, s_y in zip(ori_stft, rec_stft):
            lin_loss = (s_x - s_y).abs().mean()
            log_loss = (safe_log(s_x) - safe_log(s_y)).abs().mean()
            loss = loss + lin_loss + log_loss

        opt.zero_grad()
        loss.backward()
        opt.step()

        writer.add_scalar("loss", loss.item(), step)

        step += 1

        n_element += 1
        mean_loss += (loss.item() - mean_loss) / n_element

    if not e % 10:
        writer.add_scalar("lr", schedule(e), e)
        writer.add_scalar("reverb_decay", ddsp_model.reverb.decay.item(), e)
        writer.add_scalar("reverb_wet", ddsp_model.reverb.wet.item(), e)
        # scheduler.step()
        if mean_loss < best_loss:
            best_loss = mean_loss
            torch.save(
                ddsp_model.state_dict(),
                path.join(args.ROOT, args.NAME, "state.pth"),
            )

        mean_loss = 0
        n_element = 0

        audio = torch.cat([s, y], -1).reshape(-1).detach().cpu().numpy()

        sf.write(
            path.join(args.ROOT, args.NAME, f"eval_{e:06d}.wav"),
            audio,
            config["preprocess"]["sample_rate"],
        )


  0%|          | 2/100000 [02:22<1974:56:42, 71.10s/it]


KeyboardInterrupt: 