In [10]:
import os
import sys

import copy
import numpy as np
import torch

from pathlib import Path
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib notebook

import seaborn as sns
sns.set_style('whitegrid')

notebook_path = Path('.').resolve()
dir_list = notebook_path.parts
root_index = dir_list.index('diffmoog')
abs_path = Path(*dir_list[:root_index+1])

project_root = abs_path
sys.path.append(str(project_root.joinpath('src')))

from model.model import DecoderNetwork
from synth.parameters_normalizer import Normalizer
from model.loss import spectral_loss
from synth.synth_architecture import SynthModular
from synth.synth_constants import synth_constants
from utils.train_utils import to_torch_recursive
from utils.visualization_utils import calc_loss_vs_param_range


In [11]:
# Setup experiment
device = 'cuda:0'
chain = 'LFO_SAW'

signal_duration = 1.0
note_off_time = 0.75


synth = SynthModular(chain_name=chain,
                     synth_constants=synth_constants,
                     device=device)

decoder_net = DecoderNetwork(preset=chain, device=device)
normalizer = Normalizer(note_off_time, signal_duration, synth_constants)

AssertionError: Torch not compiled with CUDA enabled

In [None]:
%matplotlib inline

# Generate some signal to start from
sample_params = {(0, 0): {'operation': 'lfo', 'parameters': {'freq': 8, 
                                                             'waveform': 'sine',
                                                             'active': 1}},
                 (0, 1): {'operation': 'fm_sine', 'parameters': {'freq_c': 200,
                                                                 'amp_c': 0.5, 
                                                                 'mod_index': 0.15,
                                                                 'active': 1,
                                                                 'fm_active': 1}},
}

synth.update_cells_from_dict(sample_params)
synth.generate_signal(signal_duration=signal_duration, batch_size=1)
target_signal = synth.get_final_signal().to(device).squeeze()

_, ax = plt.subplots(1, 1, figsize=(20, 5))
ax.plot(target_signal.detach().cpu().numpy().squeeze())
plt.show()

In [None]:
target_param_dict = to_torch_recursive(sample_params, device, ignore_dtypes=(str, tuple))
target_params_01 = normalizer.normalize(target_param_dict)

parameters_to_freeze = {(0, 0): {'operation': 'lfo',
                                      'parameters': ['freq', 'waveform', 'active']},
                        (0, 1): {'operation': 'fm_sine',
                                 'parameters': ['active', 'fm_active', 'amp_c', 'mod_index']}}

# Decoder net will try to approximate non frozen params (here carrier frequency) by SGD on DDSP loss
decoder_net.apply_params(target_params_01)
decoder_net.freeze_params(parameters_to_freeze)

In [None]:
spec_loss_type = 'SPECTROGRAM'
loss_handler = spectral_loss.SpectralLoss(loss_type=spec_loss_type,
                                          loss_preset='cumsum_time',
                                          synth_constants=synth_constants,
                                          device=device)

params_loss_handler = torch.nn.MSELoss()

In [None]:
param_to_visualize = {'param_name': 'freq_c', 'cell_index': (0, 1), 'min_val': 0, 'max_val': 2000, 'n_steps': 2000}

loss_vals, param_range = calc_loss_vs_param_range(synth, target_param_dict, target_signal, loss_handler, **param_to_visualize)

In [None]:
# Run gradient descent. Try to play with different starting values, optimizers and learning rates to see the effect

num_epochs = 200
starting_frequency = [[0]]    # pre sigmoid value
decoder_net.apply_params_partial({(0, 1):
                                     {'operation': 'fm_sine',
                                      'parameters': {'freq_c': starting_frequency}
                                     }
                                 })

base_lr = 6e-3
optimizer = torch.optim.Adam(decoder_net.parameters(), lr=base_lr)

target_signal_unsqueezed = target_signal.unsqueeze(dim=0)

train_res = []
for e in range(num_epochs):
    
    predicted_params_01 = decoder_net.forward()

    predicted_params_full_range = normalizer.denormalize(predicted_params_01)
    
    synth.update_cells_from_dict(predicted_params_full_range)
    predicted_signal, _ = synth.generate_signal(signal_duration=1)

    loss, _, _= loss_handler.call(target_signal_unsqueezed, predicted_signal, step=0, return_spectrogram=False)

    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    
    predicted_freq = predicted_params_full_range[(0, 1)]['parameters']['freq_c']    
    train_res.append((e, predicted_freq.detach().item(), loss.detach().item()))

In [None]:
from matplotlib import animation
from IPython.display import HTML


fig, ax = plt.subplots(figsize=(15, 5))
l1, = ax.plot(param_range, loss_vals, 'o-', label='loss surface', markevery=[-1])
l2, = ax.plot([], [], 'o-', label='training progress')
ax.legend(loc='center right')
# ax.set_xlim(0,100)
# ax.set_ylim(0,1)

def animate(i):
    xi = [train_res[j][1] for j in range(i)]
    yi = [train_res[j][2] for j in range(i)]
    l2.set_data(xi, yi)
    return (l2)

a = animation.FuncAnimation(fig, animate, frames=num_epochs, interval=50)
HTML(a.to_jshtml())