In [21]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import Audio
import numpy as np
import pickle
import yaml
import pandas as pd

from lib.notebooks import plot_groups_metrics
from synthesizer import Synthesizer

In [2]:
synthesizers_path = glob("../out/synthesizer/*/")
synthesizers_path.sort()

In [24]:
groups_metrics = {}
synthesizers_loss = {
    "dataset": [],
    "hidden_layers": [],
    "dropout_p": [],
    "learning_rate": [],
    "final_loss": [],
}

for synthesizer_path in synthesizers_path:
    synthesizer = Synthesizer.reload(synthesizer_path, load_nn=False)
    config = synthesizer.config
    with open("%s/metrics.pickle" % synthesizer_path, "rb") as f:
        metrics = pickle.load(f)
    final_loss = min(metrics["validation"]["total"])
    
    synthesizers_loss["dataset"].append(config['dataset']['name'])
    synthesizers_loss["hidden_layers"].append(f"{len(config['model']['hidden_layers'])}x{config['model']['hidden_layers'][0]}")
    synthesizers_loss["dropout_p"].append(config['model']['dropout_p'])
    synthesizers_loss["learning_rate"].append(config['training']['learning_rate'])
    synthesizers_loss["final_loss"].append(final_loss)
    
    group_name = "\n".join((
        f"{config['dataset']['name']}",
        f"hidden_layers={len(config['model']['hidden_layers'])}x{config['model']['hidden_layers'][0]}",
    ))
    
    if group_name not in groups_metrics:
        groups_metrics[group_name] = {}
    groups_metrics[group_name][synthesizer_path] = metrics

synthesizers_loss = pd.DataFrame(synthesizers_loss)

In [29]:
synthesizers_loss.sort_values("final_loss").head(10)

Unnamed: 0,dataset,hidden_layers,dropout_p,learning_rate,final_loss
76,pb2007,4x512,0.074519,0.00162,0.50051
60,pb2007,4x512,0.050272,0.000462,0.500915
24,pb2007,4x512,0.11974,0.000419,0.502189
93,pb2007,4x512,0.124299,0.000413,0.502256
13,pb2007,4x512,0.113911,0.0005,0.502425
96,pb2007,4x512,0.010576,0.000728,0.502916
43,pb2007,4x512,0.018956,0.000691,0.503541
70,pb2007,4x512,0.155062,0.000254,0.503624
55,pb2007,4x512,0.076532,0.001144,0.503779
75,pb2007,4x512,0.123428,0.000297,0.50386


In [15]:
metrics_name = [
    "total",
]

def show_metrics(split_name="validation"):
    plot_groups_metrics(groups_metrics, metrics_name, split_name)
display(ipw.interactive(show_metrics, split_name=["train", "validation"]))

interactive(children=(Dropdown(description='split_name', index=1, options=('train', 'validation'), value='vali…