In [21]:
from dash import Dash, html, dcc
import numpy as np
import os
import plotly.graph_objs as go

if os.path.basename(os.getcwd()) != "RADIANT":
    print("Changing directory to RADIANT")
    os.chdir("..")
print(os.getcwd())

/home/dylan/RADIANT


In [23]:
directory = 'MachineLearning/results/'

dirs = [
    "FIRST_same_3000a_200e_5e-05lr",
    # "older/FIRST_same_3000a_200e_2e-05lr",
    # "older/FIRST_same_3000a_200e_2e-05l_0.1r",
    # "lr1e-5_first_3000a_200e_same",
    # "FIRST_same_3000a_200e_5e-06lr",

    # "radcat_first_3000a_100e_same",
    # "radcat_first_3000a_100e_diff"

]

graphs_acc = []
graphs_los = []
for d in dirs:
    losses = []
    for filename in os.listdir(directory + d):
        if filename.startswith("loss_fold"):
            loss = np.load(directory+d+"/"+ filename, allow_pickle=True).item()
            losses.append(loss)
    
    train_acc = np.array([loss['accuracy'] for loss in losses])
    train_acc_mean = np.mean(train_acc, axis=0)
    train_acc_std = np.std(train_acc, axis=0)
    val_acc = np.array([loss['val_accuracy'] for loss in losses])
    val_acc_mean = np.mean(val_acc, axis=0)
    val_acc_std = np.std(val_acc, axis=0)

    graphs_acc.append(
        dcc.Graph(
            figure={
                'data': [
                    go.Scatter(
                        x=np.arange(len(train_acc_mean)),
                        y=train_acc_mean,
                        mode='lines',
                        name='Train Accuracy'
                    ),
                    go.Scatter(
                        x=np.arange(len(val_acc_mean)),
                        y=val_acc_mean,
                        mode='lines',
                        name='Validation Accuracy'
                    ),
                    go.Scatter(
                        x=np.arange(len(train_acc_mean)),
                        y=train_acc_mean + train_acc_std,
                        mode='lines',
                        fill=None,
                        line=dict(color='rgba(0,0,255,0.2)'),
                        name='Train Accuracy Std'
                    ),
                    go.Scatter(
                        x=np.arange(len(train_acc_mean)),
                        y=train_acc_mean - train_acc_std,
                        mode='lines',
                        fill='tonexty',
                        fillcolor='rgba(0,0,255,0.2)',
                        line=dict(color='rgba(0,0,255,0.2)'),
                        name='Train Accuracy Std'
                    ),
                    go.Scatter(
                        x=np.arange(len(val_acc_mean)),
                        y=val_acc_mean + val_acc_std,
                        mode='lines',
                        fill=None,
                        line=dict(color='rgba(255,0,0,0.2)'),
                        name='Validation Accuracy Std'
                    ),
                    go.Scatter(
                        x=np.arange(len(val_acc_mean)),
                        y=val_acc_mean - val_acc_std,
                        mode='lines',
                        fill='tonexty',
                        fillcolor='rgba(255,0,0,0.2)',
                        line=dict(color='rgba(255,0,0,0.2)'),
                        name='Validation Accuracy Std'
                    )
                ],
                'layout': {
                    'title': d,
                    'xaxis': {'title': 'Epoch'},
                    'yaxis': {'title': 'Accuracy', 'range': [0, 1]},
                    'showlegend': False,
                },
            },
        )
    )

    graphs_los.append(
        dcc.Graph(
            figure={
                'data': [
                    go.Scatter(
                        x=np.arange(len(losses[0]['loss'])),
                        y=np.mean([loss['loss'] for loss in losses], axis=0),
                        mode='lines',
                        name='Train Loss'
                    ),
                    go.Scatter(
                        x=np.arange(len(losses[0]['val_loss'])),
                        y=np.mean([loss['val_loss'] for loss in losses], axis=0),
                        mode='lines',
                        name='Validation Loss'
                    )
                ],
                'layout': {
                    'title': d,
                    'xaxis': {'title': 'Epoch'},
                    'yaxis': {'title': 'Loss'},
                    'showlegend': False,
                },
            },
        )
    )





app = Dash(__name__)

app.layout = html.Div(
    children = [
        html.Div(
            style={'display': 'flex', 'flex-direction': 'row'},
            children=graphs_acc
        ),
        html.Div(
            style={'display': 'flex', 'flex-direction': 'row'},
            children=graphs_los
        )
    ]
)

app.run_server(debug=True)
