In [None]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import json
import os
import re
import pprint

In [None]:
data = json.loads(os.environ['EXP_DATA'])
history = data['history']



# Model and Benchmark Summary

In [None]:
for s in data["description"].split(".")[:-1]:
    print(s + ".\n")

## Extract and format metrics to be plotted

In [None]:
# if there are any metrics that were renamed, add this new name here as ("default_name":"new_name")
metric_custom_names={"auc":"AUC_ROC"}

metric_names = [re.sub("([a-z0-9])([A-Z])","\g<1> \g<2>",name) for name in data["benchmark"]["metrics"]]
metric_keys = [re.sub("([a-z0-9])([A-Z])","\g<1>_\g<2>",name).lower() for name in data["benchmark"]["metrics"]]

for default_name, custom_name in metric_custom_names.items():
    if not default_name in history.keys() and default_name in metric_keys:
        #replace default name with custom name
        metric_keys[metric_keys.index(default_name)]=custom_name


## Plot training & validation accuracy values

In [None]:
def print_or_plot_metric(metric_key, metric_name, figure_name):
    if len(history[metric_key]) == 1:
        print("Data for {m_name} only available for a single epoch. \nSkipping plot and printing data...".format(m_name=metric_name))
        print('Train {}: '.format(metric_name), history[metric_key])
        print('Validation {}: '.format(metric_name), history['val_'+metric_key])
        print()        
    else:
        plot_epoch_metric(metric_key, metric_name, figure_name)
        
def plot_epoch_metric(metric_key, metric_name, figure_name):
    figure(num=None, figsize=(10, 6))
    plt.plot(history[metric_key])
    if 'val_'+metric_key in history.keys():
        plt.plot(history['val_'+metric_key])
    plt.title(figure_name)
    plt.ylabel(metric_name)
    plt.xlabel('Epoch')
    if 'val_'+metric_key in history.keys():
        plt.legend(['Train', 'Validation'], loc='upper left')
    plt.show()

for i, metric_key in enumerate(metric_keys):
    print_or_plot_metric(metric_key, metric_names[i], "Model "+metric_names[i])

## Plot training & validation loss values

In [None]:
print_or_plot_metric("loss", "Loss", "Model loss")

In [None]:
if "lr" in history.keys():
    plot_epoch_metric("lr", "Learning Rate", "Learning Rate")

## Classification Report

In [None]:
if 'classification_report' in data.keys() and data['classification_report']:
    print(data['classification_report'])

## Test Scores

In [None]:
if 'test' in data.keys() and data['test']:
    for score_name, score in data["test"].items():
        print('Test {}: '.format(score_name), score)



## Benchmark Details

In [None]:
pp = pprint.PrettyPrinter(indent=4)
if "benchmark" in data.keys():
    pp.pprint(data["benchmark"])