# Imports

In [None]:
import os
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.ticker import StrMethodFormatter
import seaborn as sns
# Apply the default theme
sns.set_theme()
sns.set(font_scale = 2)

plt.rcParams['lines.linewidth'] = 2
DPI = 200

# Arguments

In [None]:
from dataclasses import dataclass

@dataclass
class args:
    outputPath = '../total_target_cleaned_scaled_old'
    figPath = os.path.join(outputPath, 'figures')
    train_loss_filePath = os.path.join(outputPath, 'run-version_0-tag-train_loss_epoch.csv')
    val_loss_filePath = os.path.join(outputPath, 'run-version_0-tag-val_loss.csv')

# Utils

In [None]:
def plot_loss(df, figsize=(18,8), title:str=None, figure_name:str=None, base=5, annotate=True):
    df['Epoch'] = df['Step'] / df['Step'].min()
    df['Epoch'] = df['Epoch'].astype(int)

    formatter = StrMethodFormatter('{x:,.5g}')
    fig, ax = plt.subplots(figsize=figsize)
    plt.gca().yaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(base=base))

    plt.plot(df['Epoch'], df['Value'])
    plt.xlabel('Number of epochs')
    plt.ylabel('Mean squared loss')
    # plt.ylim(bottom=df['Value'].min())
    plt.xlim(df['Epoch'].min()-1, df['Epoch'].max()+1)
    if title is not None:
        plt.title(title)

    minimum_loss = df['Value'].min()
    epoch = df[df['Value']==minimum_loss]['Epoch'].values[0]

    if annotate:
        arrowprops = dict(arrowstyle="fancy, head_length=0.8, head_width=0.5, tail_width=0.2", color='red')

        plt.annotate(f'Minimum point ({epoch}, {minimum_loss:0.5g})', 
            xy=(epoch, minimum_loss),
            xytext=(max((6-epoch)*20, 0), 150), 
            textcoords='offset points',
            size=20, 
            ha='center', va="center",
            arrowprops=arrowprops
        )

    plt.tight_layout()
    if figure_name is not None:
        plt.savefig(figure_name, dpi=DPI)
    plt.show()

#  Plotting

## Train

In [None]:
df = pd.read_csv(args.train_loss_filePath)
plot_loss(df, title='Train loss with time', figure_name=os.path.join(args.outputPath, 'train_loss.jpg'), annotate=False)

## Validation

In [None]:
df = pd.read_csv(args.val_loss_filePath)
plot_loss(df, title='Validation loss with time', figure_name=os.path.join(args.outputPath, 'validation_loss.jpg'))