# Task A: Pitch detection
---

In [None]:
import dataset.aGPTset.ExpressiveGuitarTechniquesDataset as agptset
import os
import librosa
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp
import pandas as pd
# import torchcrepe
import crepe

import am24utils
from am24utils import Run


dataset = agptset.import_db()

DOTEST = False
VERBOSE = False
DB_PATH = 'dataset/aGPTset'
printVerbose = lambda x: print(x) if VERBOSE else None

In [None]:

print("Task 1: Performing Pitch Detection with dataset onset labels and perturbations")

# Filter the db to keep only pitched notes
def filter_by_pitched_notes(dataset):
    pitched_files = [file for file in dataset['files_df'].index.tolist() if 'pitched' in file]        # All filtenames that have "pitched in the index column"
    pitched_files = [file for file in pitched_files if 'impro' not in file]        # Remove those that contain impro
    
    #Take from the dataset['noteLabels_df] only the rows that have the filenames in the second value of the multiindex
    pitched_notes_df = dataset['noteLabels_df'].loc[dataset['noteLabels_df'].index.get_level_values(1).isin(pitched_files)]

    # Take only the rows where the column "pitch_midi" is not NaN and < 108 (key 88 piano)
    pitched_notes_df = pitched_notes_df[pitched_notes_df['pitch_midi'].notna()]
    pitched_notes_df = pitched_notes_df[pitched_notes_df['pitch_midi'].astype(int) < 108]
    assert not 'natural' in pitched_notes_df.index.get_level_values(0).tolist(), "There are natural notes in the dataset"

    filtered_filenames = pitched_notes_df.index.get_level_values(1).tolist()
    filtered_filenames = list(sorted(set(filtered_filenames)))
    filtered_files_df = dataset['files_df'].loc[filtered_filenames]

    return pitched_notes_df,filtered_files_df


print("Filtering the Dataset...")
filtered_notes_db,filtered_files_db = filter_by_pitched_notes(dataset)
# print(f"len = {len(filtered_notes_db)}, {filtered_notes_db}")
print("Done.")

In [None]:
# Perform pitch detection by loading a file and running multiple pitch detectors on windows of size window_size_samples after each onset label
def detect_pitch_from_filename(filename, onset_list, window_size_samples, detectors = ['librosa.yin','librosa.pyin','crepe'], verbose = VERBOSE):
    printVerbose = lambda x: print(x) if verbose else None
    # Load the file
    y, sr = librosa.load(filename, sr=None)
    #if 'crepe' in detectors:
    #    y, sr = torchcrepe.load.audio(filename)
    
    assert sr == 48000

    totres = {}

    if 'librosa.yin' in detectors or 'all' in detectors:
        res = []
        printVerbose("Running librosa.yin...")
        for onset in onset_list:
            window = y[onset:onset+window_size_samples]
            try:
                del pitch
            except:
                pass
            pitch = librosa.yin(window, 
                                fmin=librosa.note_to_hz('C1'), fmax=librosa.note_to_hz('C7'), 
                                sr = sr,
                                frame_length=window_size_samples//2,
                                center=False)
            printVerbose(f"librosa.yin: onset = {onset}, pitch = {pitch}")
            assert len (pitch) == 5, f"librosa.pyin was expected to return a single value, but returned {len(pitch)} values"
            # print warning if all 5 values are nan
            if np.isnan(pitch).all():
                print(f"Warning: all 5 values of pitch are NaN for onset {onset}")
            rpitch = np.nanmean(pitch)
            res.append(rpitch.item())
            # res.append(440.0)
        totres['librosa.yin'] = res


    if 'librosa.pyin' in detectors or 'all' in detectors:
        res = []
        printVerbose("Running librosa.pyin...")
        for onset in onset_list:
            window = y[onset:onset+window_size_samples]
            pitch = librosa.pyin(window, 
                                fmin=librosa.note_to_hz('C1'), fmax=librosa.note_to_hz('C7'), 
                                sr = sr,
                                frame_length=window_size_samples//2,
                                center=False)
            printVerbose(f"librosa.pyin: onset = {onset}, pitch = {pitch}")
            assert len(pitch) == 3, f"librosa.pyin was expected to return 3 values, but returned {len(pitch)} values"
            pitch = pitch[0]
            assert len (pitch) == 5, f"librosa.pyin was expected to return a single value, but returned {len(pitch)} values"
            # print warning if all 5 values are nan
            if np.isnan(pitch).all():
                print(f"Warning: all 5 values of pitch are NaN for onset {onset}")
            rpitch = np.nanmean(pitch)
            assert type(rpitch) == np.float64, f"librosa.pyin returned a value that is a {type(pitch[0])} not a float"
            res.append(rpitch.item())
            # res.append(440.0)
        totres['librosa.pyin'] = res
    
    if 'crepe' in detectors or 'all' in detectors:
        #import torch
        #device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        res = []
        printVerbose("Running crepe...")
        for onset in onset_list:
            window = y[onset:onset+window_size_samples]
            #window = torch.from_numpy(window).float().to(device)
            time, pitch, confidence, activation = crepe.predict(window, sr, viterbi=True, verbose=0)
            #pitch = torchcrepe.predict(window, sr, hop_length=int(sr/500), batch_size=64, device=device)
            printVerbose(f"crepe: onset = {onset}, pitch = {pitch}")
            #assert len(pitch) == 3, f"crepe was expected to return 3 values, but returned {len(pitch)} values"
            #pitch = pitch[0]
            #assert len (pitch) == 5, f"crepe was expected to return a single value, but returned {len(pitch)} values"
            # print warning if all 5 values are nan
            if np.isnan(pitch).all():
                print(f"Warning: all 5 values of pitch are NaN for onset {onset}")
            rpitch = np.nanmean(pitch)
            assert type(rpitch) == np.float64, f"crepe returned a value that is a {type(pitch[0])} not a float"
            res.append(rpitch.item())
            # res.append(440.0)
        totres['crepe'] = res
        #totres['crepe_confidence'] = confidence[0]
    return totres

