In [1]:
import jax.numpy as jnp
import gp
import jax

In [None]:
def train_gp_classifier(X, y, params=None):
    # predict binarized semantic entropy labels using gp classifier
    # x should be (n_samples, hidden_dim) and y should be (n_samples,) binary labels
    
    #TODO: can probably just tune these hyperparams using MAP estimate or something like that 
    if params is None:
        params = {
            'constant': 0.0,
            'signal_variance': 6.0,
            'alpha_eps': 0.1,
            'strength': 5.0,
            'intercept_scaling': 1.0
        }
    
    # gp funcs 
    mean_func = gp.constant_mean
    cov_func = gp.cosine_kernel
    
    # negative log likelihood
    nll = gp.beta_gp_nll(
        mean_func=mean_func,
        cov_func=cov_func,
        params=params,
        x_train=X,
        y_train=y
    )
    
    return {
        'mean_func': mean_func,
        'cov_func': cov_func,
        'params': params,
        'nll': nll
    }

def predict_entropy(model, X_test):
    # predict entropy for test data
    # X_test should be (n_samples, hidden_dim) and model should be the output of train_gp_classifier
    predictions = gp.beta_gp_predict(
        mean_func=model['mean_func'],
        cov_func=model['cov_func'],
        params=model['params'],
        x_query=X_test,
        var_only=True
    )
    
    # get latent function since out of distribution
    mu, var = gp.get_latent_gp(predictions)
    
    # probs and quantiles
    probs, quantiles = gp.get_beta_quantiles(predictions, q=0.025)
    
    return {
        'probabilities': probs,
        'confidence_intervals': quantiles,
        'latent_mean': mu,
        'latent_var': var
    }

def train_all_action_dimensions(hidden_states_dict, entropies_dict):
    # train the GP classifier for each action dimension
    models = {}

    for action_dim in hidden_states_dict.keys():
        X = hidden_states_dict[action_dim]
        y = entropies_dict[action_dim]
        model = train_gp_classifier(X, y)
        models[action_dim] = model
    
    return models 

"""
# assumes in dictionaries like:
hidden_states_dict = {
    0: X_0,  # hidden states for action dimension 0
    1: X_1,  # hidden states for action dimension 1
    ...
    6: X_6   # hidden states for action dimension 6
}

entropies_dict = {
    0: y_0,  # semantic entropy or probability values for action dimension 0
    1: y_1,  # semantic entropy or probability values for action dimension 1
    ...
    6: y_6   # semantic entropy or probability values for action dimension 6
}

# train models for all action dimensions
results = train_all_action_dimensions(hidden_states_dict, entropies_dict)

# TODO: metrics?
"""