In [1]:
import sys
sys.path.append('/content/drive/My Drive/Alice codes/Askotch related/fast_krr_alisver//plotting')
import os
os.chdir('/content/drive/My Drive/Alice codes/Askotch related/fast_krr_alisver/')
# Now you can import your module
# import your_module

In [2]:
from functools import partial

from tqdm import tqdm

from constants import (
    FONTSIZE,
    X_AXIS,

    BASE_SAVE_DIR,
    EXTENSION,
)
from constants import  PROJECT_FULL_KRR, PROJECT_INDUCING_KRR
from constants import PERFORMANCE_DATASETS_CFG
from base_utils import set_fontsize, render_in_latex
HPARAMS_TO_LABEL = {
    "askotchv2": ["precond", "r", "sampling_method"],
    "skotchv2": ["precond", "r", "sampling_method"],
    "askotchv3": ["precond", "r", "sampling_method"],
    "sap": ["b"],
    "nsap": ["b"],
    "eigenpro2": [],
    "eigenpro3": ["m"],
    "pcg": ["precond", "r"],
    "falkon": ["m"],
    "mimosa": ["precond", "r", "m"],
}
ENTITY_NAME = "fengx8086-stanford-university"
PROJECT_FULL_KRRv3 = "performance_full_krr_v3_"
USE_LATEX= False

In [3]:
#new cfg_utils.py file

from base_utils import (
    get_project_runs,
    filter_runs_union,
    plot_runs_grid,
    keep_largest_m,
)


def _get_grid_shape(datasets_cfg):
    n_rows = datasets_cfg["grid"]["n_rows"]
    n_cols = datasets_cfg["grid"]["n_cols"]
    return n_rows, n_cols


def _get_save_name(name_stem, datasets_cfg, extension):
    return name_stem + datasets_cfg["name_ext"] + "." + extension


def get_save_dir(base_save_dir, name):
    return os.path.join(base_save_dir, name)


def create_krr_config(proj_name, base_criteria):
    config = {
        "proj_name": proj_name,
        "criteria_list": base_criteria,
    }
    return config


def _get_filtered_runs(krr_cfg, ds, entity_name):
    if krr_cfg is None:
        return []

    project_name = krr_cfg["proj_name"] + ds
    runs = get_project_runs(entity_name, project_name)
    runs = filter_runs_union(runs, krr_cfg["criteria_list"])
    return runs


def plot_runs_dataset_grid(
    entity_name,
    full_krr_cfg,
    full_krr_cfgv3,
    datasets_cfg,
    hparams_to_label,
    x_axis,
    name_stem,
    save_dir,
    extension,
    keep_largest_m_runs=True,
):
    run_lists = []
    metrics = []
    plot_fns = []
    xlims = []
    ylims = []
    titles = []

    n_rows, n_cols = _get_grid_shape(datasets_cfg)
    save_name = _get_save_name(name_stem, datasets_cfg, extension)

    for ds, config in datasets_cfg["datasets"].items():
        runs_full_krr = _get_filtered_runs(full_krr_cfg, ds, entity_name)
        runs_full_krrv3 = _get_filtered_runs(full_krr_cfgv3, ds, entity_name)
        run_lists.append(runs_full_krr + runs_full_krrv3)
        metrics.append(config["metric"])
        plot_fns.append(config.get("plot_fn", None))
        xlims.append(config.get("xlim", None))
        ylims.append(config["ylim"])
        titles.append(ds)

    plot_runs_grid(
        run_lists,
        hparams_to_label,
        metrics,
        plot_fns,
        x_axis,
        xlims,
        ylims,
        titles,
        n_cols,
        n_rows,
        save_dir,
        save_name,
    )

In [4]:

os.environ['WANDB_API_KEY'] = 'e72f7d1831a6bf60e7648f04dbd6da67ce2bcc11'
# save directory
SAVE_DIR = "performance_comparison"
SKOTCH_FILTER = {
    "optimizer": lambda run: run.config["opt"] == "askotchv2",
    "accelerated": lambda run: run.config["accelerated"]==False,
    "preconditioned": lambda run: run.config["precond_params"] is not None,
    "rho_damped": lambda run: run.config.get("precond_params", {}).get("rho", None)
    == "damped",
    "sampling": lambda run: run.config["sampling_method"] == "uniform",
    "block_sz_frac": lambda run: run.config["block_sz_frac"] == 0.01,
   # "finished": lambda run: run.state == "finished",
}
# filters for runs
ASKOTCH_FILTER = {
    "optimizer": lambda run: run.config["opt"] == "askotchv2",
    "accelerated": lambda run: run.config["accelerated"],
    "preconditioned": lambda run: run.config["precond_params"] is not None,
    "rho_damped": lambda run: run.config.get("precond_params", {}).get("rho", None)
    == "damped",
    "sampling": lambda run: run.config["sampling_method"] == "uniform",
    "block_sz_frac": lambda run: run.config["block_sz_frac"] == 0.01,
   # "finished": lambda run: run.state == "finished",
}

# filters for runs
ASKOTCHv3_FILTER = {
    "optimizer": lambda run: run.config["opt"] == "askotchv3",
    "accelerated": lambda run: run.config["accelerated"],
    "preconditioned": lambda run: run.config["precond_params"] is not None,
    "rho_damped": lambda run: run.config.get("precond_params", {}).get("rho", None)
    == "damped",
    "sampling": lambda run: run.config["sampling_method"] == "uniform",
    "block_sz_frac": lambda run: run.config["block_sz_frac"] == 0.01,
    "name": lambda run: "dec16" in run.name,
  #  "finished": lambda run: run.state == "finished",
}

if __name__ == "__main__":
    set_fontsize(FONTSIZE)
    if USE_LATEX:
        render_in_latex()

    plot_fn = partial(
        plot_runs_dataset_grid,
        entity_name=ENTITY_NAME,
        hparams_to_label=HPARAMS_TO_LABEL,
        x_axis=X_AXIS,
        save_dir=get_save_dir(BASE_SAVE_DIR, SAVE_DIR),
        extension=EXTENSION,
    )

    full_krr_cfg_float32 = create_krr_config(
        PROJECT_FULL_KRR, [SKOTCH_FILTER,ASKOTCH_FILTER]
    )
    full_krr_cfg_float32v3 = create_krr_config(
        PROJECT_FULL_KRRv3, [ASKOTCHv3_FILTER]
    )

    with tqdm(
        total=2 * len(PERFORMANCE_DATASETS_CFG), desc="Performance comparison"
    ) as pbar:
        for datasets_cfg in PERFORMANCE_DATASETS_CFG:
            plot_fn(
                full_krr_cfg=full_krr_cfg_float32,
                full_krr_cfgv3=full_krr_cfg_float32v3,
                datasets_cfg=datasets_cfg,
                name_stem="float32_",
                keep_largest_m_runs=False,
            )
            pbar.update(1)


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mfengx8086[0m ([33mfengx8086-stanford-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  ax.set_ylim(ylim)
Performance comparison:  50%|█████     | 5/10 [05:32<05:32, 66.55s/it]