In [None]:
def detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=2048, detectors_to_use = ['crepe'], multiprocessing = True, perturbation_distribution=None, max_perturbation_samples=0):
    y_true = {}
    y_pred = {}

    filenames = []
    onset_lists = []
    ground_truth_pitches = []
    for numidx,(index, row) in enumerate(filtered_files_db.iterrows()):
        filename = os.path.join(DB_PATH,row['full_audiofile_path'])
        printVerbose(f"filename = {filename}")
        assert os.path.exists(filename), f"File {filename} does not exist"
        
        # Take from filtered_notes_db only the rows that have the filenames in the second value of the multiindex
        assert index in filtered_notes_db.index.get_level_values(1).tolist(), f"File {index} not in filtered_notes_db"
        # Get the 'onset_label_samples' column values for the current file
        onset_list = filtered_notes_db.loc[filtered_notes_db.index.get_level_values(1) == index].loc[:, 'onset_label_samples'].tolist()
        onset_list = [int(onset) for onset in onset_list]
        printVerbose(f"onset_list ({len(onset_list)}) = {onset_list}")

        # toprint = 'Processing file %i/%i'%(numidx+1, totnumfiles)
        # print(toprint, end='\r')

        ground_truth_pitches_cur = [float(el) for el in filtered_notes_db.loc[filtered_notes_db.index.get_level_values(1) == index].loc[:, 'pitch_midi'].tolist()]

        filenames.append(filename)
        onset_lists.append(onset_list)
        ground_truth_pitches.append(ground_truth_pitches_cur)
        del onset_list, ground_truth_pitches_cur

    if perturbation_distribution is not None:
        for runidx in range(len(onset_lists)):
            onset_list = onset_lists[runidx]
            onset_lists[runidx] = am24utils.apply_onset_perturbation(onset_list, perturbation_distribution, max_perturbation_samples, max_perturbation_samples)
        # if perturbation_distribution == 'uniform':
        #     for runidx in range(len(onset_lists)):
        #         onset_list = onset_lists[runidx]
        #         onset_lists[runidx] = [onset + np.random.randint(-max_perturbation_samples, max_perturbation_samples) for onset in onset_list]
        # else:
        #     raise ValueError("Perturbation distribution %s not valid"%(perturbation_distribution))

    if multiprocessing:
        pool = mp.Pool(mp.cpu_count())
        results = [pool.apply_async(
            detect_pitch_from_filename, 
            args=(filename, onset_list, window_size_samples), 
            kwds={'detectors':detectors_to_use}) for filename, onset_list in zip(filenames, onset_lists)]
        pool.close()
        pool.join()
        detected_pitches = [r.get() for r in results]
        # assert len(detected_pitches) == len(filenames), f"detected_pitches has len {len(detected_pitches)} but filenames has len {len(filenames)}"
    else:
        raise NotImplementedError("Non-multiprocessing not implemented yet")
        # detected_pitches = []
        # for filename, onset_list, ground_truth_pitches_cur in zip(filenames, onset_lists, ground_truth_pitches):



        #     detected_pitches = detect_pitch_from_filename(filename, onset_list, window_size_samples, verbose=False, detectors=detectors_to_use)


    assert len(detected_pitches) == len(filenames)
    assert len(detected_pitches) == len(ground_truth_pitches)

    for runIdx in range(len(detected_pitches)):
        cur_detected_pitches = detected_pitches[runIdx]
        assert type(cur_detected_pitches) == dict, f"detected_pitches is not a dict, but a {type(detected_pitches)} instead"
        for pikey,pival in cur_detected_pitches.items():
            assert type(pival) == list, f"detected_pitches[{pival}] is not a list"
            assert type(pikey) == str, f"detected_pitches key is not a string"
            assert len(pival) > 0, f"detected_pitches[{pival}] has len == {len(pival)}"

            assert type(cur_detected_pitches) == dict
            for key in cur_detected_pitches.keys():
                assert len(cur_detected_pitches[key]) == len(ground_truth_pitches[runIdx])
                y_true[key] = []
                y_pred[key] = []
                for realpitch, curcur_detected_pitches in zip(ground_truth_pitches[runIdx],cur_detected_pitches[key]):
                    detected_pitches_midi = librosa.hz_to_midi(curcur_detected_pitches)
                    y_true[key].append(realpitch)
                    y_pred[key].append(detected_pitches_midi)

        # break


    func_res = {'y_true':y_true, 'y_pred':y_pred, 'window_size_samples':window_size_samples, 'perturbation_distribution':perturbation_distribution}
    if perturbation_distribution is not None:
        func_res['max_perturbation_samples'] = max_perturbation_samples
    return func_res

