In [1]:
import numpy as np
from matplotlib import pyplot as plt
import pdb
import os
import git

repo = git.Repo('.', search_parent_directories=True)
BASE_DIR = repo.working_tree_dir

In [2]:
def parse_doc(path, val_only=False, debug=False):

    with open(path) as f:
        lines = f.readlines()

    log_metrics_train = []
    epoch_values_train = []
    log_metrics_val = []
    epoch_values_val = []
    current_epoch_val = 0
    
    for cnt in range(len(lines)):
        cond = False
        
        tokens = lines[cnt].split(" ")
        if "Epoch" == tokens[0]: # Train
            epoch_value = float(lines[cnt].split(",")[0].split(" ")[1])
            current_epoch_val = int(epoch_value) # We assume the validation epoch (x-axis) is the closest epoch value
            # to the current epoch train value
            epoch_values_train.append(epoch_value)
            
            metrics = lines[cnt+1]
            cond = True
        elif "Validation" in lines[cnt]:
            epoch_values_val.append(current_epoch_val)
            
            metrics = lines[cnt+1]
            cond = True
            
        if cond:
            ade1    = float(metrics.split(', ade1 ')[1].split(',')[0].strip())
            fde1    = float(metrics.split(', fde1 ')[1].split(',')[0].strip())
            brier   = float(metrics.split(', brier_fde ')[1].split(',')[0].strip())
            adek    = float(metrics.split(', ade ')[1].split(',')[0].strip())
            fdek    = float(metrics.split(', fde ')[1].split(',')[0].strip())
            
            if "Epoch" == tokens[0]: # Train
                log_metrics_train.append([ade1, fde1, adek, fdek, brier])
                
                if debug:
                    print (f'Epoch {epoch_value} ADE1 {ade1} FDE1 {fde1} ADE-k {adek} FDE-k {fdek} BRIER_FDE {brier}')
                    print (10*'---')
            elif "Validation" in lines[cnt]:      
                log_metrics_val.append([ade1, fde1, adek, fdek, brier])
    
    epoch_values_train = np.array(epoch_values_train)
    log_metrics_train = np.array(log_metrics_train)
    epoch_values_val = np.array(epoch_values_val)
    log_metrics_val = np.array(log_metrics_val)
        
    return epoch_values_train, log_metrics_train, epoch_values_val, log_metrics_val 

In [3]:
logs = [
    "results/GANet/log", # teacher (full GANet)
    "results_student/GANet/log", # student (GANet without map)
    "results_ganet_without_mapindecoder/GANet/log",
    "results_student_lanegcn/GANet/log",
    "exp_wo_actornet/GANet/log",
    "exp_agent_gnn_dim_6/GANet/log",
    "exp_agent_gnn_dim_6_latent_64/GANet/log",
    "exp_agent_gnn_dim_6_latent_128/GANet/log",
    "exp_agent_gnn_dim_6_latent_128_again/GANet/log",
    "exp_agent_gnn_dim_6_latent_128_aug/GANet/log"
]

RESULTS_PATH = os.path.join(BASE_DIR,"metrics")

if not os.path.exists(RESULTS_PATH):
    print("Create results path folder: ", RESULTS_PATH)
    os.makedirs(RESULTS_PATH) # os.makedirs create intermediate directories. os.mkdir only the last one 
    
VAL_ONLY = False

if VAL_ONLY:
    dmetrics = {
        'ade1_val' :0,
        'fde1_val' :1,
        'adek_val' :2,
        'fdek_val' :3,
        'brierFDE_val':4
    }
else:
    dmetrics = {
        'ade1_train' :0,
        'fde1_train' :1,
        'adek_train' :2,
        'fdek_train' :3,
        'brierFDE_train':4,
        'ade1_val' :0,
        'fde1_val' :1,
        'adek_val' :2,
        'fdek_val' :3,
        'brierFDE_val':4
    }

In [4]:
for metric in dmetrics.keys():
    print (metric)
    for experiment in logs:
        epochs_train, metrics_train, epochs_val, metrics_val = parse_doc(experiment, val_only=VAL_ONLY, debug=False)

        try:
            if "train" in metric:
                plt.plot(epochs_train, metrics_train[:,dmetrics[metric]], label=experiment.replace('.txt', ''))
            elif "val" in metric:
                plt.plot(epochs_val, metrics_val[:,dmetrics[metric]], label=experiment.replace('.txt', ''))
        except:
            print (f'{experiment} has no validation metrics')
    
    if "train" in metric:
        plt.ylim(top=5)     
    plt.legend()
    plt.title(metric)
    plt.savefig(os.path.join(RESULTS_PATH,f'{metric}.png'))
    plt.close('all')
    # plt.show()

ade1_train
fde1_train
adek_train
fdek_train
brierFDE_train
ade1_val
fde1_val
adek_val
fdek_val
brierFDE_val
