In [1]:
import sys
sys.path.append('../../../src')
import copy
import numpy as np
import torch

from model.model import DecoderOnlyNetwork
from synth.parameters_normalizer import Normalizer
from model.loss import spectral_loss
from synth.synth_architecture import SynthModular
from main_hp_search_dec_only import configure_experiment
from dataset.ai_synth_dataset import AiSynthDataset
from synth.synth_constants import synth_constants
from tqdm import tqdm

%matplotlib notebook
import matplotlib
# matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style('whitegrid')

In [2]:
# Setup experiment

exp_name = 'training_visualization_fm_saw'
dataset_name = 'lfo_saw_single'
config_name = r'C:\Users\noamk\PycharmProjects\ai_synth\configs\optimization_analysis\lfo_saw_single_synth_config_hp_search.yaml'
device = 'cuda'

cfg = configure_experiment(exp_name, dataset_name, config_name, debug=True)

synth = SynthModular(preset_name=cfg.synth.preset,
                     synth_constants=synth_constants,
                     device=device)

decoder_net = DecoderOnlyNetwork(preset=cfg.synth.preset, device=device)
normalizer = Normalizer(cfg.synth.note_off_time, cfg.synth.signal_duration, synth_constants)

Deleting previous experiment...


In [3]:
def calc_loss_vs_param_range(param_name, cell_index, min_val, max_val, n_steps):
    param_range = np.linspace(min_val, max_val, n_steps)

    loss_vals = []
    for param_val in tqdm(param_range):
        update_params = copy.deepcopy(target_param_dict)

        update_params[cell_index]['parameters'].update({param_name: param_val})

        synth.update_cells_from_dict(update_params)
        signal, _ = synth.generate_signal(signal_duration=1)

        target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
        loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

        loss_vals.append(loss_val.detach().cpu().numpy().item())

    return loss_vals, param_range

def plot_signal(signal, title):
    plt.figure(figsize=(15, 5))
    plt.title(title)
    plt.plot(signal.detach().numpy().squeeze())
    plt.show()

def plot_loss_vs_param(param_range, loss_vals, title):
    plt.figure(figsize=(15, 5))
    plt.plot(param_range, loss_vals)
    plt.title(title)
    plt.show()

def param_dict_floats_to_tensors(param_dict):

    param_dict_as_tensors = {}
    for key in param_dict.keys():
        param_dict_as_tensors[key] = {'operation': param_dict[key]['operation'],
                                        'parameters': {}}
        for param, val in param_dict[key]['parameters'].items():
            param_dict_as_tensors[key]['parameters'][param] = torch.tensor([val], device=device)

    return param_dict_as_tensors

In [4]:
dataset = AiSynthDataset(fr'C:\Users\noamk\PycharmProjects\ai_synth\data\{dataset_name}\train', noise_std=0)
target_sample = dataset[0]
target_signal, target_param_dict, signal_index = target_sample

print(f"target parameters full range: \n{target_param_dict}")

target_param_dict = {(1, 1): {'operation': 'lfo',
                            'parameters': {'active': torch.tensor([[-1000.0]], device=device),
                                           'output': [[(-1, -1)]],
                                           'freq': torch.tensor([14.285357442943784], device=device),
                                           'waveform': torch.tensor([[0., 1000.0, 0.]], device=device)}},
                   (0, 2): {'operation': 'fm_saw',
                            'parameters': {'fm_active': torch.tensor([[-1000.0]], device=device),
                                           'active': torch.tensor([[-1000.0]], device=device),
                                           'amp_c': torch.tensor([0.6187255599871848], device=device),
                                           'freq_c': torch.tensor([349.22823143300377], device=device),
                                           'mod_index': torch.tensor([0.02403950683025824], device=device)}}}

parameters_to_freeze = {(1, 1): {'operation': 'lfo',
                                      'parameters': ['freq', 'waveform', 'active']},
                        (0, 2): {'operation': 'fm_saw',
                                 'parameters': ['active', 'fm_active', 'amp_c', 'mod_index']}}

target_params_01 = normalizer.normalize(target_param_dict)
print(f"target parameters 0-1: \n{target_params_01}")

# plot target signal
# plot_signal(target_signal, title="Target Signal")

decoder_net.apply_params(target_params_01)
decoder_net.freeze_params(parameters_to_freeze)

predicted_params_01 = decoder_net()
# print(f"decoder_net output: \n {predicted_params_01}")

predicted_params_full_range = normalizer.denormalize(predicted_params_01)
# print(f"decoder_net output full range: \n {predicted_params_full_range}")

synth.update_cells_from_dict(predicted_params_full_range)
generated_target_signal, _ = synth.generate_signal(signal_duration=1, batch_size=1)

# print predicted signal
# plot_signal(target_signal, title="DecoderNet Output Signal")
target_signal = target_signal.to(device)

