In [1]:
import pickle
from scipy import stats
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme();

import numpy as np
import nibabel as nib

import torch
import torch.nn.functional as F

from sage.visualization.vistool import plot_vismap
from utils.analysis import Result, FileSelector, check_existence, cherry_picker, transform



## 400개 중에 빠진 거 확인

In [10]:
from pathlib import Path
from glob import glob

from sage.config import load_config

In [18]:
def check_seeds(header):

    print(f"Checking {header} ...")
    runs = sorted(glob(str(Path(header, "*"))))

    seeds = set([load_config(Path(run, "config.yml")).seed for run in runs])

    num_seeds = len(seeds)
    min_seed, max_seed = min(seeds), max(seeds)

    print(f"Total {num_seeds} seeds")
    print(f"Min {min_seed} | Max {max_seed}")
    print(f"There should be {max_seed - min_seed + 1} seeds and you have {num_seeds}")
    if max_seed - min_seed + 1 != num_seeds:
        print(f"Missing {set(list(range(min_seed, max_seed + 1))) - seeds}")
    print()

In [19]:
check_seeds("../resnet256_augmentation_checkpoints/")
check_seeds("../resnet256_augmentation_nonreg_checkpoints/")
check_seeds("../resnet256_naive_checkpoints/")
check_seeds("../resnet256_naive_nonreg_checkpoints/")

Checking ../resnet256_augmentation_checkpoints/ ...
Total 100 seeds
Min 43 | Max 142
There should be 100 seeds and you have 100

Checking ../resnet256_augmentation_nonreg_checkpoints/ ...
Total 99 seeds
Min 42 | Max 141
There should be 100 seeds and you have 99
Missing {77}

Checking ../resnet256_naive_checkpoints/ ...
Total 101 seeds
Min 42 | Max 143
There should be 102 seeds and you have 101
Missing {88}

Checking ../resnet256_naive_nonreg_checkpoints/ ...
Total 98 seeds
Min 43 | Max 141
There should be 99 seeds and you have 98
Missing {92}



## Run GradCAM on UN-registered checkpoints with Dataset

## Choose Best MAE results

In [21]:
import yaml

HEADER_DICT = {
    "reg_naive": "../resnet256_naive_checkpoints/",
    "reg_aug": "../resnet256_augmentation_checkpoints/",
    "nonreg_naive": "../resnet256_naive_nonreg_checkpoints/",
    "nonreg_aug": "../resnet256_augmentation_nonreg_checkpoints/",
}

with open(f"data/test_gt_age.yml", "r") as f:
    GT_TEST = yaml.load(f, Loader=yaml.Loader)

In [None]:
def get_test_result(path):
    """takes single run path (single seed)
    that contains test.yml

    Than returns MAE of the test result among epochs
    Note that test.yml is in the form of {epoch: [list of predictions], ...}
    """
    test_yml_path = Path(path, "test.yml")
    with open(test_yml_path, "r") as f:
        test_prediction = yaml.load(f, Loader=yaml.Loader)
    return test_prediction