## Introduction

This notebook demonstrates the steps to sample hyper-representations, finetune the samlped models and compare to baselines.  
Make sure to install the ghrp package by running `pip3 install .` in the main directory and download the data first by running `bash download_data.sh` in `/data`.

In [None]:
# import modules
from pathlib import Path
from ghrp.sampling_auxiliaries.sample_finetune_auxiliaries import *
import ray

In [None]:
# set experiment / data root path

# PATH_ROOT = Path("./../data/hyper_representations/mnist")
PATH_ROOT = Path("./../data/hyper_representations/svhn")
# PATH_ROOT = Path("./../data/hyper_representations/cifar10")
# PATH_ROOT = Path("./../data/hyper_representations/stl10")

In [None]:
# set parameterization: where to put sampled populations, how many models to sample, source/target zoos

# set experiment path
experiment_path = PATH_ROOT
path_to_samples = experiment_path.joinpath("samples")
population_root_path = experiment_path

# set how many samples to generate
no_samples_generation = 25
# set how many samples to finetune
no_samples = 5

# set offset for in-distribution comparison
id_offset = 25
# set training epochs for in-distribution and out-of-distribuiton
training_epochs_id = 15
training_epochs_ood = 15

# set computation ressources
cpus = 4
cpu_per_trial = 1
take_config_from = "target"

# OOD zoo paths
source_zoo_path = Path("./../data/zoos/svhn/")
target_zoo_paths = [
    Path("./../data/zoos/mnist/"),
]



In [None]:
############ sample models
print("#### Sample new weights")

ae_config_path, model_config_path = sample(
    experiment_path=experiment_path,
    path_to_samples=path_to_samples,
#     no_samples=no_samples_generation,
    no_samples=0,
)
ae_config = json.load(ae_config_path.open("r"))

In [None]:
## find strs of source / target domain for naming directories
source = find_domain_from_path(population_path=source_zoo_path)
target = find_domain_from_path(population_path=source_zoo_path)
population_path = population_root_path.joinpath(f"{source}_to_{target}")
population_path

In [None]:
## finetune sampled populations on original dataset (in distribution)
print("#### Finetune ID")
source = find_domain_from_path(population_path=source_zoo_path)
target = find_domain_from_path(population_path=source_zoo_path)
population_path = population_root_path.joinpath(f"{source}_to_{target}")

if ray.is_initialized():
    ray.shutdown()

finetune(
    project=f"{source}_to_{target}",
    population_path=population_path,
    path_to_samples=path_to_samples,
    path_target_zoo=source_zoo_path,
    model_config_path=model_config_path,
    model_config=take_config_from,
    no_samples=no_samples,
    training_epochs=training_epochs_id,
    cpus=cpus,
    skip=["uniform", "train", "kde_z_train"],
)


In [None]:
### call plot function to visualize results
plot_domains = [
    "baseline",
    "direct",
    "uniform",
    # "train",
    "best",
    # "kde_z_train",
    "kde_z_best",
    # "gan",
    "gan_best",
]


## plot figures
print("#### Plot Finetune Figures")
plot_populations(
    source=source,
    target=target,
    path_target_zoo=source_zoo_path,
    population_path=population_path,
    layer_lst=ae_config["trainset::layer_lst"],
    id_offset=id_offset,
    plot_domains=plot_domains,
)


In [None]:
print(f'Figures and population satistics can be found in {population_path}')

In [None]:
### finetune sampled populations on different image datasets
# finetune OOD
print("#### Finetune OOD")
target_zoo_path = target_zoo_paths[0]
# finetune ID
source = find_domain_from_path(population_path=source_zoo_path)
target = find_domain_from_path(population_path=target_zoo_path)
print(f"#### Finetune OOD on domain: {target}")
population_path = population_root_path.joinpath(f"{source}_to_{target}")
finetune(
    project=f"{source}_to_{target}",
    population_path=population_path,
    path_to_samples=path_to_samples,
    path_target_zoo=target_zoo_path,
    model_config_path=model_config_path,
    model_config=take_config_from,
    no_samples=no_samples,
    training_epochs=training_epochs_ood,
    cpus=cpus,
    cpu_per_trial=cpu_per_trial,
    skip=["uniform", "train","kde_z_train"],
)


In [None]:
## plot figures
print(f"#### Plot OOD Figures {target}")
plot_populations(
    source=source,
    target=target,
    path_target_zoo=target_zoo_path,
    population_path=population_path,
    layer_lst=ae_config["trainset::layer_lst"],
    id_offset=0,
    plot_domains=plot_domains,
)

In [None]:
print(f'Figures and population satistics can be found in {population_path}')