target parameters full range: 
{(1, 1): {'operation': 'lfo', 'parameters': {'active': True, 'output': [(0, 2)], 'freq': 6.431552842983031, 'waveform': 'sine'}}, (0, 2): {'operation': 'fm_saw', 'parameters': {'fm_active': True, 'active': True, 'amp_c': 0.7246405936607198, 'freq_c': 1975.5332050244983, 'mod_index': 0.05109589980450436}}}
target parameters 0-1: 
{(1, 1): {'operation': 'lfo', 'parameters': {'active': tensor([[-1000.]], device='cuda:0'), 'output': [[(-1, -1)]], 'freq': tensor([0.9507], device='cuda:0'), 'waveform': tensor([[   0., 1000.,    0.]], device='cuda:0')}}, (0, 2): {'operation': 'fm_saw', 'parameters': {'fm_active': tensor([[-1000.]], device='cuda:0'), 'active': tensor([[-1000.]], device='cuda:0'), 'amp_c': tensor([0.6187], device='cuda:0'), 'freq_c': tensor([0.1523], device='cuda:0'), 'mod_index': tensor([0.1560], device='cuda:0')}}}
Missing amp param in Oscillator module lfo. Assuming fixed amp. Please check Synth structure if this is unexpected.


  SimpleWeightLayer(torch.tensor(init_values['freq'],
  SimpleWeightLayer(torch.tensor(init_values['waveform'],
  SimpleWeightLayer(torch.tensor(init_values[fm_param], dtype=torch.float, device=self.device,


In [5]:
cfg.multi_spectral_loss_spec_type = 'SPECTROGRAM'
loss_handler = spectral_loss.SpectralLoss(loss_type=cfg.loss.spec_loss_type,
                                          loss_preset=cfg.loss.preset,
                                          synth_constants=synth_constants,
                                          device=device)

params_loss_handler = torch.nn.MSELoss()

In [6]:
from tqdm import tqdm

fm_freq_vals = np.linspace(0, 2000, 20000)

losses = []
for freq_val in tqdm(fm_freq_vals):
    update_params = copy.deepcopy(target_param_dict)

    update_params[(0, 2)]['parameters'].update({'freq_c': freq_val})

    synth.update_cells_from_dict(update_params)
    generated_signal, _ = synth.generate_signal(signal_duration=1)

    target_signal_unsqueezed = target_signal.to(device).unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, generated_signal, step=0, return_spectrogram=False)

    losses.append(loss_val.detach().cpu().numpy().item())

plot_loss_vs_param(fm_freq_vals, losses, 'Loss vs Frequency')

100%|██████████| 20000/20000 [02:17<00:00, 145.04it/s]


<IPython.core.display.Javascript object>

In [7]:
num_epochs = 200
starting_frequency = 0.1
decoder_net.apply_params_partial({(0, 2):
                                     {'operation': 'fm_saw',
                                      'parameters': {'freq_c': starting_frequency}
                                     }
                                 })

base_lr = 1e-3
optimizer = torch.optim.Adamax(decoder_net.parameters())

train_res = []
for e in range(num_epochs):

    predicted_params_01 = decoder_net.forward()
    # print(f"decoder_net output: \n {predicted_params_01}")

    predicted_params_full_range = normalizer.denormalize(predicted_params_01)
    predicted_freq = predicted_params_full_range[(0, 2)]['parameters']['freq_c']
    # print(f"decoder_net output full range: \n {predicted_params_full_range}")

    synth.update_cells_from_dict(predicted_params_full_range)
    predicted_signal, _ = synth.generate_signal(signal_duration=1)

    # plt.figure(figsize=(25, 10))
    # plt.plot(target_signal.detach().numpy().squeeze())
    # plt.plot(predicted_signal.cpu().detach().numpy().squeeze())
    # plt.show()

    target_signal_unsqueezed = target_signal.to(device).unsqueeze(dim=0)
    loss, _, _ = loss_handler.call(target_signal_unsqueezed, predicted_signal, step=e)
    # print(f"loss: {loss.item()}")

    # loss = params_loss_handler(predicted_params_01[(0, 2)]['parameters']['freq_c'], torch.tensor(0.15).cuda())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('freq', predicted_freq.item())

    train_res.append((e, predicted_freq.item(), loss.item(), decoder_net.parameters_dict[decoder_net.get_key((0, 2), 'fm_saw', 'freq_c')].weight.grad))
    print(decoder_net.parameters_dict[decoder_net.get_key((0, 2), 'fm_saw', 'freq_c')].weight.grad)


freq 229.3004608154297
tensor(39.5935, device='cuda:0')
freq 224.71446228027344
tensor(1.5746, device='cuda:0')
freq 222.4438934326172
tensor(-78.8555, device='cuda:0')
freq 223.41748046875
tensor(92.3737, device='cuda:0')
freq 222.67340087890625
tensor(-8.8997, device='cuda:0')
freq 222.21844482421875
tensor(-26.5227, device='cuda:0')
freq 222.14181518554688
tensor(-68.6458, device='cuda:0')
freq 222.7350311279297
tensor(-27.7829, device='cuda:0')
freq 223.4677276611328
tensor(-45.0405, device='cuda:0')
freq 224.4482879638672
tensor(-134.7128, device='cuda:0')
freq 225.7187042236328
tensor(-28.6752, device='cuda:0')
freq 226.94747924804688
tensor(84.6857, device='cuda:0')
freq 227.60348510742188
tensor(-73.3760, device='cuda:0')
freq 228.508056640625
tensor(113.7407, device='cuda:0')
freq 228.7920379638672
tensor(57.3952, device='cuda:0')
freq 228.79324340820312
tensor(45.3267, device='cuda:0')
freq 228.60374450683594
tensor(-83.5941, device='cuda:0')
freq 228.78074645996094
tensor(18

In [9]:
%matplotlib notebook
from matplotlib import animation

fig, ax = plt.subplots(figsize=(25, 5))
l1, = ax.plot(fm_freq_vals, losses, 'o-', label='loss surface', markevery=[-1])
l2, = ax.plot([], [], 'o-', label='training progress')
ax.legend(loc='center right')
# ax.set_xlim(0,100)
# ax.set_ylim(0,1)

def animate(i):
    xi = [train_res[j][1] for j in range(i)]
    yi = [train_res[j][2] for j in range(i)]
    l2.set_data(xi, yi)
    return (l2)

animation.FuncAnimation(fig, animate, frames=num_epochs, interval=500)

<IPython.core.display.Javascript object>

<matplotlib.animation.FuncAnimation at 0x281ace491c0>