In [11]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [19]:
def read_history(model_dir):
    history_dir = os.path.join(model_dir, 'history')
    history_files = os.listdir(history_dir)
    
    hdf_list = [pd.read_csv(os.path.join(history_dir, hf)) for hf in history_files]
    df = pd.concat(hdf_list, ignore_index=True)
    return df

In [20]:
perf_u128 = read_history('model/20201209_koi/')
perf_u256 = read_history('model/20201210_koi/')

In [33]:
def check_perform(perf):
    fig = make_subplots(rows=2, cols=2,
               specs=[[{}, {}], [{'colspan': 2}, None]],
               subplot_titles=('Loss and Validate loss', 'Accuracy and Validate accuracy', 'F1 Score'))
    fig.add_trace(go.Scatter(y=perf['loss'], name='loss'), row=1, col=1)
    fig.add_trace(go.Scatter(y=perf['val_loss'], name='validate loss'), row=1, col=1)

    fig.add_trace(go.Scatter(y=perf['accuracy'], name='accuracy'), row=1, col=2)
    fig.add_trace(go.Scatter(y=perf['val_accuracy'], name='validate accuracy'), row=1, col=2)

    for prop in perf.columns[4:]:
        fig.add_trace(go.Scatter(y=perf[prop], name=prop), row=2, col=1)

    fig.update_layout(height=800, width=1000)

    fig.show()
    
def compare_perform(perf1, perf2, perf1_name, perf2_name, props):
    fig = make_subplots(cols=2, rows=len(props)//2 + 1, subplot_titles=tuple(props))
    for i in range(len(props)):
        prop = props[i]
        if i % 2 == 0:
            fig.add_trace(go.Scatter(y=perf1[prop], name=f'{perf1_name}<br>{prop}'), row=i//2 + 1, col=1)
            fig.add_trace(go.Scatter(y=perf2[prop], name=f'{perf2_name}<br>{prop}'), row=i//2 + 1, col=1)
        else:
            fig.add_trace(go.Scatter(y=perf1[prop], name=f'{perf1_name}<br>{prop}'), row=i//2 + 1, col=2)
            fig.add_trace(go.Scatter(y=perf2[prop], name=f'{perf2_name}<br>{prop}'), row=i//2 + 1, col=2)
    fig.update_layout(height=400 * len(props)//2 + 1)
    fig.show()

In [34]:
compare_perform(perf_u128, perf_u256, 'units=128', 'units=256', list(perf_u128.columns))

In [17]:
import chart_studio.tools as tls
import chart_studio.plotly as cplt

usr = 'SharpKoi'
api_key = 'XwSSNCqGzY0TWCDthlkC'
tls.set_credentials_file(username=usr, api_key=api_key)
cplt.plot(fig, file_name='20201210_u256_e80_bs16', auto_open=False)

'https://plotly.com/~SharpKoi/13/'