In [1]:
import pandas as pd
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython.display import clear_output
import os
import time

plt.style.use('material')

# Read File

In [2]:
def get_raw_runs():
    with open("logs/trainer.log") as f:
        data = f.read()
    runs = data.split('=== NEW FULL TRAIN ===')
    return [run for run in runs if 'Starting at' in run]

# Process

In [3]:
def accuracy(raw_line,prefix,run_details):
    top1 = "{} Top-1".format(prefix)
    top3 = "{} Top-3".format(prefix)
    
    
    if run_details.get(top1) is None:
        run_details[top1]=[]
    if 'Top-1' in raw_line:
        run_details[top1].append(float(raw_line.split(':',2)[-1].split("%")[0]))
    if 'Top-3' in raw_line:
        if trains.get(top3) is None:
            run_details[top3]=[]
        run_details[top3].append(float(raw_line.split(':',2)[-1].split("%")[0]))
    if 'Top' not in raw_line:
        run_details[top1].append(float(raw_line.split(':',1)[-1].split("%")[0]))
    return run_details

def proc_run(run):
    run_details = {}
    curr_epoch = -1
    deadhead_history = []
    param_history = []
    a_loss, e_loss, i_loss = [],[],[]
    
    prev_line = ""
    for raw_line in run.split("\n"):
        if raw_line == prev_line:
            prev_line = raw_line
            continue
            
        # add new values to histories
        if 'EPOCH' in raw_line:
            deadhead_history.append(0)
            param_history.append(0)
        
        # track model stats
        if 'Starting at' in raw_line:
            run_details['Start Time']=raw_line.split("Starting at")[1]
        if 'Dim' in raw_line:
            if 'torch' in raw_line:
                raw_line=raw_line.replace(")","").replace("torch.Size(","")
            detail_str = re.split(',(?=\s[A-Za-z])',raw_line)
            new_str = ''
            for detail in detail_str:
                k,v = detail.split(":")
                new_str += "'"+k.strip()+"':"+v
                if detail!=detail_str[-1]:
                    new_str += ", "
            locals_ = locals()
            if '<' in new_str and '>' in new_str:
                new_str=new_str.split("<")[0]+'None'+new_str.split(">")[-1]
            exec('details={'+new_str+"}",None,locals_)
            run_details.update(locals_['details'])
        
        # add accuracies
        if 'Train Corrects:' in raw_line:
            run_details = accuracy(raw_line,'Train',run_details)     
        elif ('Last Towers Test' in raw_line and 'Corrects' in raw_line) or ('Test' in raw_line and 'Towers' not in raw_line and 'Corrects' in raw_line) or 'test acc' in raw_line:
            run_details = accuracy(raw_line,'LT Test',run_details)
        elif ('All Towers Test' in raw_line and 'Corrects' in raw_line):
            run_details = accuracy(raw_line,'AT Test',run_details)
        
        # track deadheading
        if 'Deadheaded' in raw_line:
            deadhead_history[-1] = -int(raw_line.split('Deadheaded')[-1].split("operations")[0])
        if 'Param Delta' in raw_line:
            param_history[-1] = [int(raw_line.split('Param Delta:')[-1].split("->")[0].replace(",","")),
                                 int(raw_line.split('->')[-1].replace(",",""))]
        
        # track loss_comps
        if 'Train Loss Components' in raw_line:
            loss_comps = raw_line.split(':')
            a,e,i = [float(loss.split(",")[0]) for loss in loss_comps[2:]]
            a_loss, e_loss, i_loss = a_loss+[a], e_loss+[e], i_loss+[i]
        if raw_line != "":
            prev_line = raw_line
    run_details['Loss Accuracy']=a_loss
    run_details['Loss Edge']=e_loss
    run_details['Loss Input']=i_loss

    # flatten param hist
    if [x for x in param_history if type(x) is not int]:
        prev_val = [hist[0] for hist in param_history if hist!=0][0]
        new_param_hist = []
        for val in param_history:
            if val==0:
                new_param_hist.append(prev_val)
            else:
                new_param_hist.append(val[1])
                prev_val = val[1]
        run_details['Params'] = new_param_hist  
    run_details['Deadhead'] = deadhead_history
    return run_details
  
