### Results Analysis, Formally

In [37]:
#### import packages and configurations
import os
import re
import glob

import pandas as pd

pd.set_option("display.max_rows", 100)



#### method to parse a single output file
def parse_output(fpath, splits):
    '''
    splits = ("TRAIN", "VAL", "Blond_Hair0Male0", "Blond_Hair0Male1", "Blond_Hair1Male0", "Blond_Hair1Male1")
    splits = ("TRAIN", "VAL", "y0place0", "y0place1", "y1place0", "y1place1")
    '''
    arg_pattern = re.compile(r"^(\S+)\s*:\s+(\S+)$")

    splits_raw_str = '|'.join(splits)
    split_acc_pattern_raw_str = rf"({splits_raw_str})\s+(\d+(?:\.\d+)?)"
    split_acc_pattern = re.compile(split_acc_pattern_raw_str)

    epoch_idx_pattern = re.compile(r"^Epoch\s\[(\d+)/\d+\]")
    
    args_dict = dict()
    results_dict = {split: dict() for split in splits}
    with open(fpath) as f:
        lines = f.readlines()

    for line in lines:
        arg_match = arg_pattern.match(line)
        epoch_match = epoch_idx_pattern.match(line)
        
        if arg_match:
            assert not epoch_match
            k, v = arg_match.groups()
            args_dict[k] = v
        
        if epoch_match:
            assert not arg_match
            epoch_idx = int(epoch_match.group(1))
            if "TEST" in line:
                line = line.split("TEST")[0]
            results = re.findall(split_acc_pattern, line)
            assert results

            for result in results:
                results_dict[result[0]][epoch_idx] = float(result[1])
        
    return args_dict, results_dict



def res_select(results_dict, mode):
    '''
    modes:
        "by_val": select the highest val epoch
        "by_worst": select the highest worst acc epoch
        "by_second": select the highest traditional group acc epoch
    '''
    
    if mode == "by_val":
        best_epoch = max(results_dict["VAL"], key=results_dict["VAL"].get)
    elif mode == "by_worst":
        epochs = set(epoch for group_dict in results_dict.values() for epoch in group_dict)
        best_epoch = max(epochs, key=lambda epoch: min(results_dict[group][epoch] for group in results_dict))
    elif mode == "by_second":
        #best_epoch = max(results_dict["y1place0"], key=results_dict["y1place0"].get)
        best_epoch = max(results_dict["Blond_Hair1Male1"], key=results_dict["Blond_Hair1Male1"].get)
    else:
        raise Exception("Mode not in..")

    return best_epoch, {g: results_dict[g][best_epoch] for g in results_dict}



def collect_results(
    folder, 
    args_to_select,
    mode="by_worst",
    splits=("TRAIN", "VAL", "Blond_Hair0Male0", "Blond_Hair0Male1", "Blond_Hair1Male0", "Blond_Hair1Male1"), 
    #splits = ("TRAIN", "VAL", "y0place0", "y0place1", "y1place0", "y1place1"),
    filename_pattern="e_*.out"):
    '''
    args_to_select is a dict with keys being arg names and values being types
    '''
    type_map = {
        "str": str,
        "int": int,
        "float": float,
    }
    
    raw_df = []
    
    for fpath in glob.glob(os.path.join(folder, filename_pattern)):
        raw_row = dict()
        args, results = parse_output(fpath, splits)
        best_epoch, best_results = res_select(results, mode=mode)
        raw_row["best_epoch"] = best_epoch
        
        for arg, arg_type in args_to_select.items():
            raw_value = args.get(arg)
            raw_row[arg] = type_map[arg_type](raw_value)

        for group_name, acc in best_results.items():
            raw_row[group_name] = acc
        
        id_match = re.search(r"e_(\d+)\.out", fpath)
        if id_match:
            index = int(id_match.group(1))
        else:
            raise Exception("Error!")
        raw_row["id"] = index
        
        raw_df.append(raw_row)
    
    df = pd.DataFrame(raw_df)
    return df

In [38]:
folder = "/home/ym2380/mask_robust/logs/aug/celebA/array09"
args_to_select = {
    "mask_rate1": "float",
    #"batch_size": "int",
    "lr": "float",
    "weight_decay": "float",
    "optimizer": "str",
}
df = collect_results(folder=folder, args_to_select=args_to_select)

In [39]:
pd.set_option("display.max_rows", 100)
#df['worst_group_acc'] = df[["y0place0", "y0place1", "y1place0", "y1place1"]].min(axis=1)
df['worst_group_acc'] = df[["Blond_Hair0Male0", "Blond_Hair0Male1", "Blond_Hair1Male0", "Blond_Hair1Male1"]].min(axis=1)

In [40]:
#df = df.drop(columns=["TRAIN", "VAL", ])
#df = df[df["lr"] == 1e-4]
#df = df[(df["lr"] == 1e-3) & (df["weight_decay"] == 1e-4)]

df = df.sort_values("worst_group_acc")
df

Unnamed: 0,best_epoch,mask_rate1,lr,weight_decay,optimizer,TRAIN,VAL,Blond_Hair0Male0,Blond_Hair0Male1,Blond_Hair1Male0,Blond_Hair1Male1,id,worst_group_acc
12,49,0.9,1e-05,0.001,SGD,0.906,0.8834,0.9933,0.9996,0.2756,0.0385,22,0.0385
4,47,0.8,1e-05,0.0001,SGD,0.924,0.9186,0.9753,0.9973,0.5762,0.0879,21,0.0879
18,50,0.8,1e-05,0.001,SGD,0.925,0.9244,0.9712,0.9986,0.6249,0.0879,20,0.0879
21,45,0.9,1e-05,0.0001,SGD,0.906,0.9127,0.9761,0.9981,0.5289,0.1154,23,0.1154
6,48,0.8,3e-05,0.001,SGD,0.932,0.9378,0.9489,0.9947,0.7898,0.1703,12,0.1703
15,43,0.8,3e-05,0.0001,SGD,0.931,0.9394,0.9509,0.996,0.7905,0.1813,13,0.1813
8,46,0.9,3e-05,0.0001,SGD,0.913,0.9317,0.9511,0.9946,0.7404,0.1813,15,0.1813
0,50,0.9,3e-05,0.001,SGD,0.916,0.9351,0.9486,0.9938,0.7738,0.1813,14,0.1813
16,49,0.8,0.0001,0.0001,SGD,0.94,0.9412,0.917,0.9934,0.904,0.2857,5,0.2857
19,44,0.8,0.0001,0.001,SGD,0.939,0.9421,0.9215,0.9936,0.896,0.2912,4,0.2912


# 