In [2]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import font_manager

sns.set(rc={'figure.figsize':(6.5, 4.5)})
sns.set_theme(style='white')

font_path='/home/ericwallace/computer-modern/cmunss.ttf'
font_manager.fontManager.addfont(font_path)
prop = font_manager.FontProperties(fname=font_path)
plt.rcParams["axes.edgecolor"] = "0.15"
plt.rcParams["axes.linewidth"]  = 1.25
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = prop.get_name()
plt.rcParams.update({'axes.labelsize': '30', 'axes.labelpad': '8.0', 'axes.titlesize':'30'})
plt.rcParams['axes.unicode_minus'] = False

plt.rc('xtick', labelsize=14)
plt.rc('ytick', labelsize=14)
# plt.rcParams["legend.markerscale"] = 10.0
# plt.rcParams['lines.markersize'] = 10

In [3]:
import json
import pandas as pd

def parse_path(path, name):
    with open(path, 'r') as f:
        file_contents = f.read()
        file_contents = file_contents.replace('\'', '\"')
        data = json.loads(file_contents)

    rows = []

    for raw_row in data:
        if 'loss' in raw_row:
            rows.append([raw_row['loss'], raw_row['step'], raw_row['epoch'], 'train'])
        elif 'eval_loss' in raw_row:
            rows.append([raw_row['eval_loss'], raw_row['step'], raw_row['epoch'], 'validation'])
    
    df = pd.DataFrame(rows, columns=['loss', 'step', 'epoch', 'type'])

    df['name'] = [name] * len(df)

    return df

In [4]:
paths = {
    'News': 'runs/gpt2_NewsData_16_8_10_8_0.0005/log_history.txt',
    'Reddit': 'runs/gpt2_RedditData_16_8_10_8_0.0005/log_history.txt',
    'Tweets': 'runs/gpt2_TweetData_16_8_10_8_0.0005/log_history.txt'
}

df = pd.concat([parse_path(path, name) for name, path in paths.items()])

In [7]:
import seaborn as sns
from matplotlib import pyplot as plt

grid = sns.FacetGrid(
    df,
    col='name',
    height=5,
    aspect=1.5,
    sharex=False,
    sharey=False
)

grid.map_dataframe(
    sns.lineplot,
    data=df,
    x='step',
    y='loss',
    hue='type',
    linewidth=4
)

# set x and y labels
grid.set_axis_labels('Step', 'Loss')

# set titles
grid.axes[0, 0].set_title('News')
grid.axes[0, 1].set_title('Reddit')
grid.axes[0, 2].set_title('Tweets')

plt.legend(loc='upper right')
plt.legend(fontsize=20)
for i in range(3):
    grid.axes[0, i].tick_params(axis='both', which='major', labelsize=20)

# plt.show()

plt.savefig('figs/lossplots.pdf', bbox_inches='tight')