In [19]:
import torch
import os
import librosa
import numpy as np
import matplotlib.pyplot as plt

from model import helper
from model.model import SimpleSynthNetwork
from config import SynthConfig, Config, DatasetConfig
from dataset.ai_synth_dataset import AiSynthDataset, NSynthDataset, create_data_loader
from synth.synth_architecture import SynthModular, SynthModularCell
from run_scripts.inference.inference import visualize_signal_prediction
from run_scripts.train_helper import *

from tqdm import tqdm

dataset_to_visualize = 'modular_synth50k'

cfg = Config()
synth_cfg = SynthConfig()
dataset_cfg = DatasetConfig(dataset_to_visualize)

device = 'cuda:0'



In [6]:
model_ckpt = r'/home/almogelharar/almog/ai_synth/experiments/current/modular_synth_120e/ckpts/trained_synth_net.pt'
model = SimpleSynthNetwork('MODULAR', synth_cfg, cfg, device, backbone='resnet').to(device)
model.load_state_dict(torch.load(model_ckpt))
model.eval()



SimpleSynthNetwork(
  (backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

In [7]:
cwd = r'/home/almogelharar/almog/ai_synth/'
split_to_visualize = 'val'
data_dir = os.path.join(cwd, 'data', dataset_to_visualize, split_to_visualize, '')

wav_files_dir = os.path.join(data_dir, 'wav_files', '')
params_csv_path = os.path.join(data_dir, 'params_dataset.pkl')

ai_synth_dataset = AiSynthDataset(params_csv_path, wav_files_dir, device)
test_dataloader = create_data_loader(ai_synth_dataset, 32, 4, shuffle=False)

wav_files_dir = os.path.join(data_dir, 'wav_files', '')

split_to_visualize = 'val_nsynth'
data_dir = os.path.join('data', dataset_to_visualize, split_to_visualize, '')

nsynth_dataset = NSynthDataset(wav_files_dir, device)
nsynth_dataloader = create_data_loader(nsynth_dataset, 32, 4, shuffle=False)

NSynth dataloader found 4992 wav files in /home/almogelharar/almog/ai_synth/data/modular_synth50k/val/wav_files/


In [10]:
synth_cfg = SynthConfig()
cfg = Config()

transform = helper.mel_spectrogram_transform(cfg.sample_rate).to(device)
normalizer = helper.Normalizer(cfg.signal_duration_sec, synth_cfg)

modular_synth = SynthModular(synth_cfg=synth_cfg,
                             sample_rate=cfg.sample_rate,
                             signal_duration_sec=cfg.signal_duration_sec,
                             num_sounds_=1,
                             device=device,
                             preset='MODULAR')

In [22]:
def infer_and_compare(signals, target_params_dic, signals_indices):
    signals = helper.move_to(signals, device)
    normalizer = helper.Normalizer(cfg.signal_duration_sec, synth_cfg)

    transformed_signal = transform(signals)

    output_dic = model(transformed_signal)

    # Infer predictions
    denormalized_output_dict = normalizer.denormalize(output_dic)
    predicted_param_dict = helper.clamp_regression_params(denormalized_output_dict, synth_cfg, cfg)

    update_params = []
    for index, operation_dict in predicted_param_dict.items():
        synth_modular_cell = SynthModularCell(index=index, parameters=operation_dict['params'])
        update_params.append(synth_modular_cell)

    modular_synth.update_cells(update_params)
    modular_synth.generate_signal(num_sounds_=len(transformed_signal))

    # for i in range(len(signals)):
    for i in range(5):

        sample_params_orig, sample_params_pred = parse_synth_params(target_params_dic, predicted_param_dict, i)
        signal_index = signals_indices[i]

        orig_audio = signals[i]
        pred_audio = modular_synth.signal[i]
        orig_audio_np = orig_audio.detach().cpu().numpy()
        pred_audio_np = pred_audio.detach().cpu().numpy()

        orig_audio_transformed = librosa.feature.melspectrogram(y=orig_audio_np,
                                                                sr=cfg.sample_rate,
                                                                n_fft=1024,
                                                                hop_length=512,
                                                                n_mels=64)
        orig_audio_transformed_db = librosa.power_to_db(orig_audio_transformed, ref=np.max)
        pred_audio_transformed = librosa.feature.melspectrogram(y=pred_audio_np,
                                                                sr=cfg.sample_rate,
                                                                n_fft=1024,
                                                                hop_length=512,
                                                                n_mels=64)
        pred_audio_transformed_db = librosa.power_to_db(pred_audio_transformed, ref=np.max)

        # plot original vs predicted signal
        plt.figure(figsize=[30, 20])
        plt.ion()
        plt.subplot(2, 2, 1)
        plt.title(f"original audio")
        plt.ylim([-1, 1])
        plt.plot(orig_audio_np)
        plt.subplot(2, 2, 2)
        plt.ylim([-1, 1])
        plt.title("predicted audio")
        plt.plot(pred_audio_np)
        plt.subplot(2, 2, 3)
        librosa.display.specshow(orig_audio_transformed_db, sr=cfg.sample_rate, hop_length=512,
                                 x_axis='time', y_axis='mel')
        plt.colorbar(format='%+2.0f dB')
        plt.subplot(2, 2, 4)
        librosa.display.specshow(pred_audio_transformed_db, sr=cfg.sample_rate, hop_length=512,
                                 x_axis='time', y_axis='mel')
        plt.colorbar(format='%+2.0f dB')
        plt.ioff()
        plots_path = dataset_cfg.inference_plots_dir.joinpath(f"sound{signal_index}_plots.png")
        plt.savefig(plots_path)

        signal_vis = visualize_signal_prediction(orig_audio[i], pred_audio[i], sample_params_orig, sample_params_pred, db=True)
        break

In [23]:
for signals, target_params_dic, signals_indices in test_dataloader:
    infer_and_compare(signals, target_params_dic, signals_indices)

for signals, target_params_dic, signals_indices in nsynth_dataloader:
    infer_and_compare(signals, target_params_dic, signals_indices)

RuntimeError: DataLoader worker (pid(s) 33227, 33267, 33307, 33347) exited unexpectedly