def proc_all_runs():
    runs = [proc_run(run) for run in get_raw_runs()]
    runs = pd.DataFrame(runs)
    for col in [col for col in list(runs) if 'Top' in col]:
        runs[col] = runs[col].apply(lambda x: x if type(x)==list else [])
    runs['LT Test Top-1 Max'] = runs['LT Test Top-1'].apply(lambda x: max(x, default=0))
    runs['AT Test Top-1 Max'] = runs['AT Test Top-1'].apply(lambda x: max(x, default=0))
    runs['Epoch']=runs['Train Top-1'].apply(len)
    return runs

df = proc_all_runs()

In [None]:
#df.iloc[-1]

In [25]:
def cum_max(l):
    return [max(l[:i+1]) for i in range(len(l))]

def roll_ave(l):
    n=10
    return [np.mean(l[max(0,i-n):i+n+1]) for i in range(len(l))]

# Visualize

In [29]:
def visualize(smooth_f,compare_id=None):
    while 1:
        file_mod = os.stat('logs/trainer.log').st_mtime
        clear_output()
        runs = proc_all_runs()

        try:
            plt.figure(figsize=(14,14),dpi=150)
            full_runs = runs[runs['LT Test Top-1'].apply(lambda x: len(x)>512)].sort_values(by='LT Test Top-1 Max',ascending=False)

            if compare_id:
                compare_str = compare_id
                compare = full_runs[full_runs['ID']==compare_id]['LT Test Top-1'].values[0]
            else:
                compare = full_runs['LT Test Top-1'].values[0]
                compare_str = 'PR'
            full_runs = full_runs['LT Test Top-1'].values
            
            at,lt = max(runs.iloc[-1]['AT Test Top-1']),max(runs.iloc[-1]['LT Test Top-1'])
            at_last,lt_last= runs.iloc[-1]['AT Test Top-1'][-1],runs.iloc[-1]['LT Test Top-1'][-1]
            if lt>at:
                curr_run = runs.iloc[-1]['LT Test Top-1']
                agg_type = "Last Tower"
            else:
                curr_run = runs.iloc[-1]['AT Test Top-1']
                agg_type = "All Towers"
            cm = plt.cm.Spectral
            for i,run in enumerate(full_runs):
                plt.plot(smooth_f(run),color=cm(i/len(full_runs)),alpha=.75 if i==0 else .5)

            epoch = len(curr_run)-1
            if epoch<0:
                print("No log yet...")
            else:
                curr,curr_max,curr_arg_max = curr_run[epoch],max(curr_run),np.argmax(curr_run)
                rec, rec_max, rec_arg_max = compare[epoch], max(compare[:epoch+1]),np.argmax(compare[:epoch+1])
                print("==== EPOCH {} ======================================".format(epoch))
                print("AT Max: {} LT Max: {}".format(at,lt))
                print("AT Last: {} LT Last: {}".format(at_last,lt_last))
                print("Current Delta to {}:     {:> 2.2f}% ({}% vs {}%)".format(compare_str,curr-rec,curr,rec))
                print("Current Delta to {} Max: {:> 2.2f}% ({}% @{} vs {}% @{})".format(compare_str,curr_max-rec_max,curr_max,curr_arg_max,rec_max,rec_arg_max))

                plt.plot(smooth_f(curr_run),color='k',linewidth=2)
                plt.plot(curr_run,color='k',alpha=.25,linewidth=1)
                plt.ylim(min(curr_run[-10:]),100)

                plt.title("CIFAR-10 Loss History, Bonsai Net")
                plt.xlabel("Epoch",fontsize=14)
                plt.xticks(fontsize=14)
                plt.ylabel("Accuracy",fontsize=14)
                plt.yticks(fontsize=14)
                plt.show()
        except KeyboardInterrupt:
            break
        except Exception as e: 
            raise e
        while file_mod == os.stat('logs/trainer.log').st_mtime or ('EPOCH' not in open('logs/trainer.log',"r").readlines()[-1]):
            time.sleep(1)
        
visualize(cum_max)

ValueError: max() arg is an empty sequence

<Figure size 2100x2100 with 0 Axes>