In [49]:
import os
import pandas as pd
import plotly.express as px
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

In [5]:
LOGS_PATH = 'surrogate_logs'

In [18]:
def parse_logs(filename):
    with open(filename) as file:
        rmse_scores = []
        new_scores = []
        for line in file:
            if "Starting" in line:
                if len(new_scores) > 0:
                    rmse_scores.append(new_scores)
                    new_scores = []
            if "RMSE" in line:  
                new_scores.append(float(line.split('RMSE: ')[-1]))
    return rmse_scores

In [43]:
METHOD_NAMES = ["Random Forest", "Gaussian Process", "SVR", "MLP"]
DATASET_NAMES = ['20NG', "banners", "HR", "AR", "lenta"]
COLUMN_NAMES = ['dataset', 'method', 'value']

def plot_avg_vals(method_logs):
    dfObj = pd.DataFrame(columns=COLUMN_NAMES)
    # loop through methods
    for method_idx, method_log in enumerate(method_logs):
        method = METHOD_NAMES[method_idx]
        # loop through datasets
        method_data = []
        for dataset_idx, fname in enumerate(method_log):
            dataset = DATASET_NAMES[dataset_idx]
            
            rmse = parse_logs(os.path.join(LOGS_PATH, fname))
            rmse_res = [res[4:] for res in rmse]
            flatten = [item for sublist in rmse_res for item in sublist]
            n = len(flatten)
            df = pd.DataFrame(list(zip([dataset for _ in range(n)], 
                               [method for _ in range(n)], 
                               flatten,
                              )), 
               columns=COLUMN_NAMES) 
            dfObj = pd.concat([dfObj, df])
    return dfObj

In [44]:
# 20NG    banners     HR        AR      lenta

rf_logs = [
    '20ng_rf_15_e1.txt', 
    'banners_rf_15.txt', 
    'hr_rf_15_e1.txt', 
    'af_rf_15_e1.txt',
    'lenta_ru_rf_15_e1.txt',
]

gp_logs = [
    '20ng_gp_15_e1.txt',
    'banners_gp_15_e1.txt',
    'hr_gp_15_e1.txt',
    'af_gp_15_e1.txt',
    'lenta_GP_15.txt',
]

svr_logs = [
    '20ng_svr_15.txt',
    'banners_svr_15_e1.txt',
    'hr_svr_15_e1.txt',
    'af_svr_15_e1.txt',
    'lenta_svr_20.txt'
]

mlp_logs = [
    '20ng_mlp_11_e1.txt',
    'banners_mlp_20.txt',
    'hr_mlp_20.txt',
    'ar_mlp_20.txt',
    'lenta_mlp_20.txt'
]

all_logs = [rf_logs, gp_logs, svr_logs, mlp_logs]

In [45]:
data = plot_avg_vals(all_logs)

In [46]:
data

Unnamed: 0,dataset,method,value
0,20NG,Random Forest,0.018430
1,20NG,Random Forest,0.240782
2,20NG,Random Forest,0.044336
3,20NG,Random Forest,0.078803
4,20NG,Random Forest,0.001023
5,20NG,Random Forest,0.002395
6,20NG,Random Forest,0.038619
7,20NG,Random Forest,0.048612
8,20NG,Random Forest,0.043531
9,20NG,Random Forest,0.062556


In [47]:
fig = px.box(data, x="dataset", y="value", color="method", template='plotly_white')
# fig.update_traces(quartilemethod="exclusive") # or "inclusive", or "linear" by default
fig.update_layout(
    xaxis_title="Datasets",
    yaxis_title="RMSE",
    font=dict(
        family="Courier New, monospace",
        size=20,
        color="#7f7f7f"
    )
)
fig.show()

In [50]:
iplot(fig, image='svg')

In [17]:
filename = os.path.join(LOGS_PATH, '20ng_rf_15_e1.txt')
parse_logs(filename)

[[0.07218888073513599, 0.041001357903649026, 0.03645490375615739, 0.02301017004449994, 0.01842958159943832], [0.11218318745557619, 0.1635433620845868, 0.20299430293483015, 0.1516081476608987, 0.24078235898666966, 0.04433604053317465], [0.26254498857446107, 0.2625369127617244, 0.0574300288020829, 0.028402144986502623, 0.07880323522180635, 0.0010230878861699746], [0.20100185430992965, 0.12480675139245008, 0.04222861113528999, 0.015468840674006983, 0.0023946251185042824], [0.1976480064291819, 0.063033095136001, 0.11886398720550008, 0.07224670795695742, 0.03861943353520743, 0.04861228739124941, 0.04353106494397179, 0.06255564922285238], [0.1676426318723059, 0.15577146325884053, 0.10785661212362548, 0.04316532106477621, 0.06652356548114935], [0.29163326724284816, 0.1638110924858603, 0.1675527165928218, 0.11063934876179235, 0.008306440167540952, 0.0074203107430519805, 0.0040414013083426206], [0.126956631813171, 0.041578479819218035, 0.17456258313006567, 0.2036347452182421, 0.0537080818525185