In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from custom_causal_inference import CustomCausalInference
import utils
import forward_models_causal_inference
import matplotlib.pyplot as plt
import pickle



In [2]:
%reload_ext autoreload

In [3]:
D = 250  # grid dimension 
NUM_SIM = 10000  # number of simulations
angle_gam_data_path = './base_bayesian_contour_1_circular_gam.pkl'
unif_fn_data_path='./uniform_model_base_inv_kappa_free.pkl'
p_commons = np.linspace(0, 1, num=20)
t_index = 0

causal_inference_estimator = forward_models_causal_inference.CausalEstimator(
    model=CustomCausalInference(decision_rule='mean'),
    angle_gam_data_path=angle_gam_data_path,
    unif_fn_data_path=unif_fn_data_path)
unif_map = causal_inference_estimator.unif_map

In [4]:
grid_sz = 100
optimal_kappa_pairs_filepath = f'./learned_data/optimal_kappa_pairs_{grid_sz}_t{t_index}.pkl'
min_error_for_idx_pc_filepath = f'./learned_data/min_error_for_idx_pc_{grid_sz}_t{t_index}.pkl'
min_error_for_idx_filepath = f'./learned_data/min_error_for_idx_{grid_sz}_t{t_index}.pkl'
s_ns_filepath = f'./learned_data/selected_s_n_{grid_sz}_t{t_index}.npy'
ts_filepath = f'./learned_data/selected_t_{grid_sz}_t{t_index}.npy'
r_ns_filepath = f'./learned_data/selected_r_n_{grid_sz}_t{t_index}.npy'

In [5]:
with open(optimal_kappa_pairs_filepath, 'rb') as f:
    optimal_kappa_pairs = pickle.load(f)
with open(min_error_for_idx_pc_filepath, 'rb') as f:
    min_error_for_idx_pc = pickle.load(f)
with open(min_error_for_idx_filepath, 'rb') as f:
    min_error_for_idx = pickle.load(f)
s_ns = np.load(s_ns_filepath)
ts = np.load(ts_filepath)
r_ns = np.load(r_ns_filepath)

In [6]:
grid_dim = int(np.sqrt(s_ns.shape[0]))
grid_dim

10

In [None]:
plt.scatter(s_ns, r_ns, label = 'r_n as fn of s_n')
#plt.scatter(ts, r_ns, label='r_n as fn of t')
plt.plot(s_ns, s_ns, label='s_n', c='r')
plt.legend()
plt.savefig(f'./figs/sn_vs_rn_{grid_sz}_t{t_index}.png')
plt.clf()
plt.scatter(np.linspace(np.min(s_ns), np.max(s_ns), num=grid_dim), s_ns[:grid_dim])
plt.title('Selected s_n as a function of uniform stimuli in angle space')
plt.savefig(f'./figs/selected_sn_{grid_sz}_t{t_index}.png')
plt.clf()

In [10]:
err_mat = {est: np.zeros_like(s_ns) for est in min_error_for_idx.keys()}
min_err_mat = np.zeros_like(s_ns)
optimal_kappa_pairs_arr = {est: np.zeros((*s_ns.shape, 2)) for est in optimal_kappa_pairs.keys()} # 2 kappa values
optimal_pc = {est: np.zeros_like(s_ns) for est in min_error_for_idx.keys()}
optimal_pc_min = np.zeros_like(s_ns)
err_threshold = np.deg2rad(5)

