## Train your own models
To train you simply need to call `train.train`.
We give all necessary code. The most important bits are in the `priors` dir, e.g. `hebo_prior`, it stores the priors
with which we train our models.

### Training the HEBO+ model, `model_hebo_morebudget_9_unused_features_3.pt`
You can train this model on 8 GPUs using `torchrun` or `submitit`

In [6]:
import torch
from pfns import priors, encoders, utils, bar_distribution, train
from ConfigSpace import hyperparameters as CSH

config_heboplus = {
    "priordataloader_class_or_get_batch": priors.get_batch_to_dataloader(
        priors.get_batch_sequence(
            priors.hebo_prior.get_batch,
            priors.utils.sample_num_feaetures_get_batch,
        )
    ),
    "encoder_generator": encoders.get_normalized_uniform_encoder(
        encoders.get_variable_num_features_encoder(encoders.Linear)
    ),
    "emsize": 512,
    "nhead": 4,
    "warmup_epochs": 5,
    "y_encoder_generator": encoders.Linear,
    "batch_size": 128,
    "scheduler": utils.get_cosine_schedule_with_warmup,
    "extra_prior_kwargs_dict": {
        "num_features": 18,
        "hyperparameters": {
            "lengthscale_concentration": 1.2106559584074301,
            "lengthscale_rate": 1.5212245992840594,
            "outputscale_concentration": 0.8452312502679863,
            "outputscale_rate": 0.3993553245745406,
            "add_linear_kernel": False,
            "power_normalization": False,
            "hebo_warping": False,
            "unused_feature_likelihood": 0.3,
            "observation_noise": True,
        },
    },
    "epochs": 50,
    "lr": 0.0001,
    "seq_len": 60,
    "single_eval_pos_gen": utils.get_uniform_single_eval_pos_sampler(
        50, min_len=1
    ),  # <function utils.get_uniform_single_eval_pos_sampler.<locals>.<lambda>()>,
    "aggregate_k_gradients": 2,
    "nhid": 1024,
    "steps_per_epoch": 1024,
    "weight_decay": 0.0,
    "train_mixed_precision": False,
    "efficient_eval_masking": True,
    "nlayers": 12,
}


config_heboplus_userpriors = {
    **config_heboplus,
    "priordataloader_class_or_get_batch": priors.get_batch_to_dataloader(
        priors.get_batch_sequence(
            priors.hebo_prior.get_batch,
            priors.condition_on_area_of_opt.get_batch,
            priors.utils.sample_num_feaetures_get_batch,
        )
    ),
    "style_encoder_generator": encoders.get_normalized_uniform_encoder(
        encoders.get_variable_num_features_encoder(encoders.Linear)
    ),
}

config_bnn = {
    "priordataloader_class_or_get_batch": priors.get_batch_to_dataloader(
        priors.get_batch_sequence(
            priors.simple_mlp.get_batch,
            priors.input_warping.get_batch,
            priors.utils.sample_num_feaetures_get_batch,
        )
    ),
    "encoder_generator": encoders.get_normalized_uniform_encoder(
        encoders.get_variable_num_features_encoder(encoders.Linear)
    ),
    "emsize": 512,
    "nhead": 4,
    "warmup_epochs": 5,
    "y_encoder_generator": encoders.Linear,
    "batch_size": 128,
    "scheduler": utils.get_cosine_schedule_with_warmup,
    "extra_prior_kwargs_dict": {
        "num_features": 18,
        "hyperparameters": {
            "mlp_num_layers": CSH.UniformIntegerHyperparameter("mlp_num_layers", 8, 15),
            "mlp_num_hidden": CSH.UniformIntegerHyperparameter(
                "mlp_num_hidden", 36, 150
            ),
            "mlp_init_std": CSH.UniformFloatHyperparameter(
                "mlp_init_std", 0.08896049884896237, 0.1928554813280186
            ),
            "mlp_sparseness": 0.1449806273312999,
            "mlp_input_sampling": "uniform",
            "mlp_output_noise": CSH.UniformFloatHyperparameter(
                "mlp_output_noise", 0.00035983014290491186, 0.0013416342770574585
            ),
            "mlp_noisy_targets": True,
            "mlp_preactivation_noise_std": CSH.UniformFloatHyperparameter(
                "mlp_preactivation_noise_std",
                0.0003145707276259681,
                0.0013753183831259406,
            ),
            "input_warping_c1_std": 0.9759720822120248,
            "input_warping_c0_std": 0.8002534583197192,
            "num_hyperparameter_samples_per_batch": 16,
        },
    },
    "epochs": 50,
    "lr": 0.0001,
    "seq_len": 60,
    "single_eval_pos_gen": utils.get_uniform_single_eval_pos_sampler(50, min_len=1),
    "aggregate_k_gradients": 1,
    "nhid": 1024,
    "steps_per_epoch": 1024,
    "weight_decay": 0.0,
    "train_mixed_precision": True,
    "efficient_eval_masking": True,
}


