In [4]:
import logging
import pickle
from glob import glob

logging.basicConfig(
    format="[%(asctime)s] %(levelname)s - %(name)s: %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

In [5]:
ROOTS = {
    "naive": "../resnet256_naive_checkpoints",
    "augment": "../resnet256_augmentation_checkpoints",
}

### #TODO
1. Retrieve learning curve for each seeds
2. -


In [6]:
RUNS_LIST = {
    _type: glob(ROOTS[_type] + "/*") for _type in ROOTS
}

In [7]:
run_data = sorted(glob(RUNS_LIST["naive"][0] + "/encoder/*.pt"))

In [8]:
def get_runs(idx: int = 0, which: str = "naive", paths: list = RUNS_LIST):

    """
    paths:
        List of runs containing ROOT/runs
    idx:
        Specify which run you will return.
    """
    
    runs = sorted(glob(paths[which][idx] + "/encoder/*.pt"))
    return runs

def epoch_parser(run_name: str):

    """
    run_name: will look like as follows
        ep000_mae57.33.pt
        epXXX_maeXX.XX.pt
    """

    try:
        epoch = run_name[2:5]
        return int(epoch)

    except:
        logger.info(f"Received {run_name} and couldn't parse into EPOCH with a given run name.")
        raise

def mae_parser(run_name: str):

    """
    run_name: will look like as follows
        ep000_mae57.33.pt
        epXXX_maeXX.XX.pt
    """

    try:
        mae = run_name.split("mae")[-1].rstrip(".pt")
        return float(mae)
    except:
        logger.info(f"Received {run_name} and couldn't parse into MAE with a given run name.")
        raise

def to_epoch_mae_tuple(path: str):

    run_name = path.split("/")[-1]
    return epoch_parser(run_name), mae_parser(run_name)
    

In [9]:
naive_results = {
    path.split("/")[-1]: list(map(to_epoch_mae_tuple, get_runs(idx)))
    for idx, path in enumerate(RUNS_LIST["naive"])
}

augment_results = {
    path.split("/")[-1]: list(map(to_epoch_mae_tuple, get_runs(idx, "augment")))
     for idx, path in enumerate(RUNS_LIST["augment"])
}

In [11]:
with open("./data/naive_nonreg_results_renewal.pkl", "wb") as f:
    pickle.dump(naive_results, f)
with open("./data/augment_nonreg_results_renewal.pkl", "wb") as f:
    pickle.dump(augment_results, f)

### Gather Test Datset Result