In [7]:
import jax.numpy as jnp
import gp
import jax
import glob 
import numpy as np
from PIL import Image
from IPython.display import display
import matplotlib.pyplot as plt
import polars as pl
import os
import probabilistic_probe

In [8]:
vla_run_data = '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data'
npys = glob.glob(vla_run_data + '/*.npy')
print(npys)

['/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_33.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_27.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_26.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_32.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_24.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_30.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_18.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_19.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_31.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_25.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_21.npy', '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/episode_test_

In [10]:
def compute_semantic_entropy(prob_vector, center_bin_half_size=5):
    # Compute the semantic entropy by splitting the prob_vector into three arrays
    def entropy(p):
        return np.sum(-p * np.log(p + 1e-10))
    
    center_index = len(prob_vector) // 2
    center_bin = prob_vector[center_index - center_bin_half_size:center_index + center_bin_half_size]
    left_bin = prob_vector[:center_index - center_bin_half_size]
    right_bin = prob_vector[center_index + center_bin_half_size:]
    
    center_entropy = entropy(center_bin)
    left_entropy = entropy(left_bin)
    right_entropy = entropy(right_bin)
    
    return np.mean([center_entropy, left_entropy, right_entropy])

def find_best_split(s_entropy, 
                    plot=True, 
                    label="",
                    n_splits=100):
    # find the best split threshold for the semantic entropy

    splits = np.linspace(1e-10, s_entropy.max(), n_splits)
    split_mses = []
    
    for split in splits:
        low_idxs, high_idxs = s_entropy < split, s_entropy >= split
        
        if not any(low_idxs) or not any(high_idxs):
            split_mses.append(float('inf'))
            continue
            
        low_mean = np.mean(s_entropy[low_idxs])
        high_mean = np.mean(s_entropy[high_idxs])
        
        mse = np.sum((s_entropy[low_idxs] - low_mean)**2) + np.sum((s_entropy[high_idxs] - high_mean)**2)
        split_mses.append(mse)
    
    split_mses = np.array(split_mses)
    best_split = splits[np.argmin(split_mses)]
    
    if plot:
        plt.plot(splits, split_mses, label=label)
        plt.xlabel('Split Thresholds')
        plt.ylabel('Mean Squared Error')
        plt.title('MSE vs Split Threshold')
        if label:
            plt.legend()
    
    return best_split

def binarize_s_entropy(s_entropy, threshold):
    return (s_entropy >= threshold).astype(int)

In [11]:
def get_se_into_data(save_dir, npy_paths, verbose=False):

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    for num_trial in range(len(npy_paths)):
        data_path = npy_paths[num_trial]
        name = data_path.split('/')[-1]
        if verbose:
            print(name, save_dir)
        data = np.load(data_path, allow_pickle=True)
        for dct in data:
            ses = []
            probs = dct['probs']
            for prob_vec in probs:
                se = compute_semantic_entropy(prob_vec)
                ses.append(se)
            dct['se'] = ses
        s_path = save_dir + '/' + name
        if verbose:
            print(s_path)
        np.save(s_path, data)

In [14]:
save_dir = '/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se'
vla_run_data = '/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data'
npy_paths = glob.glob(vla_run_data + '/*.npy')
get_se_into_data(save_dir=save_dir, npy_paths=npy_paths, verbose=True)

episode_test_33.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se/episode_test_33.npy
episode_test_27.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se/episode_test_27.npy
episode_test_26.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se/episode_test_26.npy
episode_test_32.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se/episode_test_32.npy
episode_test_24.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBER

In [15]:
save_dir_val = '/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se'
npy_paths_val = glob.glob('/Users/jeremiahetiosaomeike/Downloads/vla_finetuned_run_data/validate' + '/*.npy')
get_se_into_data(save_dir=save_dir_val, npy_paths=npy_paths_val, verbose=True)

episode_validate_8.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se/episode_validate_8.npy
episode_validate_9.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se/episode_validate_9.npy
episode_validate_2.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se/episode_validate_2.npy
episode_validate_3.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se
/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se/episode_validate_3.npy
episode_validate_1.npy /Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_

In [None]:
def get_probe_vars(data_dir):
    new_npys = glob.glob(data_dir + '/*.npy')
    ses = []
    hs_bgs = []
    hs_pgs = []
    for num_trial in range(len(new_npys)):
        data_path = new_npys[num_trial]
        data = np.load(data_path, allow_pickle=True)
        for dct in data:
            hs_bgs.append(dct['hidden_state_before_gen'])
            hs_pgs.append(dct['hidden_state_post_gen'])
            ses.append(dct['se'])
    
    return np.asarray(ses), np.asarray(hs_bgs), np.asarray(hs_pgs)

data_dir = '/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/finetune_vla_data_se'
data_dir_val = '/Users/jeremiahetiosaomeike/research_projects/vla/LIBERO/probes/val_finetune_vla_data_se'
ses, hs_bgs, hs_pgs = get_probe_vars(data_dir)
# ses_val, hs_bgs_val, hs_pgs_val = get_probe_vars(data_dir)

In [17]:
print(f'Semantic Entropy Shape: {ses.shape} \n Hidden States Before Generation Shape: {hs_bgs.shape} \n Hidden States Post Generation Shape: {hs_pgs.shape}')

Semantic Entropy Shape: (1438, 7) 
 Hidden States Before Generation Shape: (1438, 1, 4096) 
 Hidden States Post Generation Shape: (1438, 1, 4096)


In [None]:
hs_bgs = hs_bgs.squeeze(axis=1)
hs_pgs = hs_pgs.squeeze(axis=1)

# hs_bgs_val = hs_bgs_val.squeeze(axis=1)
# hs_pgs_val = hs_pgs_val.squeeze(axis=1)

In [27]:
# SUCCESS_TRIALS_TRAIN = {
#     0: 1, 1: 0, 2: 0, 3: 1, 4: 0, 5: 0, 6: 0, 7: 1, 
#     8: 0, 9: 1, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0,
#     16: 0, 17: 1, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0,
#     24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 1, 31: 1,
#     32: 1, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1, 39: 1
# }

# SUCCESS_TIMESTEPS_TRAIN = {
#     0: 23, 3: 21, 7: 23, 9: 25, 17: 27, 30: 32, 31: 41,
#     32: 25, 33: 38, 34: 24, 35: 31, 36: 22, 37: 21, 38: 16, 39: 16
# }

ses = np.asarray(ses)
hs_bgs = np.asarray(hs_bgs)
hs_pgs = np.asarray(hs_pgs)
thresholds = {}
avg_ses_pos = ses[:, :3].mean(axis=1)
best_thresh_avg_pos_ses = find_best_split(avg_ses_pos, plot=False)

In [None]:
SUCCESS_TRIALS_VAL = {
    0: 1, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 1, 7: 1, 
    8: 1, 9: 1}

average_queries_dict_bgs = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}
average_queries_dict_pgs = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}

