In [1]:
import os

import torch

from config import SynthConfig, Config
from dataset.ai_synth_dataset import AiSynthDataset, create_data_loader
from model import helper
from model.model import SimpleSynthNetwork
from run_scripts.inference.inference_helper import inference_loop, process_batch_inference
from synth.synth_architecture import SynthModular

In [2]:

device = 'cuda:0'
preset = 'BASIC_FLOW'

## Create dataset
dataset_to_visualize = 'basic_flow_new_toy'
split_to_visualize = 'test'
data_dir = os.path.join('data', dataset_to_visualize, split_to_visualize, '')

wav_files_dir = os.path.join(data_dir, 'wav_files', '')
params_csv_path = os.path.join(data_dir, 'params_dataset.pkl')

ai_synth_dataset = AiSynthDataset(params_csv_path, wav_files_dir, device)
test_dataloader = create_data_loader(ai_synth_dataset, 10, 0, shuffle=False)


## init
synth_cfg = SynthConfig()
cfg = Config()

synth_obj = SynthModular(synth_cfg=synth_cfg,
                         sample_rate=cfg.sample_rate,
                         signal_duration_sec=cfg.signal_duration_sec,
                         device=device,
                         preset=preset)

transform = helper.mel_spectrogram_transform(cfg.sample_rate).to(device)
normalizer = helper.Normalizer(cfg.signal_duration_sec, synth_cfg)

## Load model
model_ckpt_path1 = r'experiments/current/basic_flow_test/ckpts/synth_net_epoch10.pt'
model_ckpt_path2 = r'experiments/current/basic_flow_w_spec_loss_low_weight/ckpts/synth_net_epoch30.pt'
model = SimpleSynthNetwork(preset, synth_cfg, cfg, device, backbone='resnet').to(device)

In [6]:
model.load_state_dict(torch.load(model_ckpt_path2, map_location=device)['model_state_dict'])
model.eval()

SimpleSynthNetwork(
  (backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

In [7]:
# audio_batch = test_dataloader.__iter__().next()
results2, metrics2 = process_batch_inference(audio_batch, transform, model, normalizer.denormalize, synth_obj, device, cfg)

  mfcc1 = mfcc(sound1, sr=sample_rate, n_mfcc=40)
 -0.00070597] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mfcc2 = mfcc(sound2, sr=sample_rate, n_mfcc=40)
 -0.00131699] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mfcc1 = mfcc(sound1, sr=sample_rate, n_mfcc=40)
 -0.00065546] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mfcc2 = mfcc(sound2, sr=sample_rate, n_mfcc=40)
 -6.85303462e-07  5.22776190e-07] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mfcc1 = mfcc(sound1, sr=sample_rate, n_mfcc=40)
 -8.014069e-04] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mfcc2 = mfcc(sound2, sr=sample_rate, n_mfcc=40)
 -3.6984619e-07  3.0417408e-07] as keyword args. From version 0.10 passing these as positional arguments will result in an 

In [8]:
for k, v in metrics2.items():
    metrics2[k] = v / 10

In [9]:
metrics

defaultdict(float,
            {'lsd_value': 61.6885009765625,
             'pearson_stft': 0.6873675721354875,
             'pearson_fft': 0.481489135158127,
             'mean_average_error': 0.4900597095489502,
             'mfcc_mae': 5.976694107055664,
             'spectral_convergence_value': 0.7485891342163086})

In [10]:
metrics2

defaultdict(float,
            {'lsd_value': 70.4620849609375,
             'pearson_stft': 0.49289836366789547,
             'pearson_fft': 0.3709534455294915,
             'mean_average_error': 0.579149580001831,
             'mfcc_mae': 5.999608993530273,
             'spectral_convergence_value': 0.9838637351989746})