In [1]:
import jax.numpy as jnp
from jax import lax
from pygmo import hypervolume # for computing hypervolume indicator
import math

In [None]:
class GammaFunction:
    """
    Gamma (parameter) function for the MotPE sampler.
    """
    def __init__(self, gamma=0.10):
        self.gamma = gamma

    def __call__(self, x):
        '''
        return number of bad samples
        '''
        return int(lax.floor(self.gamma * x))

## TPESampler
An implmentation of the Tree Parzen Estimator Sampler

### NDSort
The goal of NDsort is to assign nondomination rank to each of objective funciton results (y value).
We rank by strongly domination $$y \in \mathbb{R}^m $$
Here I provide an example on how NDSort works:
- Assume we have $\mathbb{R} = 2$ on objective function space, then given each $y_sample$

In [None]:
class TPESampler:
    def __init__(self,
                 hyper_param,  # Hyperparameters
                 observations,  # Observations of the objective function
                 random_state,  # Random seed
                 n_EI_candidates=24,  # Number of candidates for EI maximization
                 rule='james',  # Rule to use for the next best hyperparameter to sample
                 gamma_func=GammaFunction(),  # A function to compute the gamma parameter
                 weights_func=default_weights,  # A function to compute the weights
                 split_cache=None):  # A cache of the splits
        self.hyper_param_ = hyper_param
        self.observations_ = observations
        self.random_state_ = random_state
        self.n_EI_candidates_ = n_EI_candidates
        self.rule_ = rule
        self.gamma_func_ = gamma_func  # split gamma
        self.weights_func_ = weights_func

        # this is for split observation
        if split_cache is None:
            self.split_cache_ = {}
        else:
            self.split_cache_ = split_cache.copy()

    def NDSort(self, y_val):
        '''
        Non-dominated rank
        '''
        tmp_y_val = y_val.copy()
        NDRanks = jnp.zeros(len(y_val))

        current_rank = 0
        cnt = len(y_val)

        while cnt > 0:
            # change into 3D array with y_val.shape[0] along axis 0
            y_val_extend = jnp.tile(tmp_y_val, reps=(y_val.shape[0], 1, 1))
            y_val_swap = jnp.swapaxes(y_val_extend, 0, 1)
            dominance = jnp.sum(
                jnp.all(y_val_extend < y_val_swap, axis=2), axis=1)
            # assigned entry will be set to inf
            tmp_y_val.at[dominance == 0].set(jnp.finfo(jnp.float32).max)
            NDRanks.at[dominance == 0].set(current_rank)
            current_rank += 1
            cnt -= jnp.sum(dominance == 0)

        return NDRanks

    def split_observation(self, y_collection, D_l_sample_cnt):
        '''
        Algorithm 2 in Journal
        Split the observations into good and bad observations
        '''
        # TODO: how to cache results?
        # I dont know if this will work. Maybe use pickle is a better idea
        cache_key = y_collection.tobytes()

        if cache_key in self.split_cache_:
            # bad samples indices
            D_l_sample_indices = self.split_cache_[cache_key]["l_x_i"]
            # good samples indices
            D_g_sample_indices = self.split_cache_[cache_key]["g_x_i"]
        else:
            y_ranks = self.NDSort(y_collection)
            indices = jnp.array(range(len(y_collection)))
            D_l_sample_cnt = jnp.array(range(len(y_ranks)))
            D_l_sample_indices = jnp.array([], dtype=jnp.int32)

            # gather D_l_sample until the D_l_sample_cnt is reached
            # step 1: gather until D_l_sample_cnt is almost reached
            rank = 0
            while (len(D_l_sample_indices) + jnp.sum(y_ranks == rank)) <= D_l_sample_cnt:
                D_l_sample_indices = jnp.concatenate(
                    (D_l_sample_indices, indices[y_ranks == 0]))
                rank += 1

            # HSSP
            # step 2: gather subset until HSSP is reached
            y_refer_pt = y_collection[y_ranks == rank]
            y_refer_pt_indices = indices[y_ranks == rank]
            y_worst = jnp.max(y_collection, axis=0)

            # reference pointer creation: lower bound by eps
            reference_pt = jnp.maximum(
                1.1 * y_worst, jnp.full(y_worst.shape[1], jnp.finfo.eps))

            # collecting subsets
            # TODO, better name
            y_subset = []
            HV_constribution = []

            # Algorithm 3 in Journal, line 2-3
            for j in range(len(y_refer_pt)):
                HV_indicator = hypervolume(
                    [y_refer_pt[j]]).compute(reference_pt)
                HV_constribution.append(HV_indicator)

            # fill the rest of the D_l_sample_indices
            while len(D_l_sample_indices) + 1 < D_l_sample_cnt:
                HV_subset_indicator = 0
                if len(y_subset) > 0:
                    HV_subset_indicator = hypervolume(
                        y_subset).compute(reference_pt)
                max_contribution_indices = jnp.argmax(HV_constribution)

                # assigned points will be set to -inf to avoid being selected again
                HV_constribution[max_contribution_indices] = -1 * jnp.inf

                for i in range(len(HV_constribution)):
                    # skip current max
                    if i == max_contribution_indices:
                        continue

                    HV_constribution[i] = HV_constribution[i] - (hypervolume(y_subset + [jnp.max(y_refer_pt[max_contribution_indices], y_refer_pt[i], axis=0)]).compute(reference_pt) - HV_subset_indicator)
                
                y_subset += y_refer_pt[max_contribution_indices]
                D_l_sample_indices = jnp.concatenate(
                    (D_l_sample_indices, y_refer_pt_indices[max_contribution_indices]))
                
            # step 3: gather the rest of the points as D_g_sample_indices
            D_g_sample_indices = jnp.setdiff1d(indices, D_l_sample_indices)

            # cache the results
            self.split_cache_[cache_key] = {
                "l_x_i": D_l_sample_indices,
                "g_x_i": D_g_sample_indices
            }

        return D_l_sample_indices, D_g_sample_indices
    

    def get_type(self):
        '''
        This is only compatible with configspace
        TODO need to change
        '''
        cs_dist = str(type(self.hp))
        if 'Integer' in cs_dist:
            return int
        elif 'Float' in cs_dist:
            return float
        elif 'Categorical' in cs_dist:
            var_type = type(self.hp.choices[0])
            if var_type == str or var_type == bool:
                return var_type
            else:
                raise ValueError('The type of categorical parameters must be "bool" or "str".')
        else:
            raise NotImplementedError('The distribution is not implemented.')

    def sample(self):
        # Load the observed hyperparameter values and corresponding function values
        hyper_param, y_collection=self.load_hyper_param_and_y_collection()

        # use gamma to determine the split ratio
        D_l_cnt = self.gamma_func_(len(hyper_param))
        D_l_sample_indices, D_g_sample_indices = self.split_observation(
            y_collection, D_l_cnt)
        D_l_sample = hyper_param[D_l_sample_indices]
        D_g_sample = hyper_param[D_g_sample_indices]

        # Determine the type of the hyperparameter
        var_type = self.get_type()

        # Sample a new hyperparameter
        if var_type in [float, int]:
            new_hyper_param = self.sample_continuous(D_l_sample, D_g_sample)
        elif var_type == str:
            new_hyper_param = self.sample_categorical(D_l_sample, D_g_sample)
        
        new_hyper_param = self.revert_h(new_hyper_param)
        return new_hyper_param
    
    def sample_continuous(self, D_l_sample, D_g_sample):
        pass

    def sample_categorical(self, D_l_sample, D_g_sample):
        pass

    def revert_h(self, hyper_param_val):
        pass

    def load_hyper_param_and_y_collection(self):
        # Load the hyperparameter values
        hyper_param=[]
        y_collection=[]

        # TODO: this data structure can be vectorized
        for obs in self.observations_:
            if self.hyper_param_.name in obs['params']:
                # TODO convert_h
                hyper_param.append(self.convert_h(
                    obs['params'][self.hyper_param_.name]))
                y_collection.append(obs['f_value'].values())

        jnp.asarray(hyper_param)
        jnp.asarray(y_collection, axis=0)
        return hyper_param, y_collection

    # TODO: better name?
    def convert_h(self, hyper_param_val):
        try:
            lower_bound, upper_bound, _=self.get_bound_and_q()
            if self.hyper_param_.type == 'loguniform':
                hyper_param_val=lax.log(hyper_param_val)
            return (hyper_param_val - lower_bound) / (upper_bound - lower_bound)
        except NotImplementedError:
            raise NotImplementedError(
                'Categorical parameters do not have lower and upper options.')

    def get_bound_and_q(self):
        try:
            if self.hyper_param_.type == 'loguniform':
                return lax.log(self.hyper_param_.lower), lax.log(self.hyper_param_.upper), self.hyper_param_.q
            else:
                return self.hyper_param_.lower, self.hyper_param_.upper, self.hyper_param_.q
        except NotImplementedError:
            raise NotImplementedError(
                'Categorical parameters do not have the log scale option.')