for npy_path in glob.glob(data_dir_val + '/*.npy'):
    npy_path_num = int(npy_path.split('.')[0][-1])
    data = np.load(npy_path, allow_pickle=True) 
    ses_val = []
    hs_bgs_val = []
    hs_pgs_val = []

    for step in data:
        ses_val.append(step['se'])
        hs_bgs_val.append(step['hidden_state_before_gen'])
        hs_pgs_val.append(step['hidden_state_post_gen'])

    ses_val = np.asarray(ses_val)
    query_hs_bgs_val = np.asarray(hs_bgs_val).squeeze(axis=1)
    query_hs_pgs_val = np.asarray(hs_pgs_val).squeeze(axis=1)

    avg_ses_pos_val = ses_val[:, :3].mean(axis=1)
    gt_labels_val = binarize_s_entropy(avg_ses_pos_val, best_thresh_avg_pos_ses)

    num_iters = 10
    num_data = 256

    for _ in range(num_iters):
        # Sample random indexes 
        random_idxs = np.random.choice(len(avg_ses_pos), num_data, replace=False)
        sampled_avg_ses = avg_ses_pos[random_idxs] # semantic entropies here 

        sampled_hs_bgs = np.asarray(hs_bgs)[random_idxs] # hidden states before generation
        sampled_hs_pgs = np.asarray(hs_pgs)[random_idxs] # hidden states post generation
        
        label_obs = binarize_s_entropy(sampled_avg_ses, best_thresh_avg_pos_ses) # binarized labels 

        measures_hs_bgs = probabilistic_probe.gpp(query_hs_bgs_val, sampled_hs_bgs, label_obs)
        measures_hs_pgs = probabilistic_probe.gpp(query_hs_pgs_val, sampled_hs_pgs, label_obs)

        mean_label_query_pred_bgs = measures_hs_bgs['bernoulli_mu']
        mean_label_query_pred_pgs = measures_hs_pgs['bernoulli_mu']

        binarized_mean_label_query_pred_bgs = (mean_label_query_pred_bgs >= .5).astype(int)
        binarized_mean_label_query_pred_pgs = (mean_label_query_pred_pgs >= .5).astype(int)

        num_helps_bgs = np.count_nonzero(binarized_mean_label_query_pred_bgs)
        num_helps_pgs = np.count_nonzero(binarized_mean_label_query_pred_pgs)
        average_queries_dict_bgs[npy_path_num].append(num_helps_bgs)
        average_queries_dict_pgs[npy_path_num].append(num_helps_pgs)




(256, 4096)
(256, 4096)


KeyError: '8'