In [1]:
import logging
import pickle
from glob import glob

logging.basicConfig(
    format="[%(asctime)s] %(levelname)s - %(name)s: %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

In [2]:
ROOTS = {
    "naive": "../resnet256_naive_nonreg_checkpoints",
    "augment": "../resnet256_augmentation_nonreg_checkpoints",
}

### #TODO
1. Retrieve learning curve for each seeds
2. -


In [3]:
RUNS_LIST = {
    _type: glob(ROOTS[_type] + "/*") for _type in ROOTS
}

In [4]:
run_data = sorted(glob(RUNS_LIST["naive"][0] + "/encoder/*.pt"))

In [5]:
def get_runs(idx: int = 0, which: str = "naive", paths: list = RUNS_LIST):

    """
    paths:
        List of runs containing ROOT/runs
    idx:
        Specify which run you will return.
    """
    
    runs = sorted(glob(paths[which][idx] + "/encoder/*.pt"))
    return runs

def epoch_parser(run_name: str):

    """
    run_name: will look like as follows
        ep000_mae57.33.pt
        epXXX_maeXX.XX.pt
    """

    try:
        epoch = run_name[2:5]
        return int(epoch)

    except:
        logger.info(f"Received {run_name} and couldn't parse into EPOCH with a given run name.")
        raise

def mae_parser(run_name: str):

    """
    run_name: will look like as follows
        ep000_mae57.33.pt
        epXXX_maeXX.XX.pt
    """

    try:
        mae = run_name.split("mae")[-1].rstrip(".pt")
        return float(mae)
    except:
        logger.info(f"Received {run_name} and couldn't parse into MAE with a given run name.")
        raise

def to_epoch_mae_tuple(path: str):

    run_name = path.split("/")[-1]
    return epoch_parser(run_name), mae_parser(run_name)
    

In [15]:
naive_results = {
    path.split("/")[-1]: list(map(to_epoch_mae_tuple, get_runs(idx)))
    for idx, path in enumerate(RUNS_LIST["naive"])
}

augment_results = {
    path.split("/")[-1]: list(map(to_epoch_mae_tuple, get_runs(idx, "augment")))
     for idx, path in enumerate(RUNS_LIST["augment"])
}

In [16]:
with open("./data/naive_nonreg_results.pkl", "wb") as f:
    pickle.dump(naive_results, f)
with open("./data/augment_nonreg_results.pkl", "wb") as f:
    pickle.dump(augment_results, f)

In [17]:
naive_results

{'20220120-0342_resnet': [(0, 53.82),
  (1, 47.84),
  (2, 51.99),
  (3, 12.69),
  (4, 24.78),
  (5, 42.89),
  (6, 24.59),
  (7, 11.59),
  (8, 34.98),
  (9, 26.34),
  (10, 40.49),
  (11, 8.14),
  (12, 8.67),
  (13, 7.28),
  (14, 9.44),
  (15, 7.31),
  (16, 9.42),
  (17, 8.38),
  (18, 8.17),
  (19, 14.16),
  (20, 15.87),
  (21, 8.32),
  (22, 15.77),
  (23, 7.8),
  (24, 9.27),
  (25, 7.93),
  (26, 18.2),
  (27, 16.65),
  (28, 7.85),
  (29, 10.28),
  (30, 8.93),
  (31, 7.54),
  (32, 8.56),
  (33, 10.35)],
 '20220120-1318_resnet': [(0, 55.8),
  (1, 27.79),
  (2, 17.63),
  (3, 13.78),
  (4, 9.86),
  (5, 12.02),
  (6, 43.39),
  (7, 9.34),
  (8, 20.76),
  (9, 9.89),
  (10, 16.31),
  (11, 9.78),
  (12, 10.18),
  (13, 24.81),
  (14, 10.58),
  (15, 14.06),
  (16, 8.77),
  (17, 8.62),
  (18, 8.43),
  (19, 9.04),
  (20, 10.04),
  (21, 13.47),
  (22, 14.62),
  (23, 18.04),
  (24, 15.11),
  (25, 6.93),
  (26, 7.81),
  (27, 8.08),
  (28, 11.6),
  (29, 7.65),
  (30, 11.9),
  (31, 13.9),
  (32, 10.59),


In [18]:
len(augment_results)

99

In [19]:
with open("./data/naive_results.pkl", "rb") as f:
    data = pickle.load(f)