In [2]:
import copy
import json
import os, glob

from itertools import product

import numpy as np
import pandas as pd

import torch
# %matplotlib inline

import torchaudio

from torchaudio.functional.filtering import lowpass_biquad, highpass_biquad
from torchaudio.transforms import Spectrogram, MelSpectrogram, Resample

from matplotlib import rcParams

from collections import defaultdict

from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

from src.model.loss import spectral_loss
from synth.synth_architecture import SynthModular
# from config import SynthConfig
from src.main_hp_search_dec_only import configure_experiment

# from config import configure_experiment
from dataset.ai_synth_dataset import AiSynthDataset
from synth.synth_constants import synth_constants

%matplotlib inline
import matplotlib
# matplotlib.use('TkAgg')
import matplotlib.pyplot as plt


import seaborn as sns
sns.set_style('whitegrid')

ModuleNotFoundError: No module named 'src'

In [None]:
# Setup experiment

exp_name = 'del_visualization'
dataset_name = 'fm_saw_single'
config_name = r'C:\Users\noamk\PycharmProjects\ai_synth\configs\lfo_saw_single_synth_config_hp_search.yaml'
device = 'cuda'
# resample_op = Resample(orig_freq=16000, new_freq=50).to('cpu')

cfg = configure_experiment(exp_name, dataset_name, config_name, debug=True)

# synth = SynthModular(synth_cfg=synth_cfg, sample_rate=cfg.sample_rate, device=device, num_sounds=1,
#                      signal_duration_sec=cfg.signal_duration_sec, preset=synth_cfg.preset)
synth = SynthModular(preset_name=cfg.synth.preset,
                     synth_constants=synth_constants,
                     device=device)

# dataset = AiSynthDataset(cfg.train_parameters_file, cfg.train_audio_dir, device)
dataset = AiSynthDataset(r'C:\Users\noamk\PycharmProjects\ai_synth\data\lfo_saw_single\train', noise_std=0)



In [None]:
dataset = AiSynthDataset(r'C:\Users\noamk\PycharmProjects\ai_synth\data\lfo_saw_single\train', noise_std=0)

target_sample = dataset[0]
target_signal, target_param_dict, signal_index = target_sample

target_signal = target_signal.to(device)

synth.update_cells_from_dict(target_param_dict)
print(target_param_dict)

In [None]:
cfg.multi_spectral_loss_spec_type = 'SPECTROGRAM'
loss_handler = spectral_loss.SpectralLoss(loss_type=cfg.loss.spec_loss_type,
                                          loss_preset=cfg.loss.preset,
                                          synth_constants=synth_constants, device=device)

In [None]:
loss_handler.spectrogram_ops

In [None]:
fm_freq_vals = np.linspace(0, 1200, 1200)

losses1 = []
for freq_val in fm_freq_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(0, 2)]['parameters'].update({'freq_c': freq_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    # resampled_target_signal = resample_op(target_signal.cpu())
    # resampled_pred_signal = resample_op(signal.cpu())

    # plt.plot(resampled_target_signal.detach().numpy().squeeze())
    # plt.plot(resampled_pred_signal.detach().numpy().squeeze())
    #
    # plt.show()
    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    # loss_val = loss_handler.call(resampled_target_signal, resampled_pred_signal.unsqueeze(0), signal_chain_index=0, global_step=0, summary_writer=None, log=False)

    losses1.append(loss_val.detach().cpu().numpy().item())

print(losses1)


In [None]:
# matplotlib.use('Qt5Agg')



In [None]:
plt.figure(figsize=(20, 5))
plt.plot(fm_freq_vals, losses1)
plt.title("Loss vs carrier frequency")
plt.show()

In [None]:
lfo_freq_vals = np.linspace(0, 15, 100)

losses2 = []
for freq_val in lfo_freq_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(1, 1)]['parameters'].update({'freq': freq_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    # resampled_target_signal = resample_op(target_signal.cpu())
    # resampled_pred_signal = resample_op(signal.cpu())

    # plt.plot(resampled_target_signal.detach().numpy().squeeze())
    # plt.plot(resampled_pred_signal.detach().numpy().squeeze())
    #
    # plt.show()
    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    # loss_val = loss_handler.call(resampled_target_signal, resampled_pred_signal.unsqueeze(0), signal_chain_index=0, global_step=0, summary_writer=None, log=False)

    losses2.append(loss_val.detach().cpu().numpy().item())

print(losses2)

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(lfo_freq_vals, losses2)
plt.title("Loss vs LFO frequency")

plt.show()

In [None]:
amp_vals = np.linspace(0, 1, 100)

losses3 = []
for amp_val in amp_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(0, 2)]['parameters'].update({'amp_c': amp_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    # resampled_target_signal = resample_op(target_signal.cpu())
    # resampled_pred_signal = resample_op(signal.cpu())

    # plt.plot(resampled_target_signal.detach().numpy().squeeze())
    # plt.plot(resampled_pred_signal.detach().numpy().squeeze())
    #
    # plt.show()
    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    # loss_val = loss_handler.call(resampled_target_signal, resampled_pred_signal.unsqueeze(0), signal_chain_index=0, global_step=0, summary_writer=None, log=False)

    losses3.append(loss_val.detach().cpu().numpy().item())

print(losses3)


In [None]:
plt.figure(figsize=(20, 5))
plt.plot(amp_vals, losses3)
plt.title("Loss vs carrier amplitude")
plt.show()

In [None]:
mod_index_vals = np.linspace(0, 0.3, 1000)

losses4 = []
for mod_index_val in mod_index_vals:
    update_params = copy.deepcopy(target_param_dict)

    update_params[(0, 2)]['parameters'].update({'mod_index': mod_index_val})
    update_params[(1, 1)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'active': [True]})
    update_params[(0, 2)]['parameters'].update({'fm_active': [True]})
    synth.update_cells_from_dict(update_params)
    signal, _ = synth.generate_signal(signal_duration=1)

    # resampled_target_signal = resample_op(target_signal.cpu())
    # resampled_pred_signal = resample_op(signal.cpu())

    # plt.plot(resampled_target_signal.detach().numpy().squeeze())
    # plt.plot(resampled_pred_signal.detach().numpy().squeeze())
    #
    # plt.show()
    target_signal_unsqueezed = target_signal.unsqueeze(dim=0)
    loss_val, _, _ = loss_handler.call(target_signal_unsqueezed, signal, step=0, return_spectrogram=False)

    # loss_val = loss_handler.call(resampled_target_signal, resampled_pred_signal.unsqueeze(0), signal_chain_index=0, global_step=0, summary_writer=None, log=False)

    losses4.append(loss_val.detach().cpu().numpy().item())

print(losses4)

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(mod_index_vals, losses4)
plt.title("Loss vs FM mod_index")
plt.show()

In [None]:
sim_target_params = {'freq_c'}

In [None]:
spec_op = Spectrogram(n_fft=512)


fig, ax = plt.subplots(1, 2, figsize=(80, 40))

target_spec = spec_op(target_signal.cpu())

update_params = copy.deepcopy(target_param_dict)

update_params[(0,1)]['parameters'].update({'freq_c': 200})
synth.update_cells_from_dict(update_params)
signal, _ = synth.generate_signal()

pred_spec = spec_op(signal.cpu().detach())

ax[0].imshow(target_spec.squeeze(), origin='lower')
ax[1].imshow(pred_spec.squeeze(), origin='lower')