## Introduction

This notebook goes through the steps to train the hyper-representations, to replicate the experiments from the paper.  
Make sure to install the ghrp package by running `pip3 install .` in the main directory and download the data first by running `bash download_data.sh` in `/data`.

In [None]:
import torch
import ray
from ray import tune
from ray.tune.logger import DEFAULT_LOGGERS
from ray.tune.integration.wandb import WandbLogger

import json
import sys
from pathlib import Path

from ghrp.model_definitions.def_simclr_ae_trainable import SimCLR_AE_tune_trainable
from ghrp.checkpoints_to_datasets.dataset_simclr import SimCLRDataset

In [None]:
# set which hyper-representation to train

PATH_ROOT = Path("./../data/hyper_representations/mnist")
# PATH_ROOT = Path("./../data/hyper_representations/svhn")
# PATH_ROOT = Path("./../data/hyper_representations/cifar10")
# PATH_ROOT = Path("./../data/hyper_representations/stl10")

In [None]:
# load config
config_path = PATH_ROOT.joinpath('config_ae.json')
config = json.load(config_path.open('r'))
config['dataset::dump'] = PATH_ROOT.joinpath('dataset.pt').absolute()

In [None]:
# configure output path

output_dir = PATH_ROOT.joinpath("tune")
try:
    output_dir.mkdir(parents=True, exist_ok=False)
except FileExistsError:
    pass


In [None]:
gpus = 1 if torch.cuda.is_available() else 0
cpus = 4
resources_per_trial = {"cpu": cpus, "gpu": gpus}

config['device'] = torch.device('cuda') if gpus>0 else torch.device('cpu')

ray.init(
    num_cpus=cpus,
    num_gpus=gpus,
)


In [None]:
assert ray.is_initialized() == True

analysis = tune.run(
    run_or_experiment=SimCLR_AE_tune_trainable,
    name='reproduce_experiments',
    stop={
        "training_iteration": config["training::epochs_train"],
    },
    checkpoint_at_end=True,
    checkpoint_score_attr="loss_val",
    checkpoint_freq=config["training::output_epoch"],
    config=config,
    local_dir=output_dir,
#         callbacks=[
#             WandbLoggerCallback(
#                 api_key_file="/path/to/your/wandb.key",
#                 project="your project name",
#             )
#         ],
    resources_per_trial=resources_per_trial,
    reuse_actors=False,
    max_failures=1,
    fail_fast=False,
    verbose=3,
    # resume=True,
)


In [None]:
ray.shutdown()
assert ray.is_initialized() == False