In [None]:
## Log2Graphs.py


from parse_logs import parse_log_files, parse_batch_log_files
from make_graphs import plot_training_errors, plot_comparison_errors

def log2graphs(log_file, tr_label, te_label, max_epoch=None, show=True, save=True):
    data = parse_log_files(log_file)
    plot_training_errors(data, tr_label, te_label, max_epoch=max_epoch, show=show, save=save)

def compare_log2graphs(log_files_path, graph_label, te_label, max_epoch=100, show=True, save=True, log_y=True, log_x=False):
    data_tag, data_list = parse_batch_log_files(log_files_path)
    plot_comparison_errors(data_tag, data_list, graph_label, te_label, max_epoch=max_epoch, show=show, save=save, log_y=log_y, log_x=log_x)


if __name__ == "__main__":
    #log2graphs("/Users/bd20841/code/IMP2_make_graphs/logs/new_logs/300_dt5ab_F_ga2_iter_log.txt", tr_label="DT5AB_F_ga2", te_label="DT4_F_containing")
    compare_log2graphs("/Users/bd20841/code/IMP2_make_graphs/logs/ga_logs", graph_label='DT5AB_gradient_accumulation', te_label='DT3', log_y=True, log_x=False)

In [3]:
## Parse logs

import os
from collections import defaultdict
import matplotlib.pyplot as plt
import glob

def parse_log_files(log_file):

    with open(log_file, 'r') as f:
        nmr_params = []

        header_line = f.readline()
        h_block = header_line.split('|')
        header = h_block[1:]
        for i in range(len(header)-1):
                    nmr_params.append(header[i].split()[1])

        values = defaultdict(list)

        for idx, line in enumerate(f):
            for i in range(len(nmr_params)):
                main_block = line.split('|')
                val_block = main_block[1:]
                values[nmr_params[i]].append(float(val_block[i].split()[1]))
    return values

def parse_batch_log_files(log_files_path):
    files = glob.glob(f"{log_files_path}/*.txt")
    data_list = []
    data_tag = []
    for file in files:
        tag = file.split('/')[-1].split('_')[2]
        data = parse_log_files(file)
        data_list.append(data)
        data_tag.append(tag)

    return data_tag, data_list

In [None]:
## Make graphs

import sys
import os
from pathlib import Path
sys.path.append(os.path.realpath(os.path.dirname("TROY_PROJECT_eval_iter_log.txt"))+'/../')

import matplotlib.pyplot as plt

def plot_training_errors(data, tr_label='training_set', te_label='test_set', max_epoch=None, show=True, save=True):
    for key, values in data.items():
        epochs = []
        errors = []
        for i, error in enumerate(values):
            epochs.append(i)
            errors.append(error)
            if i == max_epoch:
                break

        plt.plot(epochs, errors, label=key)
        plt.title(f'{tr_label} errors for {key} prediction against {te_label}')
        plt.xlabel('Epochs')
        plt.ylabel('Loss (MAE)')
        plt.legend()
        fig = plt.gcf()
        if show:
            plt.show()
        if save:
            if not os.path.exists(f'graphs/{tr_label}'):
                os.makedirs(f'graphs/{tr_label}')
            fig.savefig(f'graphs/{tr_label}/{tr_label}_{key}_training_error.png', format='png')

def plot_comparison_errors(data_tag, data_list, tr_label='training_set', te_label='test_set', max_epoch=None, show=True, save=True, log_y=False, log_x=False):

    for key in data_list[0].keys():
        plt.clf()
        for tag, data in zip(data_tag, data_list):
            epochs = []
            errors = []
            for i, error in enumerate(data[key]):
                epochs.append(i)
                errors.append(error)
                if i == max_epoch:
                    break

            plt.plot(epochs, errors, label=tag)
        plt.title(f'{tr_label} errors for {key} prediction against {te_label}')
        plt.xlabel('Epochs')
        plt.ylabel('Loss (MAE)')
        if log_y:
            plt.yscale('log')
        if log_x:
            plt.xscale('log')
        plt.legend()
        fig = plt.gcf()
        if show:
            plt.show()
        if save:
            if not os.path.exists(f'graphs/{tr_label}'):
                os.makedirs(f'graphs/{tr_label}')
            fig.savefig(f'graphs/{tr_label}/{tr_label}_{key}_training_error.png', format='png')