In [None]:
import jax.numpy as jnp
import gp
import jax
import glob 
import numpy as np
from PIL import Image
from IPython.display import display


In [None]:
def train_gp_classifier(X, y, params=None):
    # predict binarized semantic entropy labels using gp classifier
    # x should be (n_samples, hidden_dim) and y should be (n_samples,) binary labels
    
    #TODO: can probably just tune these hyperparams using MAP estimate or something like that 
    if params is None:
        params = {
            'constant': 0.0,
            'signal_variance': 6.0,
            'alpha_eps': 0.1,
            'strength': 5.0,
            'intercept_scaling': 1.0
        }
    
    # gp funcs 
    mean_func = gp.constant_mean
    cov_func = gp.cosine_kernel
    
    # negative log likelihood
    nll = gp.beta_gp_nll(
        mean_func=mean_func,
        cov_func=cov_func,
        params=params,
        x_train=X,
        y_train=y
    )
    
    return {
        'mean_func': mean_func,
        'cov_func': cov_func,
        'params': params,
        'nll': nll
    }

def predict_entropy(model, X_test):
    # predict entropy for test data
    # X_test should be (n_samples, hidden_dim) and model should be the output of train_gp_classifier
    predictions = gp.beta_gp_predict(
        mean_func=model['mean_func'],
        cov_func=model['cov_func'],
        params=model['params'],
        x_query=X_test,
        var_only=True
    )
    
    # get latent function since out of distribution
    mu, var = gp.get_latent_gp(predictions)
    
    # probs and quantiles
    probs, quantiles = gp.get_beta_quantiles(predictions, q=0.025)
    
    return {
        'probabilities': probs,
        'confidence_intervals': quantiles,
        'latent_mean': mu,
        'latent_var': var
    }

def train_all_action_dimensions(hidden_states_dict, entropies_dict):
    # train the GP classifier for each action dimension
    models = {}

    for action_dim in hidden_states_dict.keys():
        X = hidden_states_dict[action_dim]
        y = entropies_dict[action_dim]
        model = train_gp_classifier(X, y)
        models[action_dim] = model
    
    return models 

"""
# assumes in dictionaries like:
hidden_states_dict = {
    0: X_0,  # hidden states for action dimension 0
    1: X_1,  # hidden states for action dimension 1
    ...
    6: X_6   # hidden states for action dimension 6
}

entropies_dict = {
    0: y_0,  # semantic entropy or probability values for action dimension 0
    1: y_1,  # semantic entropy or probability values for action dimension 1
    ...
    6: y_6   # semantic entropy or probability values for action dimension 6
}

# train models for all action dimensions
results = train_all_action_dimensions(hidden_states_dict, entropies_dict)

# TODO: metrics?
"""

In [None]:
vla_run_data = '/Users/jeremiahetiosaomeike/Desktop/Research_Projs/vla/LIBERO/probes/vla_run_data'
npys = glob.glob(vla_run_data + '/*.npy')
# dat = np.load(npys, allow_pickle=True)

In [None]:
# for num_trial in range(len(npys)):
#     data_path = npys[num_trial]
#     dat = np.load(data_path, allow_pickle=True)
#     print(f'Trial: {num_trial}, Language Instruction: {dat[0]["language_instruction"]}')
# print(dat[0]['language_instruction'])

# group trials by language instruction
grouped_trial_paths = {}
for num_trial in range(len(npys)):
    data_path = npys[num_trial]
    dat = np.load(data_path, allow_pickle=True)
    prompt = dat[0]["language_instruction"]

    if prompt not in grouped_trial_paths.keys():
        grouped_trial_paths[prompt] = [data_path]
    else:
        grouped_trial_paths[prompt].append(data_path)

In [None]:
grouped_trial_paths

In [None]:
for num_trial in range(len(npys)):
    data_path = npys[num_trial]
    dat = np.load(data_path, allow_pickle=True)
    image = dat[0]["image"]
    img = Image.fromarray(image)
    display(img)

In [None]:
# Focusing on trials 0, 1, 2 and 3 
# img = Image.fromarray(dat[0]['image'])
paths_white_can = grouped_trial_paths['lift the white can']
images = []
for f_path in paths_white_can:
    dat = np.load(f_path, allow_pickle=True)
    images.append(Image.fromarray(dat[0]['image']))

for img in images:
    display(img)

In [None]:
paths = grouped_trial_paths['lift the red can']
images = []
for f_path in paths:
    dat = np.load(f_path, allow_pickle=True)
    images.append(Image.fromarray(dat[0]['image']))

for img in images:
    display(img)

In [None]:
lift_dr_pep_can_paths = grouped_trial_paths['lift the dr pepper can']
images = []
for f_path in paths:
    dat = np.load(f_path, allow_pickle=True)
    images.append(Image.fromarray(dat[0]['image']))

print(paths)
for img in images:
    display(img)

# first path in this has no coca cola can, second path does

In [None]:
import probabilistic_probe

datas = [np.load(path, allow_pickle=True) for path in lift_dr_pep_can_paths]
hs_bg = [data[0]['hidden_state_before_gen'] for data in datas]
hs_pg = [data[0]['hidden_state_post_gen'] for data in datas]
params = gp.set_default_params({'alpha_eps': .1, 'strength': 5}, warp_func=None)

query_bg = hs_bg[0]
x_obs = np.asarray(hs_bg[1:]).squeeze(1)
y_obs = np.array([0, 1, 1, 1]) # the second obs 

measures = probabilistic_probe.gpp(query_bg, x_obs, y_obs)
measures
# predictions = gp.beta_gp_predict(mean_func=gp.constant_mean, 
#                                  cov_func=gp.cosine_kernel,
#                                  params=params,
                                 
#                                  )

In [None]:
# query = second_hs_bg
# x_obs = first_hs_bg 
# y_obs = np.array([0]) # the second obs 

# measures = probabilistic_probe.gpp(query, x_obs, y_obs)
# measures
x_obs = np.asarray(hs_bg[1:]).squeeze(1)
x_obs.shape