In [None]:
filtered_filenames = filtered_notes_db.index.get_level_values(1).tolist()
filtered_filenames = list(sorted(set(filtered_filenames)))
print(f"Unique files filtered for this task = {len(filtered_filenames)}")
if DOTEST:
    filtered_filenames = filtered_filenames[:5]
    print(f"REDUCED TO {len(filtered_filenames)} FOR TESTING PURPOSES")
filtered_files_db = dataset['files_df'].loc[filtered_filenames]



## Perform the pitch detection
print("Performing Pitch Detection...")



In [None]:
def get_onsetlist_filenames_ground_truth_pitches(filtered_notes_db:pd.DataFrame, filtered_files_db:pd.DataFrame):
    filenames = []
    onset_lists = []
    ground_truth_pitches = []
    for numidx,(index, row) in enumerate(filtered_files_db.iterrows()):
        filename = os.path.join(DB_PATH,row['full_audiofile_path'])
        printVerbose(f"filename = {filename}")
        assert os.path.exists(filename), f"File {filename} does not exist"
        
        # Take from filtered_notes_db only the rows that have the filenames in the second value of the multiindex
        assert index in filtered_notes_db.index.get_level_values(1).tolist(), f"File {index} not in filtered_notes_db"
        # Get the 'onset_label_samples' column values for the current file
        onset_list = filtered_notes_db.loc[filtered_notes_db.index.get_level_values(1) == index].loc[:, 'onset_label_samples'].tolist()
        onset_list = [int(onset) for onset in onset_list]
        printVerbose(f"onset_list ({len(onset_list)}) = {onset_list}")

        # toprint = 'Processing file %i/%i'%(numidx+1, totnumfiles)
        # print(toprint, end='\r')

        ground_truth_pitches_cur = [float(el) for el in filtered_notes_db.loc[filtered_notes_db.index.get_level_values(1) == index].loc[:, 'pitch_midi'].tolist()]

        filenames.append(filename)
        onset_lists.append(onset_list)
        ground_truth_pitches.append(ground_truth_pitches_cur)
        del onset_list, ground_truth_pitches_cur
    return onset_lists, filenames, ground_truth_pitches

