In [1]:
import json
import os, glob

from itertools import product

import numpy as np
import pandas as pd

import torch

import matplotlib.pyplot as plt
import torchaudio

from torchaudio.functional.filtering import lowpass_biquad, highpass_biquad

from torchaudio.transforms import Spectrogram, SpectralCentroid

from matplotlib import rcParams

from collections import defaultdict

from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

import seaborn as sns

from model import helper
from model.model import SimpleSynthNetwork
from synth.synth_architecture import SynthModular
from config import SynthConfig, Config
from dataset.ai_synth_dataset import AiSynthDataset, create_data_loader

from synth.synth_modular_chains import synth_chains_dict

sns.set_style('whitegrid')

spectrogram = Spectrogram(n_fft=128)


In [2]:
device = 'cuda:0'

dataset_to_visualize = 'fm_filter_dataset'
split_to_visualize = 'train'
data_dir = os.path.join('data', dataset_to_visualize, split_to_visualize, '')

wav_files_dir = os.path.join(data_dir, 'wav_files', '')
params_csv_path = os.path.join(data_dir, 'params_dataset.pkl')

ai_synth_dataset = AiSynthDataset(params_csv_path, wav_files_dir, device)
test_dataloader = create_data_loader(ai_synth_dataset, 32, 4, shuffle=False)


FileNotFoundError: [Errno 2] No such file or directory: 'data\\fm_filter_dataset\\train\\params_dataset.pkl'

In [3]:
synth_cfg = SynthConfig()
cfg = Config()

transform = helper.mel_spectrogram_transform(cfg.sample_rate).to(device)
normalizer = helper.Normalizer(cfg.signal_duration_sec, synth_cfg)

synth_obj = SynthModular(synth_cfg=synth_cfg,
                         sample_rate=cfg.sample_rate,
                         signal_duration_sec=cfg.signal_duration_sec,
                         num_sounds=1,
                         device=device,
                         chain='FM_FILTER')

In [48]:
lfo_model_ckpt = r'E:\Users\elhara2\ai_synth\experiments\current\lfo_only_model_fm_filter_data\ckpts\trained_synth_net.pt'
fm_model_ckpt = r'E:\Users\elhara2\ai_synth\experiments\current\fm_only_model_fm_filter_data_2\ckpts\trained_synth_net.pt'
filter_model_ckpt = r'E:\Users\elhara2\ai_synth\experiments\current\filter_only_model_fm_filter_data\ckpts\trained_synth_net.pt'

lfo_model = SimpleSynthNetwork('LFO', synth_cfg, device, backbone='resnet').to(device)
fm_model = SimpleSynthNetwork('FM_ONLY', synth_cfg, device, backbone='resnet').to(device)
filter_model = SimpleSynthNetwork('FILTER_ONLY', synth_cfg, device, backbone='resnet').to(device)

lfo_model.load_state_dict(torch.load(lfo_model_ckpt))
fm_model.load_state_dict(torch.load(fm_model_ckpt))
filter_model.load_state_dict(torch.load(filter_model_ckpt))

lfo_model.eval()
fm_model.eval()
filter_model.eval()

SimpleSynthNetwork(
  (backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

In [18]:
def discretize_params(operation: str, input_params: dict, synth_cfg):

    params_chain = synth_cfg.all_params_chains.get(operation, {})

    res = {}
    for param_name, param_values in input_params.items():

        if isinstance(param_values, torch.Tensor):
            param_values = param_values.detach().cpu().numpy()
        else:
            param_values = np.asarray(param_values)

        if param_name in ['waveform', 'filter_type']:

            if isinstance(param_values[0], str):
                res[param_name] = param_values
                continue

            idx = np.argmax(param_values, axis=1)
            if param_name == 'waveform':
                res[param_name] = [synth_cfg.wave_type_dic_inv[i] for i in idx]
            else:
                res[param_name] = [synth_cfg.filter_type_dic_inv[i] for i in idx]
            continue

        possible_values = params_chain.get(param_name, None)

        if possible_values is None:
            res[param_name] = param_values
            continue

        idx = np.searchsorted(possible_values, param_values, side="left")
        idx[idx == len(possible_values)] = len(possible_values) - 1
        idx[idx == 0] = 1

        if operation == 'fm' and param_name == 'freq_c':
            below_distance = (param_values / possible_values[idx - 1])
            above_distance = (possible_values[idx] / param_values)
        else:
            below_distance = np.abs(param_values - possible_values[idx - 1])
            above_distance = np.abs(param_values - possible_values[idx])

        idx = idx - (below_distance < above_distance)
        res[param_name] = possible_values[idx]

    return res

In [43]:
def compare_params(target_params, predicted_params):
    res = defaultdict(dict)
    for cell_idx, target_cell_data in target_params.items():

        if target_cell_data['operation'][0] == 'None':
            continue

        target_cell_params = target_cell_data['parameters']
        predicted_cell_params = predicted_params[cell_idx]['parameters']

        for param_name, target_param_values in target_cell_params.items():
            pred_param_values = np.asarray(predicted_cell_params[param_name]).squeeze()
            target_param_values = np.asarray(target_param_values).squeeze()

            assert len(target_param_values) == len(pred_param_values)

            correct_preds = np.sum(target_param_values == pred_param_values)

            res[cell_idx][param_name] = correct_preds

    return res

In [49]:
step, n_samples = 0, 0
results = defaultdict(int)
for target_signal, target_param_dict, signal_index in test_dataloader:

    step += 1

    target_signal = target_signal.to(device)
    transformed_signal = transform(target_signal)

    # -----------Run Model-----------------
    output_lfo_params = lfo_model(transformed_signal)
    output_fm_params = fm_model(transformed_signal)
    output_filter_params = filter_model(transformed_signal)

    output_params = {**output_lfo_params, **output_filter_params, **output_fm_params}

    denormalized_output_params = normalizer.denormalize(output_params)

    denormalized_discrete_output_params = {cell_idx: {'operation': cell_params['operation'], 'parameters': discretize_params(cell_params['operation'], cell_params['parameters'], synth_cfg)}
                                           for cell_idx, cell_params in denormalized_output_params.items()}

    discrete_target_params = {cell_idx: {'operation': cell_params['operation'], 'parameters': discretize_params(cell_params['operation'][0], cell_params['parameters'], synth_cfg)}
                              for cell_idx, cell_params in target_param_dict.items() if cell_params['operation'][0] != 'None'}

    correct_preds = compare_params(discrete_target_params, denormalized_discrete_output_params)

    for cell_idx, cell_data in correct_preds.items():
        for param_name, correct_preds in cell_data.items():
            results[f'{cell_idx}_{param_name}'] += correct_preds

    n_samples += len(target_signal)

    if step % 100 == 0:
        print(step)

100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500


In [50]:
for k, v in results.items():
    results[k] = v / n_samples

In [53]:
filter_model.train()
ps = filter_model(transformed_signal)

In [54]:
p = ps[(0, 2)]['parameters']['filter_freq'][0]

In [56]:
p = torch.floor(p * 16000)
p_s = p.squeeze()

In [62]:
dude = torch.tensor(8.0, dtype=torch.float, requires_grad=True)

In [59]:
p_lin = torch.linspace(0, 1, dude)

TypeError: linspace(): argument 'steps' (position 3) must be int, not Tensor