In [34]:
# !pip install plotly

In [36]:
import glob
import numpy as np
import pandas as pd
import plotly.express as px

In [13]:
files = glob.glob('./grid_search/*.txt')
accuracies_list = []

# Go through all the files
for file in files:
    with open(file) as f:
        data_point = {}
        for line in f:

            # stop if we run into the line that starts the predicted states print
            if '~' not in line:
                if line == '\n':
                    pass
                else:
                    # save the params for this point
                    info = line.strip('\n').split(': ')
                    data_point[f'{info[0]}']= f'{info[1]}'
            else:
                break
        accuracies_list.append(data_point)

print(len(accuracies_list))
print(accuracies_list[0])

75
{'Accuracy': '0.4491618947086279', 'n_chords': '1', 'n_melody': '0', 't_prior': '1037.5', 'e_prior': '1037.5'}


In [52]:
accuracies_df = pd.DataFrame(accuracies_list)
accuracies_df['Accuracy'] = accuracies_df['Accuracy'].astype(float).round(4)
accuracies_df['n_chords'] = accuracies_df['n_chords']
accuracies_df['n_melody'] = accuracies_df['n_melody']
accuracies_df['t_prior'] = accuracies_df['t_prior']
accuracies_df['e_prior'] = accuracies_df['e_prior']
accuracies_df.style.background_gradient('viridis', 
                                        vmin=accuracies_df['Accuracy'].min(),
                                        vmax=accuracies_df['Accuracy'].max())

Unnamed: 0,Accuracy,n_chords,n_melody,t_prior,e_prior
0,0.4492,1,0,1037.5,1037.5
1,0.4198,2,0,300.0,0.0
2,0.4429,1,0,1037.5,2025.0
3,0.4304,1,1,30.0,7.5
4,0.4304,2,0,160.0,150.0
5,0.4205,1,0,4000.0,3012.5
6,0.4132,2,0,300.0,225.0
7,0.4356,2,0,160.0,0.0
8,0.4351,2,0,20.0,300.0
9,0.4263,1,0,3012.5,3012.5


In [49]:
fig = px.scatter_3d(accuracies_df, 
              x='n_chords', 
              y='n_melody', 
              z='Accuracy',
              height=800,)
fig.update_xaxes(type='category')
fig.update_yaxes(type='category')
fig.show()

In [68]:
t_prior_labels = accuracies_df['t_prior'].astype(float).value_counts().sort_index().astype(str).index.astype(str).to_list()
e_prior_labels = accuracies_df['e_prior'].astype(float).value_counts().sort_index().astype(str).index.astype(str).to_list()
print(t_prior_labels)
print(type(t_prior_labels[0]))
print(e_prior_labels)
print(type(e_prior_labels[0]))

['0.0', '7.5', '15.0', '20.0', '22.5', '30.0', '50.0', '90.0', '160.0', '230.0', '300.0', '1037.5', '2025.0', '3012.5', '4000.0']
<class 'str'>
['0.0', '7.5', '15.0', '22.5', '30.0', '50.0', '75.0', '150.0', '225.0', '300.0', '1037.5', '2025.0', '3012.5', '4000.0']
<class 'str'>


In [74]:
px.scatter_3d(accuracies_df, 
              x='t_prior', 
              y='e_prior', 
              z='Accuracy',
              hover_name='Accuracy',
              hover_data=['t_prior', 
                          'e_prior', 
                          'n_chords',
                          'n_melody'],
              category_orders={'t_prior': t_prior_labels,
                               'e_prior':e_prior_labels},
              height=800)

In [75]:
px.scatter_3d(accuracies_df[accuracies_df['Accuracy']>0.32], 
              x='t_prior', 
              y='e_prior', 
              z='Accuracy',
              hover_name='Accuracy',
              hover_data=['t_prior', 
                          'e_prior', 
                          'n_chords',
                          'n_melody'],
              category_orders={'t_prior': t_prior_labels,
                'e_prior':e_prior_labels},
              height=800)

In [72]:
accuracies_df.loc[accuracies_df['Accuracy'].idxmax()]

Accuracy    0.4597
n_chords         1
n_melody         0
t_prior     3012.5
e_prior       50.0
Name: 53, dtype: object

In [None]:
print(accuracies_df['Accuracy'].max().round(2))
print(np.floor(accuracies_df['Accuracy'].min()))

0.46
0.0


In [101]:
percs = list(range(int(np.floor(accuracies_df['Accuracy'].min())),
           int(accuracies_df['Accuracy'].max().round(2)*100),
           1
           )
    )
total_len = len(accuracies_df)

previous_perc_above = -1
print('Grid Search Accuracies')
print('---------------------')
for perc in percs[::-1]:
    perc = perc/100.0
    # print(perc)
    perc_len = len(accuracies_df[accuracies_df['Accuracy'] >= perc])
    perc_above = round(perc_len/total_len, 3)

    if perc_above != previous_perc_above:
        print(f"% of grid points >= ({perc:.2f}):\t {round(perc_above, 3)}")
        previous_perc_above = perc_above


Grid Search Accuracies
---------------------
% of grid points >= (0.45):	 0.147
% of grid points >= (0.44):	 0.36
% of grid points >= (0.43):	 0.6
% of grid points >= (0.42):	 0.787
% of grid points >= (0.41):	 0.907
% of grid points >= (0.38):	 0.92
% of grid points >= (0.31):	 0.933
% of grid points >= (0.01):	 0.987
% of grid points >= (0.00):	 1.0