test_fwd_sim = True
fwd_errors = np.zeros_like(s_ns)
for est in optimal_kappa_pairs.keys():
    for key in min_error_for_idx_pc[est]:
        if min_error_for_idx_pc[est][key] == min_error_for_idx[est][key[0]]:
            # key[0] is the index of means s.t. causal inference is performed on (s_n[key[0]], t[key[0]])
            # key[1] is the optimal p_common for the pair
            optimal_pc[est][key[0]] = key[1]
            err_mat[est][int(key[0])] = min_error_for_idx[est][key[0]]
            optimal_kappa_pairs_arr[est][int(key[0]), :] = np.array(optimal_kappa_pairs[est][key]).reshape(2,)
            if min_error_for_idx_pc[est][key] >= err_threshold:
                print(f'{key[0]}: t, s_n, r_n = {np.round(s_ns[key[0]], 3), np.round(ts[key[0]], 3), np.round(r_ns[key[0]], 4)}', 
                    f'd/pi={np.round(utils.circular_dist(s_ns[key[0]], ts[key[0]]) / np.pi, 4)}', 
                    f'min_err={min_error_for_idx_pc[est][key]}',  
                    f'p_c={np.round(key[1], 3)}', 
                    f'optimal_kappa={optimal_kappa_pairs[est][key]}')
            if test_fwd_sim:
                mean_ests = []
                for _ in range(30):
                    usn, ut = unif_map.angle_space_to_unif_space([s_ns[key[0]]]), unif_map.angle_space_to_unif_space([ts[key[0]]])
                    t_samples, s_n_samples = causal_inference_estimator.get_vm_samples(
                                num_sim=10000,
                                mu_t=ut,
                                mu_s_n=usn,
                                kappa1=optimal_kappa_pairs[est][key][0],
                                kappa2=optimal_kappa_pairs[est][key][1])
                    _, _, mean_t_est, mean_sn_est = causal_inference_estimator.forward(t_samples=t_samples,
                        s_n_samples=s_n_samples, 
                                                                            kappa1=optimal_kappa_pairs[est][key][0], 
                                                                            kappa2=optimal_kappa_pairs[est][key][1], 
                                                                            p_common=key[1])
                    if est == 'sn':
                        mean_ests.append(mean_sn_est[0,0])
                    else:
                        mean_ests.append(mean_t_est[0,0])
                    assert mean_sn_est.shape == (1, 1)
                del t_samples, s_n_samples, mean_t_est, mean_sn_est
                fwd_errors[key[0]] = sum(map(lambda x: utils.circular_dist(x, r_ns[key[0]]), mean_ests)) / len(mean_ests)
                if (np.rad2deg(utils.circular_dist(max(mean_ests), min(mean_ests))) > 2):
                    print('Max dif across 10 estimates for index', key, np.rad2deg(utils.circular_dist(max(mean_ests), min(mean_ests))))
                if abs(fwd_errors[key[0]] - min_error_for_idx_pc[est][key]) > err_threshold:
                    print(f'Error forward vs fits {key[0]}', fwd_errors[key[0]], np.rad2deg(abs(fwd_errors[key[0]] - min_error_for_idx_pc[est][key])), np.rad2deg(utils.circular_dist(max(mean_ests), min(mean_ests))))


31: t, s_n, r_n = (-1.274, -0.694, -0.5054) d/pi=0.1847 min_err=0.16620621476297304 p_c=0.895 optimal_kappa=(array([400.]), array([1.1]))
32: t, s_n, r_n = (-0.971, -0.694, -0.3778) d/pi=0.0884 min_err=0.28996344957419495 p_c=0.895 optimal_kappa=(array([400.]), array([1.1]))
33: t, s_n, r_n = (-0.694, -0.694, -0.3649) d/pi=0.0 min_err=0.29013045386387315 p_c=0.842 optimal_kappa=(array([291.81624545]), array([1.1]))
34: t, s_n, r_n = (-0.517, -0.694, -0.3803) d/pi=0.0562 min_err=0.1519011371940664 p_c=0.0 optimal_kappa=(array([32.09636014]), array([200.2]))
Max dif across 10 estimates for index (39, 0.2631578947368421) 2.4006810588808443
41: t, s_n, r_n = (-1.274, -0.517, -0.2964) d/pi=0.241 min_err=0.2231142337559735 p_c=0.947 optimal_kappa=(array([400.]), array([1.1]))
42: t, s_n, r_n = (-0.971, -0.517, -0.1256) d/pi=0.1446 min_err=0.3905482690815374 p_c=0.947 optimal_kappa=(array([400.]), array([1.1]))
43: t, s_n, r_n = (-0.694, -0.517, -0.173) d/pi=0.0562 min_err=0.3346169394168941 

In [12]:
for est in optimal_kappa_pairs_arr.keys():
    print(f'Indices for {est} where optimal kappa1 > kappa2: {np.sum(optimal_kappa_pairs_arr[est][:, 0] >= optimal_kappa_pairs_arr[est][:, 1])} {optimal_kappa_pairs_arr[est][:, 0] >= optimal_kappa_pairs_arr[est][:, 1]}')

Indices for sn where optimal kappa1 > kappa2: 91 [False False  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True False False
 False  True  True  True  True  True  True  True  True False False  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True False  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True False  True  True  True  True  True  True  True
  True  True  True  True]
