In [None]:
from pprint import pprint
import hydra
import omegaconf
from experiment_util.experiment_data import ExperimentData
from experiment_util.plotting import plot_predictions, plot_metrics_for_one_run
from experiment_util.eval_single_run import eval_one_model
from metalearning_model_gmm_np.mm_gmm_np import MetaLearningModelGMMNP
%load_ext autoreload
%autoreload 2


In [None]:
## CHOOSE EXPERIMENT HERE ##
experiment = "Sinusoid1D"
# experiment = "LineSine1D"

## SET TO FALSE FOR FULL RUN ##
smoke_test = False 

In [None]:
## load config
with hydra.initialize(version_base=None, config_path="../config"):
    cfg = hydra.compose(
        config_name="config",
        overrides=[
            f"+experiment={experiment}_64_16",
            f"+model=GMMNP-{experiment}_64_16",
            f"do_smoke_test={smoke_test}",
        ],
    )
cfg = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
# pprint(cfg)

In [None]:
## generate model
print("\n****** Generating model ******")
model = MetaLearningModelGMMNP(
    cfg=cfg["model"],
    d_x=cfg["experiment"]["d_x"],
    d_y=cfg["experiment"]["d_y"],
)


In [None]:
## generate data
print("\n****** Generating data ******")
exp_data = ExperimentData(cfg["experiment"])

In [None]:
## meta train
print("\n****** Meta-Training ******")
model.meta_train(
    benchmark=exp_data.benchmark_meta_train,
    n_epochs=cfg["model"].get("n_epochs_meta_train", None),
)

In [None]:
## plot some predictions on test benchmark
print("\n****** Plotting on test set ******")
figs = plot_predictions(
    model=model,
    benchmark=exp_data.benchmark_test,
    task_ids=cfg["experiment"]["task_ids_plot"],
    context_sizes=cfg["experiment"]["context_sizes_plot"],
    n_samples=cfg["experiment"]["n_samples_plot"],
    n_epochs_adapt=cfg["model"].get("n_epochs_adapt", None),
    plot_std_y=cfg["experiment"]["plot_std_y"],
)


In [None]:
## evaluate model on test benchmark
print("\n****** Evaluating on test set ******")
test_metrics = eval_one_model(
    benchmark=exp_data.benchmark_test,
    model=model,
    context_sizes=cfg["experiment"]["context_sizes_eval"],
    n_samples=cfg["experiment"]["n_samples_eval"],
    context_size_proposal=cfg["experiment"]["context_size_proposal_test"],
    n_epochs_adapt=cfg["model"].get("n_epochs_adapt", None),
    batch_size_eval=cfg["experiment"]["batch_size_eval"],
)
plot_metrics_for_one_run(
    metrics=test_metrics,
    kind="line",
    task_aggregate_op=cfg["experiment"]["metrics_reduce_mode"],
)