onset_lists, filenames, ground_truth_pitches = get_onsetlist_filenames_ground_truth_pitches(filtered_notes_db, filtered_files_db)
assert len(onset_lists) == len(filenames) == len(ground_truth_pitches), "The lists have different lengths"

packedData = (onset_lists, filenames, ground_truth_pitches)

In [None]:
# def run_taskA(runs, packedData, classifier='librosa.yin'):
#     onsetlist_list,filenames_list,pitch_list = packedData
#     assert len(onsetlist_list) == len(filenames_list) == len(pitch_list), f"Different length of lists: onsetlist_list={len(onsetlist_list)}, filenames_list={len(filenames_list)}, pitch_list={len(pitch_list)}"
#     for ridx,run in enumerate(runs):
#         print('Running task A for Run:%s [%i,%i]'%(run.name,ridx+1,len(runs)), end='\r')
#         print('+--%s--Arguments--------------+'%(run.name))
#         print('| Window size: %i'%run.window_size_samples)
#         print('| Onset perturbation distribution: %s'%run.onset_perturbation_distribution)
#         print('| Onset perturbation max samples: %i'%run.onset_perturbation_max_samples)
#         print('| Onset perturbation min samples: %i'%run.onset_perturbation_min_samples)
#         print('+-------------------------------------+')
        
#         y_true, y_pred = [],[]
#         errors_list = []
#         for idx, file in enumerate(filenames_list):
#             cur_onsetlist = onsetlist_list[idx]
#             cur_y_true = pitch_list[idx]
            
#             # Apply onset perturbation
#             if run.onset_perturbation_distribution is not None:
#                 # print('Applying onset perturbation to file %s'%cur_filename)
#                 cur_onsetlist = am24utils.apply_onset_perturbation(cur_onsetlist, run.onset_perturbation_distribution, run.onset_perturbation_max_samples, run.onset_perturbation_min_samples)

#             # Perform the pitch detection
#             cur_y_pred = detect_pitch_from_filename(file, cur_onsetlist, run.window_size_samples, detectors=[classifier], verbose=False)[classifier]
            
#             # Convert the detected pitches to MIDI
#             cur_y_pred = librosa.hz_to_midi(cur_y_pred)
            
#             y_true_cur = np.array(cur_y_true)
#             y_pred_cur = np.array(cur_y_pred)
        
#             y_pred_cur = np.nan_to_num(y_pred_cur, nan=100000) # replace Nan with inf

#             # Compute absolute errors
#             cur_errors = np.abs(y_true_cur - y_pred_cur)
                                
#             # Add to the list of errors
#             errors_list.extend(cur_errors)
#             # Append to the list of true and predicted values
#             y_true.extend(y_true_cur)
#             y_pred.extend(y_pred_cur)
            
#         # Compute the metrics
#         y_true = np.array(y_true)
#         y_pred = np.array(y_pred)
#         mae = mean_absolute_error(y_true, y_pred)
#         print('-MEAN ABSOLUTE ERROR %f-\n'%(mae))
        
#         # Append run results
#         run.results = {'mae':mae, 'errors':errors_list}
        

