In [37]:
import torch
import os
import random
import json
import glob

import torch.multiprocessing as mp
import torch.nn.functional as F
from tqdm import tqdm

from dataset import LLVCDataset as Dataset
from model import Net
from discriminators import MultiPeriodDiscriminator, discriminator_loss, generator_loss, feature_loss
import utils
import fairseq


In [38]:


class LLVCDataset(torch.utils.data.Dataset):

    def __init__(
        self, **kwargs
    ):
        pass

    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return torch.rand([1, 240]),torch.rand([1, 240]),

In [39]:
# test_data=LLVCDataset()
# test_data[3]

In [40]:

def net_g_step(
    batch, net_g, device
):
    og, gt = batch
    og = og.to(device=device, non_blocking=True)
    gt = gt.to(device=device, non_blocking=True)

    output = net_g(og)
    return output, gt, og

In [1]:


def training_runner(
    config,
    training_dir,
):
    checkpoint_dir = os.path.join(training_dir, "checkpoints")

    device = "cuda" if torch.cuda.is_available() else "cpu"


    torch.manual_seed(config['seed'])

    data_train = LLVCDataset(
        **config['data'], dset='train')
    data_val = LLVCDataset(
        **config['data'], dset='val')
    
    for ds in [data_train, data_val]:
        print(
            f"Loaded dataset containing {len(ds)} elements")

    train_loader = torch.utils.data.DataLoader(data_train,
                                               batch_size=1,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(data_val,
                                             batch_size=1)

    net_g = Net(**config['model_params'])
    net_g = net_g.to(device=device)

    if config['discriminator'] == 'hfg':
        net_d = ComboDisc()
    else:
        net_d = MultiPeriodDiscriminator(periods=config['periods'])
    net_d = net_d.to(device=device)

    optim_g = torch.optim.AdamW(
        net_g.parameters(),
        config['optim']['lr'],
        betas=config['optim']['betas'],
        eps=config['optim']['eps'],
        weight_decay=config['optim']['weight_decay']
    )
    optim_d = torch.optim.AdamW(
        net_d.parameters(),
        config['optim']['lr'],
        betas=config['optim']['betas'],
        eps=config['optim']['eps'],
        weight_decay=config['optim']['weight_decay']
    )

   


    lr = config['optim']['lr']
    global_step = 0
    epoch = 0

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=config['lr_sched']['lr_decay']
    )
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=config['lr_sched']['lr_decay']
    )


    # load fairseq model
    # if config['aux_fairseq']['c'] > 0:
    #     cp_path = config['aux_fairseq']['checkpoint_path']
    #     fairseq_model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([
    #         cp_path])
    #     fairseq_model = fairseq_model[0]
    #     # move model to GPU
    #     fairseq_model.eval().to(device)
    # else:
    fairseq_model = None

    cache = []
    loss_mel_avg = utils.RunningAvg()
    loss_fairseq_avg = utils.RunningAvg()
    for epoch in range(epoch, 10000):
        # train_loader.batch_sampler.set_epoch(epoch)

        net_g.train()
        net_d.train()

        use_cache = len(cache) == len(train_loader)
        data = cache if use_cache else enumerate(train_loader)

        lr = optim_g.param_groups[0]["lr"]

        # count down steps to next checkpoint
        progress_bar = tqdm(range(config['checkpoint_interval']))
        progress_bar.update(global_step % config['checkpoint_interval'])

        for batch_idx, batch in data:
            output, gt, og = net_g_step(
                batch, net_g, device)

            # take random slices of input and output wavs
            if config['segment_size'] < output.shape[-1]:
                start_idx = random.randint(
                    0, output.shape[-1] - config['segment_size'] - 1)
                gt_sliced = gt[:, :, start_idx:start_idx +
                               config['segment_size']]
                output_sliced = output.detach()[:, :,
                                                start_idx:start_idx + config['segment_size']]
            else:
                gt_sliced = gt
                output_sliced = output.detach()

            
            y_d_hat_r, y_d_hat_g, _, _ = net_d(
                output_sliced, gt_sliced)
            loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
                y_d_hat_r, y_d_hat_g
            )

            optim_d.zero_grad()
            loss_disc.backward()
            if config['grad_clip_threshold'] is not None:
                grad_norm_d = torch.nn.utils.clip_grad_norm_(
                    net_d.parameters(), config['grad_clip_threshold'])
            grad_norm_d = utils.clip_grad_value_(
                net_d.parameters(), config['grad_clip_value'])
            optim_d.step()

            # Generator
            y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(gt, output)
            if fairseq_model is not None:
                loss_fairseq = utils.fairseq_loss(
                    output, gt, fairseq_model) * config['aux_fairseq']['c']
            else:
                loss_fairseq = torch.tensor(0.0)
            loss_fairseq_avg.update(loss_fairseq)
           
            # if config['aux_mel']['c'] > 0:
            #     loss_mel = utils.aux_mel_loss(
            #         output, gt, config) * config['aux_mel']['c']
            # else:
            loss_mel = torch.tensor(0.0)
            loss_mel_avg.update(loss_mel)
            loss_fm = feature_loss(
                fmap_r, fmap_g) * config['feature_loss_c']
            loss_gen, losses_gen = generator_loss(
                y_d_hat_g)
            loss_gen = loss_gen * config['disc_loss_c']
            loss_gen_all = (loss_gen + loss_fm) + loss_mel + \
                loss_fairseq

            optim_g.zero_grad()
            loss_gen_all.backward()
            if config['grad_clip_threshold'] is not None:
                grad_norm_g = torch.nn.utils.clip_grad_norm_(
                    net_g.parameters(), config['grad_clip_threshold'])
            grad_norm_g = utils.clip_grad_value_(
                net_g.parameters(), config['grad_clip_value'])
            optim_g.step()

            global_step += 1
            progress_bar.update(1)

            if global_step > 0 and (global_step % config['log_interval'] == 0):
                lr = optim_g.param_groups[0]["lr"]
                # Amor For Tensorboard display
                if loss_mel > 50:
                    loss_mel = 50

                scalar_dict = {
                    "loss/g/total": loss_gen_all,
                    "loss/d/total": loss_disc,
                    "learning_rate": lr,
                    "grad_norm_d": grad_norm_d,
                    "grad_norm_g": grad_norm_g,
                }
                scalar_dict.update(
                    {
                        "loss/g/fm": loss_fm,
                        "loss/g/mel": loss_mel,
                    }
                )

                if config['aux_mel']['c'] > 0:
                    scalar_dict.update({"train_metrics/mel": loss_mel_avg()})
                    loss_mel_avg.reset()

                if fairseq_model is not None:
                    scalar_dict.update(
                        {
                            "loss/g/fairseq": loss_fairseq,
                        }
                    )
                    scalar_dict.update(
                        {"train_metrics/fairseq": loss_fairseq_avg()}
                    )
                    loss_fairseq_avg.reset()

                scalar_dict.update(
                    {"loss/g/{}".format(i): v for i,
                     v in enumerate(losses_gen)}
                )
                scalar_dict.update(
                    {"loss/d_r/{}".format(i): v for i,
                     v in enumerate(losses_disc_r)}
                )
                scalar_dict.update(
                    {"loss/d_g/{}".format(i): v for i,
                     v in enumerate(losses_disc_g)}
                )
                audio_dict = {}
                audio_dict.update(
                    {f"train_audio/gt_{i}": gt[i].data.cpu().numpy()
                     for i in range(min(3, gt.shape[0]))}
                )
                audio_dict.update(
                    {f"train_audio/in_{i}": og[i].data.cpu().numpy()
                     for i in range(min(3, og.shape[0]))}
                )
                audio_dict.update(
                    {f"train_audio/pred_{i}": output[i].data.cpu().numpy()
                     for i in range(min(3, output.shape[0]))}
                )
                net_g.eval()

                # load audio from benchmark dir
                test_wavs = [
                    (
                        os.path.basename(p),
                        utils.load_wav_to_torch(p, config['data']['sr']),
                    )
                    for p in glob.glob(config['test_dir'] + "/*.wav")
                ]

                logging.info("Testing...")
                for test_wav_name, test_wav in tqdm(test_wavs, total=len(test_wavs)):
                    test_out = net_g(test_wav.unsqueeze(
                        0).unsqueeze(0).to(device))
                    audio_dict.update(
                        {f"test_audio/{test_wav_name}":
                            test_out[0].data.cpu().numpy()}
                    )

                # don't worry about caching val dataset for now
                for loader in [val_loader]:

                    loader_name = "val"
                    v_data = enumerate(loader)
                    logging.info(f"Validating on {loader_name} dataset...")
                    v_loss_mel_avg = utils.RunningAvg()
                    v_loss_fairseq_avg = utils.RunningAvg()
                    v_mcd_avg = utils.RunningAvg()

                    with torch.no_grad():
                        for v_batch_idx, v_batch in tqdm(v_data, total=len(loader)):
                            v_output, v_gt, og = net_g_step(
                                v_batch, net_g, device)

                        if config['aux_mel']['c'] > 0:
                            v_loss_mel = utils.aux_mel_loss()
                            v_loss_mel_avg.update(v_loss_mel)
                        if fairseq_model is not None:
                            v_loss_fairseq = utils.fairseq_loss(
                                output, gt, fairseq_model) * config['aux_fairseq']['c']
                            v_loss_fairseq_avg.update(v_loss_fairseq)
                        v_mcd = utils.mcd(
                            v_output, v_gt, config['data']['sr'])
                        v_mcd_avg.update(v_mcd)

                    if config['aux_mel']['c'] > 0:
                        scalar_dict.update(
                            {f"{loader_name}_metrics/mel": v_loss_mel_avg(),
                             f"{loader_name}_metrics/mcd": v_mcd_avg()}
                        )
                        v_loss_mel_avg.reset()
                    if fairseq_model is not None:
                        scalar_dict.update(
                            {f"{loader_name}_metrics/fairseq": v_loss_fairseq_avg()}
                        )
                        v_loss_fairseq_avg.reset()
                    v_mcd_avg.reset()
                    audio_dict.update(
                        {f"{loader_name}_audio/gt_{i}": v_gt[i].data.cpu().numpy()
                         for i in range(min(3, v_gt.shape[0]))}
                    )
                    audio_dict.update(
                        {f"{loader_name}_audio/in_{i}": og[i].data.cpu().numpy()
                         for i in range(min(3, og.shape[0]))}
                    )
                    audio_dict.update(
                        {f"{loader_name}_audio/pred_{i}": v_output[i].data.cpu().numpy()
                         for i in range(min(3, v_output.shape[0]))}
                    )

                net_g.train()

                utils.summarize(
                    writer=writer,
                    global_step=global_step,
                    scalars=scalar_dict,
                    audios=audio_dict,
                    audio_sampling_rate=config['data']['sr'],
                )

                if global_step % config['checkpoint_interval'] == 0:
                    g_checkpoint = os.path.join(
                        checkpoint_dir, f"G_{global_step}.pth")
                    d_checkpoint = os.path.join(
                        checkpoint_dir, f"D_{global_step}.pth")
                    utils.save_state(
                        net_g,
                        optim_g,
                        lr,
                        epoch,
                        global_step,
                        g_checkpoint
                    )
                    utils.save_state(
                        net_d,
                        optim_d,
                        lr,
                        epoch,
                        global_step,
                        d_checkpoint
                    )
                    logging.info(
                        f"Saved checkpoints to {g_checkpoint} and {d_checkpoint}")
                    progress_bar.reset()
                torch.cuda.empty_cache()

        scheduler_g.step()
        scheduler_d.step()




