In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tbparse import SummaryReader

In [2]:
TB_LOGS_ROOT = '../tb_logs'  # Download all logs manually before running notebook.
PLOTS_DIR = '../plots'
SHOW_PLOTS = True

if not os.path.exists(PLOTS_DIR):
  os.makedirs(PLOTS_DIR)


def read_tb_logs_to_dfs(log_dir_name: str) -> dict[str, pd.DataFrame]:
    reader = SummaryReader(os.path.join(TB_LOGS_ROOT, log_dir_name))
    df = reader.scalars

    dataframes = {}
    for name in sorted(set(df['tag'])):
        separate_df = df[df['tag'] == name]
        separate_df.index = np.arange(1, len(separate_df) + 1)
        dataframes[name] = separate_df

    return dataframes

def merge_into_one_df(df_names: list[str], dfs: dict[str, pd.DataFrame]) -> pd.DataFrame:
    dfs = {name: dfs[name] for name in df_names}
    new_df = pd.DataFrame()
    for name, df in dfs.items():
        new_df[name] = df['value']

    return new_df

def plot_metrics(df: pd.DataFrame, metric_to_label: dict[str, str], plot_name: str):
    metric_name = 'loss' if 'loss' in next(iter(metric_to_label)) else 'accuracy'

    fig, ax = plt.subplots(figsize=(6.4, 4))
    for metric, label in metric_to_label.items():
        ax.plot(df.index, df[metric], label=label)
    ax.set(xlabel='epoch', ylabel=metric_name)
    ax.legend(loc='upper right' if metric_name == 'loss' else 'lower right')
    if SHOW_PLOTS:
        fig.show()

    fig.savefig(os.path.join(PLOTS_DIR, f'{plot_name}.pdf'), bbox_inches='tight')

def plot_lightning_logs(log_dir_name: str, model_name: str):
    dataframes = read_tb_logs_to_dfs(log_dir_name)

    loss_metrics = {
        'train_loss_epoch': model_name + ' training loss',
        'val_loss_epoch': model_name + ' validation loss'
    }
    accuracy_metrics = {
        'train_accuracy_epoch': model_name + ' training accuracy',
        'val_accuracy_epoch': model_name + ' validation accuracy'
    }

    df = merge_into_one_df(list(loss_metrics.keys()) + list(accuracy_metrics.keys()), dataframes)
    plot_metrics(df, loss_metrics, log_dir_name + '_loss')
    plot_metrics(df, accuracy_metrics, log_dir_name + '_accuracy')

def plot_transformers_logs(log_dir_name: str, model_name: str):
    dataframes = read_tb_logs_to_dfs(log_dir_name)

    loss_metrics = {
        'train/loss': model_name + ' training loss',
        'eval/loss': model_name + ' validation loss'
    }

    df = merge_into_one_df(list(loss_metrics.keys()), dataframes)
    plot_metrics(df, loss_metrics, log_dir_name + '_loss')

def plot_comparison(
    log_dir_names: list[str],
    model_names: list[str],
    split: str,
    metric_name: str,
    suffix: str=''
):
    """
    :param split: 'train' or 'eval'
    :param metric_name: 'loss' or 'accuracy'
    :param suffix: Use it to distinguish between plots and to avoid overwriting files.
    """
    df = pd.DataFrame()
    for log_dir_name, model_name in zip(log_dir_names, model_names):
        dataframes = read_tb_logs_to_dfs(log_dir_name)
        if 'train/loss' in dataframes:
            df[f'{model_name}_{metric_name}'] = dataframes[f'{split}/{metric_name}']['value']
        else:
            tmp_split = 'val' if split == 'eval' else split
            df[f'{model_name}_{metric_name}'] = dataframes[f'{tmp_split}_{metric_name}_epoch']['value']

    split = 'training' if split == 'train' else 'validation'
    metrics = {
        colname: f'{model_name} {split} {metric_name}'
        for colname, model_name in zip(df.columns, model_names)
    }

    plot_metrics(df, metrics, f'comparison_{split}_{metric_name}_{suffix}')

In [None]:
log_dir_names = [
    'T5_3e-4', 'T5_3e-4_polish_labels',
    'mT5-base_5e-4',
    'mT5-small_1e-3', 'mT5-small_3langs_1e-3'
]
model_names = [
    'T5', 'T5',
    'mT5',
    'mT5-small', 'mT5-small'
]

for log_dir_name, model_name in zip(log_dir_names, model_names):
    plot_lightning_logs(log_dir_name, model_name)

In [None]:
log_dir_names = [
    'xlm-roberta-lr5e5', 'xlm-roberta-lr5e5-combined',
    'xlm-v-lr5e5', 'xlm-v-lr5e5-combined',
    'herbert_pl_lr5e-5', 'herbert_combined_lr5e-5'
]
model_names = [
    'XLM-RoBERTa', 'XLM-RoBERTa',
    'XLM-V', 'XLM-V',
    'HerBERT', 'HerBERT'
]

for log_dir_name, model_name in zip(log_dir_names, model_names):
    plot_transformers_logs(log_dir_name, model_name)

In [None]:
log_dir_names = ['mT5-small_1e-3', 'xlm-v-lr5e5', 'herbert_pl_lr5e-5']
model_names = ['mT5', 'XLM-V', 'HerBERT']

plot_comparison(log_dir_names, model_names, 'eval', 'accuracy', 'pl-all_models')

In [None]:
log_dir_names = ['mT5-small_3langs_1e-3', 'xlm-v-lr5e5-combined', 'herbert_combined_lr5e-5']
model_names = ['mT5', 'XLM-V', 'HerBERT']

plot_comparison(log_dir_names, model_names, 'eval', 'accuracy', 'combined-all_models')

In [None]:
log_dir_names = ['T5_3e-4', 'mT5-base_5e-4', 'mT5-small_1e-3']
model_names = ['T5', 'mT5-base', 'mT5-small']

plot_comparison(log_dir_names, model_names, 'eval', 'accuracy', 'pl-T5')

In [None]:
log_dir_names = ['T5_3e-4', 'T5_3e-4_polish_labels']
model_names = ['T5 with English labels', 'T5 with Polish labels']

plot_comparison(log_dir_names, model_names, 'eval', 'accuracy', 'polish-labels')