In [23]:
import sys
sys.path.append('../../../src')
import copy
import numpy as np
import torch

from torchaudio.transforms import Spectrogram, MelSpectrogram, Resample
from model.loss import spectral_loss
from synth.synth_architecture import SynthModular
from main_hp_search_dec_only import configure_experiment
from dataset.ai_synth_dataset import AiSynthDataset
from synth.synth_constants import synth_constants
from tqdm import tqdm


%matplotlib notebook

import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style('whitegrid')

In [24]:
def calc_loss_vs_param_range(param_name, cell_index, min_val, max_val, n_steps):
    param_range = np.linspace(min_val, max_val, n_steps)

    loss_vals = []
    for param_val in tqdm(param_range):
        update_params = copy.deepcopy(target_param_dict)

        update_params[cell_index]['parameters'].update({param_name: param_val})

        synth.update_cells_from_dict(update_params)
        signal, _ = synth.generate_signal(signal_duration=1)

        target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
        loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

        loss_vals.append(loss_val.detach().cpu().numpy().item())

    return loss_vals, param_range

def plot_loss_vs_param(param_range, loss_vals, title):
    plt.figure(figsize=(15, 5))
    plt.plot(param_range, loss_vals)
    plt.title(title)
    plt.show()

def param_dict_floats_to_tensors(param_dict):
    param_dict_as_tensors = {}
    for key in param_dict.keys():
        param_dict_as_tensors[key] = {'operation': param_dict[key]['operation'],
                                        'parameters': {}}
        for param, val in param_dict[key]['parameters'].items():
            param_dict_as_tensors[key]['parameters'][param] = torch.tensor([val], device=device)

    return param_dict_as_tensors

In [25]:
# Setup experiment
exp_name = 'training_visualization_fm_saw'
dataset_name = 'lfo_saw_single'
config_name = r'C:\Users\noamk\PycharmProjects\ai_synth\configs\optimization_analysis\lfo_saw_single_synth_config_hp_search.yaml'
device = 'cuda'

cfg = configure_experiment(exp_name, dataset_name, config_name, debug=True)

synth = SynthModular(preset_name=cfg.synth.preset,
                     synth_constants=synth_constants,
                     device=device)


Deleting previous experiment...


In [None]:
dataset = AiSynthDataset(fr'C:\Users\noamk\PycharmProjects\ai_synth\data\{dataset_name}\train', noise_std=0)

target_sample = dataset[0]
target_signal, target_param_dict, signal_index = target_sample

print(f"target parameters full range: \n{target_param_dict}")

# target_param_dict = param_dict_floats_to_tensors(target_param_dict)

param_dict = target_param_dict

param_dict_as_tensors = {}
for key in param_dict.keys():
    param_dict_as_tensors[key] = {'operation': param_dict[key]['operation'],
                                    'parameters': {}}
    for param, val in param_dict[key]['parameters'].items():
        if isinstance(val, list):
            param_dict_as_tensors[key]['parameters'][param] = torch.tensor(val, device=device)
        else:
            if isinstance(val, np.str_):

            param_dict_as_tensors[key]['parameters'][param] = torch.tensor([val], device=device)

# return param_dict_as_tensors


target_signal = target_signal.to(device)

synth.update_cells_from_dict(target_param_dict)

target parameters full range: 
{(1, 1): {'operation': 'lfo', 'parameters': {'active': True, 'output': [(0, 2)], 'freq': 14.630004869328621, 'waveform': 'sawtooth'}}, (0, 2): {'operation': 'fm_saw', 'parameters': {'fm_active': True, 'active': True, 'amp_c': 0.38389441559806636, 'freq_c': 1479.9776908465387, 'mod_index': 0.08196441414940485}}}


  param_dict_as_tensors[key]['parameters'][param] = torch.tensor([val], device=device)


In [6]:
cfg.multi_spectral_loss_spec_type = 'SPECTROGRAM'
loss_handler = spectral_loss.SpectralLoss(loss_type=cfg.loss.spec_loss_type,
                                          loss_preset="mag_logmag",
                                          synth_constants=synth_constants,
                                          device=device)

In [None]:
param_name = 'freq_c'
cell_index = (0, 2)
min_val = 0
max_val = 2000
n_steps = 2000