Indices for t where optimal kappa1 > kappa2: 42 [ True False False  True  True  True False  True False  True False False
  True False False  True  True  True False  True  True  True False False
 False False False False False False False  True  True False False False
 False False False False False  True  True  True False False False False
 

In [13]:
for est in optimal_kappa_pairs_arr.keys():
    print(f'Kappa values for estimate {est} where optimal kappa1 > kappa2: {optimal_kappa_pairs_arr[est][optimal_kappa_pairs_arr[est][:, 0] >= optimal_kappa_pairs_arr[est][:, 1]]}')

Kappa values for estimate sn where optimal kappa1 > kappa2: [[  6.6328955    2.50176241]
 [113.30729922   7.48256615]
 [212.89180278  50.89890563]
 [291.81624545   4.32661562]
 [212.89180278   3.29000977]
 [400.           1.90236977]
 [ 23.41559827   1.1       ]
 [ 60.305488     1.90236977]
 [400.           1.1       ]
 [113.30729922   1.1       ]
 [400.           4.32661562]
 [  4.83896665   1.1       ]
 [  4.83896665   1.1       ]
 [ 43.99530272   1.90236977]
 [113.30729922   1.44658451]
 [ 32.09636014   2.50176241]
 [  2.57544084   1.44658451]
 [291.81624545 200.2       ]
 [291.81624545   1.1       ]
 [291.81624545   1.1       ]
 [ 43.99530272   1.1       ]
 [  6.6328955    5.68983195]
 [  9.09187971   4.32661562]
 [ 43.99530272   5.68983195]
 [ 23.41559827   1.44658451]
 [ 82.6622766    1.44658451]
 [155.31321644   7.48256615]
 [ 23.41559827   3.29000977]
 [291.81624545   1.44658451]
 [400.           1.1       ]
 [400.           1.1       ]
 [291.81624545   1.1       ]
 [ 43.995302

In [15]:
usns = unif_map.angle_space_to_unif_space(s_ns)
uts = unif_map.angle_space_to_unif_space(ts)
for est in ['sn', 't']:
    plt.figure(figsize=(10, 8))
    plt.pcolormesh(usns.reshape((grid_dim, grid_dim)), uts.reshape((grid_dim, grid_dim)), 
                err_mat[est].reshape((grid_dim, grid_dim)), shading='auto', cmap='magma')
    plt.colorbar(label='circular distance to r_n(t, s_n)')
    plt.xlabel('s_n')
    plt.ylabel('t')
    plt.title(f'Heatmap of error for r_n(t, s_n) using estimate for {est}')
    plt.savefig(f'./figs/error_heatmap_uspace_{grid_sz}_t{t_index}_{est}.png')
    plt.clf()

<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>

In [16]:
plt.figure(figsize=(10, 8))
plt.pcolormesh(usns.reshape((grid_dim, grid_dim)), 
               uts.reshape((grid_dim, grid_dim)), 
               r_ns.reshape((grid_dim, grid_dim)), shading='auto', cmap='magma')
plt.colorbar(label='r_n(t, s_n)')
plt.xlabel('s_n')
plt.ylabel('t')
plt.title('Heatmap of r_n(t, s_n)')
plt.savefig(f'./figs/r_n_axis_u_space_{grid_sz}_t{t_index}.png')
plt.clf()

<Figure size 720x576 with 0 Axes>

In [19]:
for est in ['sn', 't']:
    optimal_kappa_ratio = optimal_kappa_pairs_arr[est][:, 1] / optimal_kappa_pairs_arr[est][:, 0]
    plt.figure(figsize=(10, 8))
    plt.pcolormesh(usns.reshape((grid_dim, grid_dim)), uts.reshape((grid_dim, grid_dim)), 
                optimal_kappa_ratio.reshape((grid_dim, grid_dim)), shading='auto', cmap='RdBu')
    plt.colorbar(label='optimal_kappa_ratio(t, s_n)')
    plt.xlabel('us_n')
    plt.ylabel('ut')
    plt.title(f'Heatmap of kappa ratio using estimate for {est}')
    plt.savefig(f'./figs/optimal_kappa_ratio_{grid_sz}_t{t_index}_{est}.png')
    plt.clf()

    for k_index in [0, 1]:
        plt.figure(figsize=(10, 8))
        plt.pcolormesh(usns.reshape((grid_dim, grid_dim)), uts.reshape((grid_dim, grid_dim)), 
                    optimal_kappa_pairs_arr[est][:, k_index].reshape((grid_dim, grid_dim)), shading='auto', cmap='RdBu')
        plt.colorbar(label='optimal_kappa_ratio(t, s_n)')
        plt.xlabel('us_n')
        plt.ylabel('ut')
        if k_index == 0:
            # plot latex doesn't support variables :(
            plt.title(f'Heatmap of $\kappa_1(t, s_n)$ for {est}')
        else:
            plt.title(f'Heatmap of $\kappa_2(t, s_n)$ for {est}')
        plt.savefig(f'./figs/kappa{k_index+1}_{grid_sz}_t{t_index}_{est}.png')
        plt.clf()

<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>

In [21]:
for est in ['sn', 't']:
    plt.figure(figsize=(10, 8))
    plt.pcolormesh(usns.reshape((grid_dim, grid_dim)), uts.reshape((grid_dim, grid_dim)), 
                optimal_pc[est].reshape((grid_dim, grid_dim)), shading='auto', cmap='RdBu')
    plt.colorbar(label='p_common(ut, us_n)')
    plt.xlabel('s_n')
    plt.ylabel('t')
    plt.title('Heatmap of optimal p_common(ut, us_n) for {est} estimate')
    plt.savefig(f'./figs/p_common_{grid_sz}_t{t_index}_{est}.png')
    plt.clf()

<Figure size 720x576 with 0 Axes>

<Figure size 720x576 with 0 Axes>

## Looking at optimal $\kappa$ pairs

In [24]:
import os
print(os.getcwd())

d:\AK_Q2_2024\causal_inference


In [None]:
import pickle

with open(f'./learned_data/task_metadata_{grid_sz}_t{t_index}.pkl', 'rb') as f:
    tasks_metadata = pickle.load(f)
num_tasks = len(tasks_metadata)
print(f'Number of tasks: {num_tasks}')
errors_dicts = {}
err_for_idx = {est: {} for est in ['sn', 't']}
nums_bad = 0
max_error = 0
high_err_idx = []
for task_idx in range(num_tasks):
    try:
        with open(f'./learned_data/optimal_kappa_errors/errors_dict_{task_idx}_{grid_sz}_t{t_index}.pkl', 'rb') as f:
            errors_dicts[task_idx] = pickle.load(f)
            mean_idx = tasks_metadata[task_idx]['mean_indices'][0]
            min_error_across_pc = {est: min([errors_dicts[task_idx][p_c][f'errors_{est}'].min() for p_c in errors_dicts[task_idx].keys()]) for est in ['sn', 't']}
            for est in ['sn', 't']:
                if mean_idx not in err_for_idx[est]:
                    err_for_idx[est][mean_idx] = min_error_across_pc[est]
                else:   
                    err_for_idx[est][mean_idx] = min(err_for_idx[est][mean_idx], min_error_across_pc[est])
    except:
        nums_bad += 1
        print(f'Error for task {task_idx}')
print(f'Number of bad tasks: {nums_bad}')
for mean_idx in err_for_idx:
    if max(err_for_idx['sn'][mean_idx], err_for_idx['t'][mean_idx]) > .1:
        high_err_idx.append(mean_idx)
        print(mean_idx, err_for_idx['sn'][mean_idx], err_for_idx['t'][mean_idx])
    max_error = max(min(err_for_idx['sn'][mean_idx], err_for_idx['t'][mean_idx]), max_error)
print(f'Max error: {max_error}')
print(high_err_idx, '\n', list(zip(zip(s_ns[high_err_idx], ts[high_err_idx]), r_ns[high_err_idx])))

Number of tasks: 100
> [1;32mc:\users\ana\appdata\local\temp\ipykernel_3688\1803793318.py[0m(18)[0;36m<module>[1;34m()[0m

{0.0: {'errors_t': array([0.00038898, 0.00050798, 0.00057388, 0.00058591, 0.00058645,
       0.00087152, 0.00095575, 0.00111948, 0.00136583, 0.00144638,
       0.00145817, 0.00147318, 0.00154525, 0.00159554, 0.00167705,
       0.00178248, 0.00190052, 0.00191101, 0.00191308, 0.00197094,
       0.00199955, 0.00204998, 0.00221001, 0.00227272, 0.00228538,
       0.00230836, 0.00231887, 0.00234016, 0.00237506, 0.00247469,
       0.00255754, 0.00256296, 0.00263723, 0.00270272, 0.00283274,
       0.00285688, 0.00290814, 0.00295934, 0.00298432, 0.00299032,
       0.00302353, 0.00322   , 0.00335554, 0.00335833, 0.00337418,
       0.00339052, 0.00341889, 0.00342883, 0.00346428, 0.00348363]), 'optimal_kappa1_t': array([  4.839 ,   4.839 ,  12.4625,   4.839 ,  12.4625,  43.9953,
         3.5302,   6.6329,   3.5302,   4.839 ,   3.5302,  43.9953,
        23.4156,  23.4156, 