In [37]:
import jax.numpy as jnp
import gp
import jax
import glob 
import numpy as np
from PIL import Image
from IPython.display import display
import matplotlib.pyplot as plt
import polars as pl

In [2]:
vla_run_data = '/Users/jeremiahetiosaomeike/Desktop/Research_Projs/vla/LIBERO/probes/vla_run_data'
npys = glob.glob(vla_run_data + '/*.npy')

In [9]:
def get_semantic_entropy(prob_vector, k):
    # essentially an action level approximation of the trajectory level semantic entropy  
    # principlied way of choosing k? probably something to do with dirichlet process...
    def entropy(p):
        return np.sum(-p * np.log(p + 1e-10))
    
    ses = []
    for i in range(k):
        split = prob_vector.shape[0] // (2 ** (i+1))
        clustered_prob_vector = np.mean(prob_vector.reshape(-1, split), axis=1) # get normalized probabilities of belonging to a cluster for each cluster
        ith_semantic_entropy = entropy(clustered_prob_vector) # entropy of the ith level of the hierarchy
        ses.append(ith_semantic_entropy)
    
    return np.mean(np.array(ses))

In [13]:
import os

test_dir = '/Users/jeremiahetiosaomeike/Desktop/Research_Projs/vla/LIBERO/probes/test_data'
os.makedirs(test_dir)
for num_trial in range(len(npys)):
    data_path = npys[num_trial]
    data = np.load(data_path, allow_pickle=True)
    for dct in data:
        ses = []
        probs = dct['probs']
        for prob_vec in probs:
            se = get_semantic_entropy(prob_vec, k=1)
            ses.append(se)
        dct['se'] = ses
    np.save(test_dir + f'/episode_test_{num_trial}.npy', data)

In [34]:
def find_best_split(s_entropy, 
                    plot=True, 
                    label="",
                    n_splits=100):
    # find the best split threshold for the semantic entropy

    splits = np.linspace(1e-10, s_entropy.max(), n_splits)
    split_mses = []
    
    for split in splits:
        low_idxs, high_idxs = s_entropy < split, s_entropy >= split
        
        if not any(low_idxs) or not any(high_idxs):
            split_mses.append(float('inf'))
            continue
            
        low_mean = np.mean(s_entropy[low_idxs])
        high_mean = np.mean(s_entropy[high_idxs])
        
        mse = np.sum((s_entropy[low_idxs] - low_mean)**2) + np.sum((s_entropy[high_idxs] - high_mean)**2)
        split_mses.append(mse)
    
    split_mses = np.array(split_mses)
    best_split = splits[np.argmin(split_mses)]
    
    if plot:
        plt.plot(splits, split_mses, label=label)
        plt.xlabel('Split Thresholds')
        plt.ylabel('Mean Squared Error')
        plt.title('MSE vs Split Threshold')
        if label:
            plt.legend()
    
    return best_split

In [49]:
new_npys = glob.glob(test_dir + '/*.npy')
ses = []
hs_bgs = []
hs_pgs = []
for num_trial in range(len(new_npys)):
    data_path = new_npys[num_trial]
    data = np.load(data_path, allow_pickle=True)
    for dct in data:
        hs_bgs.append(dct['hidden_state_before_gen'])
        hs_pgs.append(dct['hidden_state_post_gen'])
        ses.append(dct['se'])

In [None]:
#TODO: Confusion matrix approach 

SUCCESS_TRIALS = {
    0: 1, 1: 0, 2: 0, 3: 1, 4: 0, 5: 0, 6: 0, 7: 1, 
    8: 0, 9: 1, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0,
    16: 0, 17: 1, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0,
    24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 1, 31: 1,
    32: 1, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1, 39: 1
}
SUCCESS_TIMESTEPS = {
    0: 23, 3: 21, 7: 23, 9: 25, 17: 27, 30: 32, 31: 41,
    32: 25, 33: 38, 34: 24, 35: 31, 36: 22, 37: 21, 38: 16, 39: 16
}

ses = np.asarray(ses)
hs_bgs = np.asarray(hs_bgs)
hs_pgs = np.asarray(hs_pgs)
thresholds = {}
x_ses, y_ses, z_ses, rot_x_ses, rot_y_ses, rot_z_ses, done_ses = ses.T
thresholds['x'] = find_best_split(x_ses, plot=False)
thresholds['y'] = find_best_split(y_ses, plot=False)
thresholds['z'] = find_best_split(z_ses, plot=False)
thresholds['rot_x'] = find_best_split(rot_x_ses, plot=False)
thresholds['rot_y'] = find_best_split(rot_y_ses, plot=False)
thresholds['rot_z'] = find_best_split(rot_z_ses, plot=False)
thresholds['done'] = find_best_split(done_ses, plot=False)

In [56]:
def binarize_s_entropy(s_entropy, threshold):
    return (s_entropy >= threshold).astype(int)

In [None]:
num_trials = 10
num_train = 128
num_eval = 64
random_idxs = np.random.choice(len(x_ses), num_train+num_eval, replace=False)

# print(np.asarray(hs_bgs).squeeze(1).shape)
# print(x_ses.shape)

sampled_x_ses = x_ses[random_idxs]
sampled_hs_bgs = np.asarray(hs_bgs).squeeze(1)[random_idxs]
sampled_hs_pgs = np.asarray(hs_pgs).squeeze(1)[random_idxs]

y_obs = binarize_s_entropy(sampled_x_ses[:num_train], thresholds['x'])
x_obs_bgs = sampled_hs_bgs[:num_train]
x_query = sampled_hs_bgs[num_train:]

# print(x_query.shape)
# print(x_obs_bgs.shape)
# print(y_obs.shape)

(64, 4096)
(128, 4096)
(128,)


In [None]:
import probabilistic_probe

measures = probabilistic_probe.gpp(x_query, x_obs_bgs, y_obs)
measures