In [1]:
import os
import sys
project_name = 'ai_synth'
src_dir_relative_path = '../../src'
if src_dir_relative_path not in sys.path:
    sys.path.append(src_dir_relative_path)
import copy
import numpy as np
import torch

from pathlib import Path
from model.model import DecoderNetwork
from synth.parameters_normalizer import Normalizer
from model.loss import spectral_loss
from synth.synth_architecture import SynthModular
from main_hp_search_dec_only import configure_experiment
from dataset.ai_synth_dataset import AiSynthDataset
from synth.synth_constants import synth_constants
from utils.train_utils import to_torch_recursive
from utils.visualization_utils import calc_loss_vs_param_range
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib notebook

import seaborn as sns
sns.set_style('whitegrid')

notebook_path = Path('.').resolve()
dir_list = notebook_path.parts
root_index = dir_list.index(project_name)
abs_path = Path(*dir_list[:root_index+1])

project_root = abs_path

In [2]:
# Setup experiment

exp_name = 'surrogate_osc_fm_sin'
dataset_name = 'fm_sin_single'
device = 'cpu'

config_path = os.path.join(project_root, 'configs', 'optimization_analysis', 'lfo_sine_single_synth_config_hp_search.yaml')
data_path = os.path.join(project_root, 'data', dataset_name, 'train', '')

cfg = configure_experiment(exp_name, dataset_name, config_path, debug=True)

synth = SynthModular(preset_name=cfg.synth.preset,
                     synth_constants=synth_constants,
                     device=device)

decoder_net = DecoderNetwork(preset=cfg.synth.preset, device=device)
normalizer = Normalizer(cfg.synth.note_off_time, cfg.synth.signal_duration, synth_constants)

Deleting previous experiment...


In [3]:
dataset = AiSynthDataset(data_path, noise_std=0)

target_sample = dataset[0]
target_signal, target_param_dict, signal_index = target_sample

target_param_dict = to_torch_recursive(target_param_dict, device, ignore_dtypes=(str, tuple))
target_param_dict[(1, 1)]['parameters']['waveform'] = [target_param_dict[(1, 1)]['parameters']['waveform']]

print(f"target parameters full range: \n{target_param_dict}")


parameters_to_freeze = {(1, 1): {'operation': 'lfo',
                                      'parameters': ['freq', 'waveform', 'active']},
                        (0, 2): {'operation': 'fm_sine',
                                 'parameters': ['active', 'fm_active', 'amp_c', 'mod_index']}}

target_params_01 = normalizer.normalize(target_param_dict)
decoder_net.apply_params(target_params_01)
decoder_net.freeze_params(parameters_to_freeze)

predicted_params_01 = decoder_net()

predicted_params_full_range = normalizer.denormalize(predicted_params_01)

synth.update_cells_from_dict(predicted_params_full_range)
generated_target_signal, _ = synth.generate_signal(signal_duration=1, batch_size=1)
target_signal = target_signal.to(device)

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\noamk\\PycharmProjects\\ai_synth\\data\\fm_sin_single\\train\\params_dataset.pkl'

In [None]:
cfg.multi_spectral_loss_spec_type = 'SPECTROGRAM'
loss_handler = spectral_loss.SpectralLoss(loss_type=cfg.loss.spec_loss_type,
                                          loss_preset=cfg.loss.preset,
                                          synth_constants=synth_constants, device=device)

params_loss_handler = torch.nn.MSELoss()

In [None]:
param_to_visualize = {'param_name': 'freq_c', 'cell_index': (0, 2), 'min_val': 0, 'max_val': 2000, 'n_steps': 2000}

loss_vals, param_range = calc_loss_vs_param_range(synth, target_param_dict, target_signal, loss_handler, **param_to_visualize)

In [None]:
num_epochs = 200
starting_frequency = [[0.5]]
decoder_net.apply_params_partial({(0, 2):
                                     {'operation': 'fm_sine',
                                      'parameters': {'freq_c': starting_frequency}
                                     }
                                 })

base_lr = 6e-3
optimizer = torch.optim.Adamax(decoder_net.parameters(), lr=base_lr)

train_res = []
for e in range(num_epochs):

    predicted_params_01 = decoder_net.forward()

    predicted_params_full_range = normalizer.denormalize(predicted_params_01)
    predicted_freq = predicted_params_full_range[(0, 2)]['parameters']['freq_c']

    synth.update_cells_from_dict(predicted_params_full_range)
    predicted_signal, _ = synth.generate_signal(signal_duration=1)

    target_signal_unsqueezed = target_signal.to(device).unsqueeze(dim=0)
    loss, _, _ = loss_handler.call(target_signal_unsqueezed, predicted_signal, step=e)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_res.append((e, predicted_freq.item(), loss.item(), decoder_net.parameters_dict[decoder_net.get_key((0, 2), 'fm_sine', 'freq_c')].weight.grad))

In [None]:
from matplotlib import animation

fig, ax = plt.subplots(figsize=(15, 5))
l1, = ax.plot(param_range, loss_vals, 'o-', label='loss surface', markevery=[-1])
l2, = ax.plot([], [], 'o-', label='training progress')
ax.legend(loc='center right')
# ax.set_xlim(0,100)
# ax.set_ylim(0,1)

def animate(i):
    xi = [train_res[j][1] for j in range(i)]
    yi = [train_res[j][2] for j in range(i)]
    l2.set_data(xi, yi)
    return (l2)

animation.FuncAnimation(fig, animate, frames=num_epochs, interval=50)