In [45]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


# Inference process of WaveGrad

In [2]:
import sys
sys.path.insert(0, '..')

import json
import IPython.display as ipd

import torch

from tqdm import tqdm

import utils
import benchmark
from model import WaveGrad
from data import AudioDataset, MelSpectrogramFixed

**Load configuration**

In [3]:
CONFIG_PATH='../configs/default.json'

with open(CONFIG_PATH) as f:
    config = utils.ConfigWrapper(**json.load(f))
config.training_config.logdir = f'../{config.training_config.logdir}'
config.training_config.train_filelist_path = f'../{config.training_config.train_filelist_path}'
config.training_config.test_filelist_path = f'../{config.training_config.test_filelist_path}'
config

{'model_config': {'factors': [5, 5, 3, 2, 2], 'upsampling_preconv_out_channels': 768, 'upsampling_out_channels': [512, 512, 256, 128, 128], 'upsampling_dilations': [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], 'downsampling_preconv_out_channels': 32, 'downsampling_out_channels': [128, 128, 256, 512], 'downsampling_dilations': [[1, 2, 4], [1, 2, 4], [1, 2, 4], [1, 2, 4]]}, 'data_config': {'sample_rate': 22050, 'n_fft': 1024, 'win_length': 1024, 'hop_length': 300, 'f_min': 80.0, 'f_max': 8000, 'n_mels': 80}, 'training_config': {'logdir': '../logs/default', 'continue_training': False, 'train_filelist_path': '../filelists/train.txt', 'test_filelist_path': '../filelists/test.txt', 'batch_size': 48, 'segment_length': 7200, 'lr': 0.001, 'grad_clip_threshold': 1, 'scheduler_step_size': 1, 'scheduler_gamma': 0.95, 'n_epoch': 100000000, 'n_samples_to_test': 4, 'test_interval': 1, 'training_noise_schedule': {'n_iter': 1000, 'betas_range': [1e-06, 0.01]}, 'test_noise_sche

**Initialize the model**

In [4]:
model = WaveGrad(config).cuda()
print(f'Number of parameters: {model.nparams}')

Number of parameters: 15810401


In [5]:
# model.load_state_dict(torch.load('../logs/default/checkpoint_180630.pt)['model'], strict=False)
model, _, _ = utils.load_latest_checkpoint(config.training_config.logdir, model)

Latest checkpoint: ../logs/default/checkpoint_180630.pt


**Initialize the dataset**

In [6]:
dataset = AudioDataset(config, training=False)
mel_fn = MelSpectrogramFixed(
    sample_rate=config.data_config.sample_rate,
    n_fft=config.data_config.n_fft,
    win_length=config.data_config.win_length,
    hop_length=config.data_config.hop_length,
    f_min=config.data_config.f_min,
    f_max=config.data_config.f_max,
    n_mels=config.data_config.n_mels,
    window_fn=torch.hann_window
).cuda()

In [7]:
TEST_BATCH_SIZE=1

# Sample test batch from test set 
test_batch = dataset.sample_test_batch(TEST_BATCH_SIZE)

for test_sample in test_batch:
    ipd.display(ipd.Audio(test_sample.squeeze(), rate=22050))

**Grid search of best schedule (optional, otherwise set betas in the next section by hand)**

Note: the lower `step` argument, the more accurate the search is.

In [8]:
iters_best_schedule, stats = benchmark.iters_schedule_grid_search(
    model=model, n_iter=6, config=config, step=1, test_batch_size=2,
    path_to_store_stats='schedules/gs_stats_6iters.pt',
    verbose=True
)

Initializing betas grid...


  0%|          | 0/5670 [00:00<?, ?it/s]

Grid size: 5670
Initializing utils...
Starting search...


                                                   

Saving stats to schedules/gs_stats_7iters.pt...
Best betas on 6 iterations: [8.e-06 6.e-05 9.e-04 7.e-03 4.e-02 3.e-02 5.e-02]


In [9]:
torch.save(iters_best_schedule, 'schedules/iters6_best_schedule.pt')

**Set noise schedule**

Note: `init_kwargs` should always contain the key `steps`.

In [35]:
LOAD_GS_ITERS_BEST_SCHEDULE=True

ITERS_SCHEDULE_PATHS={
    6: 'schedules/iters6_best_schedule.pt',
    7: 'schedules/iters7_best_schedule.pt',
#     8: 'schedules/iters8_best_schedule.pt',
    12: 'schedules/iters12_best_schedule.pt',
    25: 'schedules/iters25_best_schedule.pt',
#     50: 'schedules/iters50_best_schedule.pt',
#     100: 'schedules/iters100_best_schedule.pt',
#     1000: 'schedules/iters1000_best_schedule.pt',
}

if LOAD_GS_ITERS_BEST_SCHEDULE:
    SCHEDULES = {
        str(schedule_type): {
            'init': lambda **kwargs: torch.FloatTensor(torch.load(kwargs['path'])),
            'init_kwargs': {'steps': schedule_type, 'path': path}
        } for schedule_type, path in ITERS_SCHEDULE_PATHS.items()
    }
else:
    iters6_init = lambda **kwargs: torch.FloatTensor([1e-6, 1e-5, 1e-4, 1e-3, 1e-3, 1e-2])
    SCHEDULES = {
        '1000': {'init': torch.linspace, 'init_kwargs': {'steps': 1000, 'start': 1e-6, 'end': 1e-2}},
        '100': {'init': torch.linspace, 'init_kwargs': {'steps': 100, 'start': 1e-6, 'end': 1e-2}},
        '50': {'init': torch.linspace, 'init_kwargs': {'steps': 50, 'start': 1e-6, 'end': 1e-2}},
        '25': {'init': torch.linspace, 'init_kwargs': {'steps': 25, 'start': 1e-6, 'end': 1e-2}},
        '12': {'init': torch.linspace, 'init_kwargs': {'steps': 12, 'start': 1e-6, 'end': 1e-2}},
        '8': {'init': torch.linspace, 'init_kwargs': {'steps': 8, 'start': 1e-6, 'end': 1e-2}},
        '7': {'init': torch.linspace, 'init_kwargs': {'steps': 7, 'start': 1e-6, 'end': 1e-2}},
        '6': {'init': iters6_init, 'init_kwargs': {'steps': 6}}
    }

In [42]:
SCHEDULE_TYPE_TO_SET='6'

model.set_new_noise_schedule(
    init=SCHEDULES[SCHEDULE_TYPE_TO_SET]['init'],
    init_kwargs=SCHEDULES[SCHEDULE_TYPE_TO_SET]['init_kwargs']
)
model.noise_schedule_kwargs

{'init': <function __main__.<dictcomp>.<lambda>(**kwargs)>,
 'init_kwargs': {'steps': 6, 'path': 'schedules/iters6_best_schedule.pt'}}

**Inference**

In [43]:
STORE_INTERMEDIATE_STATES=False

test_preds = []
for test_sample in tqdm(test_batch):
    mel = mel_fn(test_sample[None].cuda())
    outputs = model.forward(
        mel, store_intermediate_states=STORE_INTERMEDIATE_STATES
    )
    test_preds.append(outputs)

100%|██████████| 1/1 [00:00<00:00,  2.12it/s]


In [44]:
for signal in test_preds:
    ipd.display(ipd.Audio((signal).squeeze().cpu(), rate=config.data_config.sample_rate))

**Compute real-time factor (RTF)**

In [20]:
rtf_stats = benchmark.estimate_average_rtf_on_filelist(
    '../filelists/test.txt', config, model, verbose=True
)
rtf_stats

100%|██████████| 100/100 [00:35<00:00,  2.83it/s]

DEVICE: cuda:0. average_rtf=0.04834762980697031, std=0.0031231637153867078





{'rtfs': [0.06185184782608696,
  0.0467216590909091,
  0.046799224137931035,
  0.046884507968127494,
  0.045916382845188286,
  0.04581568245125348,
  0.04654731818181818,
  0.04628940909090909,
  0.046570252865329516,
  0.04632195295902883,
  0.05597049337748344,
  0.046341164978292326,
  0.0461812530949106,
  0.04597791509433962,
  0.046297382727272726,
  0.04682762078651685,
  0.05076252409638555,
  0.04609193676222597,
  0.0480396,
  0.04604682068965517,
  0.04714714873417722,
  0.04715321370967741,
  0.04639231497797357,
  0.04723374539877301,
  0.04688595569620253,
  0.05317611340206186,
  0.047303415032679735,
  0.052982562500000004,
  0.04753064900662251,
  0.04703285607476636,
  0.047200249525616696,
  0.04652045334507042,
  0.047389933078393884,
  0.046650975,
  0.05139360148514852,
  0.04665593507462687,
  0.04688634448818898,
  0.04615239303482587,
  0.04628166350067842,
  0.04742788686131386,
  0.048680717252396165,
  0.04671087397708674,
  0.05095476411290323,
  0.04619326