In [57]:
import json 
with open('experiments/llvc/config.json') as f:
    config = json.load(f)
# config

In [58]:
training_runner(config, 'test_train')


Loaded dataset containing 100 elements
Loaded dataset containing 100 elements




[A[A                                                                                                                                                                                     | 0/5000 [00:00<?, ?it/s]

[A[A                                                                                                                                                                           | 1/5000 [00:02<3:04:03,  2.21s/it]

[A[A                                                                                                                                                                           | 2/5000 [00:03<2:40:51,  1.93s/it]

[A[A                                                                                                                                                                           | 3/5000 [00:05<2:31:23,  1.82s/it]

[A[A                                                                                                                                        

KeyboardInterrupt: 

In [8]:
import os
import torch
from scipy.io.wavfile import read
import os
import glob


def get_dataset(dir):
    original_files = glob.glob(os.path.join(dir, "*_original.wav"))
    converted_files = []
    for original_file in original_files:
        converted_file = original_file.replace(
            "_original.wav", "_converted.wav")
        converted_files.append(converted_file)
    return original_files, converted_files


def load_wav(full_path):
    sampling_rate, data = read(full_path)
    return data, sampling_rate




In [9]:
!ls test_wavs

174-50561-0000.wav    2902-9006-0000.wav   7850-73752-0000.wav
1919-142785-0000.wav  5895-34615-0000.wav  8842-302196-0000.wav
2086-149214-0000.wav  652-129742-0000.wav
2412-153947-0000.wav  777-126732-0000.wav


In [10]:
original_data, o_sr = load_wav("test_wavs/174-50561-0000.wav")


In [12]:
original_data.dtype

dtype('int16')

In [13]:
converted = torch.from_numpy(original_data)
converted = converted.unsqueeze(0).to(torch.float) / 32768
if converted.shape[-1] < 240:
    converted = torch.cat(
        (converted, torch.zeros(1, 240 - converted.shape[-1])), dim=1
    )
else:
    converted = converted[:, : 240]


In [14]:
converted.shape

torch.Size([1, 240])

In [15]:
converted.dtype

torch.float32