In [None]:
def run_taskA(runs, packedData,classifier='librosa.yin'):
    onsetlist_list,filenames_list,pitch_list = packedData
    assert len(onsetlist_list) == len(filenames_list) == len(pitch_list), f"Different length of lists: onsetlist_list={len(onsetlist_list)}, filenames_list={len(filenames_list)}, pitch_list={len(pitch_list)}"    
    for ridx,run in enumerate(runs):
            print('Running task A for Run:%s [%i,%i]'%(run.name,ridx+1,len(runs)), end='\r')
            print('+--%s--Arguments--------------+'%(run.name))
            print('| Window size: %i'%run.window_size_samples)
            print('| Onset perturbation distribution: %s'%run.onset_perturbation_distribution)
            print('| Onset perturbation max samples: %i'%run.onset_perturbation_max_samples)
            print('| Onset perturbation min samples: %i'%run.onset_perturbation_min_samples)
            print('+-------------------------------------+')
            
            results = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=run.window_size_samples, detectors_to_use = [classifier], multiprocessing = True, perturbation_distribution=run.onset_perturbation_distribution, max_perturbation_samples=run.onset_perturbation_max_samples)
            
            # calculate the metrics
            for key in results['y_true'].keys():
                y_true = results['y_true'][key]
                y_pred = results['y_pred'][key]
                errors = np.abs(np.array(y_true) - np.array(y_pred))
                mae = mean_absolute_error(y_true, y_pred)
                mse = mean_squared_error(y_true, y_pred)
                print(f"MAE for {key} = {mae}")
                print(f"MSE for {key} = {mse}")
                run.results = { 'y_true': y_true,
                                'y_pred': y_pred,
                                'errors': errors, 
                                'mae':mae, 
                                'mse':mse}
                
            
            

In [None]:
import am24utils
from am24utils import Run

to_run_NN = am24utils.get_run_list()

for ridx,run in enumerate(to_run_NN):
    print(run.name)

run_taskA(to_run_NN, packedData=packedData, classifier='crepe')


In [None]:
to_run = am24utils.get_run_list()

for ridx,run in enumerate(to_run):
    print(run.name)

run_taskA(to_run, packedData=packedData, classifier='librosa.yin')

In [None]:
# allruns = {}

# # import am24utils
# # from am24utils import Run

# # to_run = am24utils.get_run_list()

# # for ridx,run in enumerate(to_run):
# #     print(run.name)

# # run_taskA(to_run, packedData=packedData)

# # # allruns['512--noP'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=512, detectors_to_use=['librosa.yin', 'crepe'])
# allruns['1024-noP'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=1024, detectors_to_use=['librosa.yin'])
# allruns['2048-noP'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=2048, detectors_to_use=['librosa.yin'])
# allruns['4096-noP'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=4096, detectors_to_use=['librosa.yin'])

# # # allruns['512--unip1024'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=512, detectors_to_use=['librosa.yin'], perturbation_distribution='uniform', max_perturbation_samples=1024)
# allruns['1024-unip1024'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=1024, detectors_to_use=['librosa.yin'], perturbation_distribution='normal', max_perturbation_samples=1024)
# allruns['2048-unip1024'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=2048, detectors_to_use=['librosa.yin'], perturbation_distribution='normal', max_perturbation_samples=1024)
# allruns['4096-unip1024'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=4096, detectors_to_use=['librosa.yin'], perturbation_distribution='normal', max_perturbation_samples=1024)

# # # allruns['512--unip2048'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=512, detectors_to_use=['librosa.yin, 'crepe''], perturbation_distribution='uniform', max_perturbation_samples=2048)
# # allruns['1024-unip2048'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=1024, detectors_to_use=['librosa.yin', 'crepe'], perturbation_distribution='uniform', max_perturbation_samples=2048)
# # allruns['2048-unip2048'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=2048, detectors_to_use=['librosa.yin', 'crepe'], perturbation_distribution='uniform', max_perturbation_samples=2048)
# # allruns['4096-unip2048'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=4096, detectors_to_use=['librosa.yin', 'crepe'], perturbation_distribution='uniform', max_perturbation_samples=2048)