# now let's add the criterions, where we decide the border positions based on the prior
def get_ys(config, device="cuda:0"):
    bs = 128
    all_targets = []
    for num_hps in [
        2,
        8,
        12,
    ]:  # a few different samples in case the number of features makes a difference in y dist
        b = config["priordataloader_class_or_get_batch"].get_batch_method(
            bs,
            1000,
            num_hps,
            epoch=0,
            device=device,
            hyperparameters={
                **config["extra_prior_kwargs_dict"]["hyperparameters"],
                "num_hyperparameter_samples_per_batch": -1,
            },
        )
        all_targets.append(b.target_y.flatten())
    return torch.cat(all_targets, 0)


def add_criterion(config, device="cuda:0"):
    return {
        **config,
        "criterion": bar_distribution.FullSupportBarDistribution(
            bar_distribution.get_bucket_limits(1000, ys=get_ys(config, device).cpu())
        ),
    }

In [9]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [10]:
# Now let's train either with
train.train(**add_criterion(config_heboplus, device="cpu:0"))
# or
# train.train(**add_criterion(config_heboplus_userpriors))
# or
# train.train(**add_criterion(config_bnn))

Using 384000 y evals to estimate 1000 buckets. Cut off the last 0 ys.
Using cpu:0 device
init dist
Not using distributed
DataLoader.__dict__ {'num_steps': 1024, 'get_batch_kwargs': {'batch_size': 128, 'eval_pos_seq_len_sampler': <function train.<locals>.eval_pos_seq_len_sampler at 0x7f817c751ee0>, 'seq_len_maximum': 60, 'device': 'cpu:0', 'num_features': 18, 'hyperparameters': {'lengthscale_concentration': 1.2106559584074301, 'lengthscale_rate': 1.5212245992840594, 'outputscale_concentration': 0.8452312502679863, 'outputscale_rate': 0.3993553245745406, 'add_linear_kernel': False, 'power_normalization': False, 'hebo_warping': False, 'unused_feature_likelihood': 0.3, 'observation_noise': True}}, 'num_features': 18, 'epoch_count': 0}
Style definition of first 3 examples: None
Initialized decoder for standard with (None, 1000)  and nout 1000
Using a Transformer with 26.79 M parameters


(inf,
 inf,
 TransformerModel(
   (transformer_encoder): TransformerEncoderDiffInit(
     (layers): ModuleList(
       (0): TransformerEncoderLayer(
         (self_attn): MultiheadAttention(
           (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
         )
         (linear1): Linear(in_features=512, out_features=1024, bias=True)
         (dropout): Dropout(p=0.0, inplace=False)
         (linear2): Linear(in_features=1024, out_features=512, bias=True)
         (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
         (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
         (dropout1): Dropout(p=0.0, inplace=False)
         (dropout2): Dropout(p=0.0, inplace=False)
       )
       (1): TransformerEncoderLayer(
         (self_attn): MultiheadAttention(
           (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
         )
         (linear1): Linear(in_features=512, out_featu