# loss_vals, param_range = calc_loss_vs_param_range(param_name, cell_index, min_val=min_val, max_val=max_val, n_steps=n_steps)
param_range = np.linspace(min_val, max_val, n_steps)

loss_vals = []
for param_val in param_range:
    update_params = copy.deepcopy(target_param_dict)

    update_params[cell_index]['parameters'].update({param_name: param_val})

    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    loss_vals.append(loss_val.detach().cpu().numpy().item())

plot_loss_vs_param(param_range, loss_vals, title='Loss vs Carrier Frequency')

KeyboardInterrupt: 

In [10]:
fm_freq_vals = np.linspace(0, 2000, 2000)

losses1 = []
for freq_val in fm_freq_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(0, 2)]['parameters'].update({'freq_c': freq_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    losses1.append(loss_val.detach().cpu().numpy().item())

# print(losses1)


In [None]:
# matplotlib.use('Qt5Agg')



In [None]:
plt.figure(figsize=(15, 5))
plt.plot(fm_freq_vals, losses1)
plt.title("Loss vs carrier frequency")
plt.show()

In [None]:
lfo_freq_vals = np.linspace(0, 20, 100)

losses2 = []
for freq_val in lfo_freq_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(1, 1)]['parameters'].update({'freq': freq_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    # resampled_target_signal = resample_op(target_signal.cpu())
    # resampled_pred_signal = resample_op(signal.cpu())

    # plt.plot(resampled_target_signal.detach().numpy().squeeze())
    # plt.plot(resampled_pred_signal.detach().numpy().squeeze())
    #
    # plt.show()
    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    # loss_val = loss_handler.call(resampled_target_signal, resampled_pred_signal.unsqueeze(0), signal_chain_index=0, global_step=0, summary_writer=None, log=False)

    losses2.append(loss_val.detach().cpu().numpy().item())

# print(losses2)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(lfo_freq_vals, losses2)
plt.title("Loss vs LFO frequency")

plt.show()

In [None]:
amp_vals = np.linspace(0, 1, 100)

losses3 = []
for amp_val in amp_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(0, 2)]['parameters'].update({'amp_c': amp_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    # resampled_target_signal = resample_op(target_signal.cpu())
    # resampled_pred_signal = resample_op(signal.cpu())

    # plt.plot(resampled_target_signal.detach().numpy().squeeze())
    # plt.plot(resampled_pred_signal.detach().numpy().squeeze())
    #
    # plt.show()
    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    # loss_val = loss_handler.call(resampled_target_signal, resampled_pred_signal.unsqueeze(0), signal_chain_index=0, global_step=0, summary_writer=None, log=False)

    losses3.append(loss_val.detach().cpu().numpy().item())

# print(losses3)


In [None]:
plt.figure(figsize=(12, 5))
plt.plot(amp_vals, losses3)
plt.title("Loss vs carrier amplitude")
plt.show()

In [None]:
mod_index_vals = np.linspace(0, 0.3, 1000)

losses4 = []
for mod_index_val in mod_index_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(0, 2)]['parameters'].update({'mod_index': mod_index_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    # resampled_target_signal = resample_op(target_signal.cpu())
    # resampled_pred_signal = resample_op(signal.cpu())

    # plt.plot(resampled_target_signal.detach().numpy().squeeze())
    # plt.plot(resampled_pred_signal.detach().numpy().squeeze())
    #
    # plt.show()
    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    # loss_val = loss_handler.call(resampled_target_signal, resampled_pred_signal.unsqueeze(0), signal_chain_index=0, global_step=0, summary_writer=None, log=False)

    losses4.append(loss_val.detach().cpu().numpy().item())

# print(losses4)

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(mod_index_vals, losses4)
plt.title("Loss vs FM mod_index")
plt.show()

In [None]:
sim_target_params = {'freq_c'}

In [None]:
spec_op = Spectrogram(n_fft=512)


fig, ax = plt.subplots(1, 2, figsize=(15, 5))

target_spec = spec_op(target_signal.cpu())

update_params = copy.deepcopy(target_param_dict)

update_params[(0, 1)]['parameters'].update({'freq_c': 200})
synth.update_cells_from_dict(update_params)
signal, _ = synth.generate_signal()

pred_spec = spec_op(signal.cpu().detach())

ax[0].imshow(target_spec.squeeze(), origin='lower')
ax[1].imshow(pred_spec.squeeze(), origin='lower')