# # # allruns['512--unip512'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=512, detectors_to_use=['librosa.yin', 'crepe'], perturbation_distribution='uniform', max_perturbation_samples=512)
# # allruns['1024-unip0512'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=1024, detectors_to_use=['librosa.yin', 'crepe'], perturbation_distribution='uniform', max_perturbation_samples=512)
# # allruns['2048-unip0512'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=2048, detectors_to_use=['librosa.yin', 'crepe'], perturbation_distribution='uniform', max_perturbation_samples=512)
# # allruns['4096-unip0512'] = detect_pitch_on_dataset(filtered_files_db, filtered_notes_db, window_size_samples=4096, detectors_to_use=['librosa.yin', 'crepe'], perturbation_distribution='uniform', max_perturbation_samples=512)


# print("Done.")

In [None]:
# import pickle, datetime

# resdir_path = os.path.join('results','task-A','date_%s'%(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")))
# os.makedirs(resdir_path)

# bakfilename  = 'taskA_Yin_results.pickle'

# with open(os.path.join(resdir_path,bakfilename), 'wb') as f:
#     pickle.dump(to_run, f)

In [None]:
# for run in to_run:
#     run.results['error'] = run.results['errors']

In [None]:
# am24utils.plot_runs(to_run, arg_metric = 'error', arg_plottype = 'box')
# plt.savefig(os.path.join(resdir_path,'error_yin.png'))
# plt.savefig(os.path.join(resdir_path,'error_yin.pdf'))

In [None]:
# import am24utils
# from am24utils import Run

# to_run_crepe = am24utils.get_run_list()

# for ridx,run in enumerate(to_run_crepe):
#     print(run.name)

# run_taskA(to_run_crepe, packedData=packedData, classifier='crepe')
# print("Done.")

In [None]:
import pickle, datetime


resdir_path = os.path.join('results','task-A','date_%s'%(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")))
os.makedirs(resdir_path)

# resdir_path = "results/task-A/date_2024-05-02_16-33-13"

bakfilename  = 'taskA_CREPE_results.pickle'

with open(os.path.join(resdir_path,bakfilename), 'wb') as f:
    pickle.dump(to_run_NN, f)
    
    
am24utils.plot_runs(to_run_NN, 'errors', 'box')
plt.savefig(os.path.join(resdir_path,'errors_crepe.png'))
plt.savefig(os.path.join(resdir_path,'errors_crepe.pdf'))

In [None]:
bakfilename  = 'taskA_YIN_results.pickle'

with open(os.path.join(resdir_path,bakfilename), 'wb') as f:
    pickle.dump(to_run, f)
    
    
am24utils.plot_runs(to_run, 'errors', 'box')
plt.savefig(os.path.join(resdir_path,'errors_yin.png'))
plt.savefig(os.path.join(resdir_path,'errors_yin.pdf'))

In [None]:
# # %matplotlib qt
# alltoplot = []
# alltoplotlabels = []
# for run in to_run:
#     y_true = run.results['y_true']
#     y_pred = run.results['y_pred']
#     window_size_samples = run.window_size_samples

#     errors = {}
#     error_labels = []
#     for key in y_true:
#         y_true_cur = np.array(y_true)
#         y_pred_cur = np.array(y_pred)
        
#         y_pred_cur = np.nan_to_num(y_pred_cur, nan=100000) # replace Nan with inf

#         print("Mean Absolute Error(",run.name,"",key,") = %.1f"% mean_absolute_error(y_true_cur, y_pred_cur))
#         # print("Mean Squared Error(",run,") = ", mean_squared_error(y_true_cur, y_pred_cur))

#         # Plot error with boxplot
#         errors[key] = abs(y_true_cur - y_pred_cur)
#         error_labels.append(run.name)
#         #error_labels.append(run.name+' '+key+' '+str(window_size_samples)+' ~%.2fms'%(window_size_samples/48))

#     error_lists = [errors[key] for key in errors]



#     alltoplot.append(error_lists)
#     alltoplotlabels.append(error_labels)

# # # Group alltoplot and alltoplotlabels and sort by alltoplotlabels
# # alltoplot = [x for _, x in sorted(zip(alltoplotlabels, alltoplot))]
# # alltoplotlabels = sorted(alltoplotlabels)

# plt.figure(figsize=(3*len(alltoplot), 7))
# # Grouped boxplot since alltoplot is a list of lists
# xpositions = np.arange(len(alltoplotlabels))
# for idx in range(len(alltoplotlabels)):
#     for idx_method in range(len(alltoplot[idx])):
#         assert len(alltoplot[idx]) == len(alltoplotlabels[idx]), f"{len(alltoplot[idx])} {len(alltoplotlabels[idx])}"
#         plt.boxplot(alltoplot[idx][idx_method], positions=[xpositions[idx]+idx_method*0.5],labels=[alltoplotlabels[idx][idx_method].replace(' ', '\n')])
#     # plt.boxplot(alltoplot[idx], positions=[xpositions[idx]], labels=[s.replace(' ', '\n') for s in alltoplotlabels[idx]])
        
#         # plt.show()  
# plt.ylabel('Error (MIDI note)')
# plt.title('Error in MIDI note')
# #plt.ylim(0, 2)

# plt.show()  

# # plt.boxplot(alltoplot,labels=alltoplotlabels)
# # plt.ylabel('Error (MIDI note)')
# # plt.title('Error in MIDI note (run: %s, window size = %i)'%(run, allruns[run]['window_size_samples']))

# # plt.show()

In [None]:
# for run in to_run:
#     print(run.name)
#     print(run.results['y_pred'])

In [None]:
# # %matplotlib qt
# alltoplot = []
# alltoplotlabels = []
# for run in sorted(list(allruns.keys())):
#     y_true = allruns[run]['y_true']
#     y_pred = allruns[run]['y_pred']
#     window_size_samples = allruns[run]['window_size_samples']

#     errors = {}
#     error_labels = []
#     for key in y_true:
#         y_true_cur = np.array(y_true[key])
#         y_pred_cur = np.array(y_pred[key])
        
#         y_pred_cur = np.nan_to_num(y_pred_cur, nan=100000) # replace Nan with inf

#         print("Mean Absolute Error(",run,"",key,") = %.1f"% mean_absolute_error(y_true_cur, y_pred_cur))
#         # print("Mean Squared Error(",run,") = ", mean_squared_error(y_true_cur, y_pred_cur))

#         # Plot error with boxplot
#         errors[key] = abs(y_true_cur - y_pred_cur)
#         error_labels.append(run+' '+key+' '+str(window_size_samples)+' ~%.2fms'%(window_size_samples/48))

#     error_lists = [errors[key] for key in errors]



#     alltoplot.append(error_lists)
#     alltoplotlabels.append(error_labels)

# # # Group alltoplot and alltoplotlabels and sort by alltoplotlabels
# # alltoplot = [x for _, x in sorted(zip(alltoplotlabels, alltoplot))]
# # alltoplotlabels = sorted(alltoplotlabels)

# plt.figure(figsize=(3*len(alltoplot), 7))
# # Grouped boxplot since alltoplot is a list of lists
# xpositions = np.arange(len(alltoplotlabels))
# for idx in range(len(alltoplotlabels)):
#     for idx_method in range(len(alltoplot[idx])):
#         assert len(alltoplot[idx]) == len(alltoplotlabels[idx]), f"{len(alltoplot[idx])} {len(alltoplotlabels[idx])}"
#         plt.boxplot(alltoplot[idx][idx_method], positions=[xpositions[idx]+idx_method*0.5],labels=[alltoplotlabels[idx][idx_method].replace(' ', '\n')])
#     # plt.boxplot(alltoplot[idx], positions=[xpositions[idx]], labels=[s.replace(' ', '\n') for s in alltoplotlabels[idx]])
        
#         # plt.show()  
# plt.ylabel('Error (MIDI note)')
# plt.title('Error in MIDI note')
# #plt.ylim(0, 2)

# plt.show()  

# # plt.boxplot(alltoplot,labels=alltoplotlabels)
# # plt.ylabel('Error (MIDI note)')
# # plt.title('Error in MIDI note (run: %s, window size = %i)'%(run, allruns[run]['window_size_samples']))

# # plt.show()
