In [1]:
import argparse
import dill
import jax
import jax.numpy as np
from jax import lax
from jax.scipy.special import logsumexp
import tqdm

In [2]:
RUNS=1
input_dir = "../Datasets"
algos = ["optimistic", "softmax", "igw", "greedy", "ucb", "ts"]

filename = f'{input_dir}/dataset_1_ucb.dill'
with open(filename, 'rb') as f:
    data = dill.load(f)
            
    data_x = np.array(data['x'])  
    data_a = np.array(data['a'])  
#    rhox = data['rhox']           
#    betas_mean = data['betas_mean']
#    betas_cov=data['betas_cov']
    T,A,K = data_x.shape    

# A dictionary of hyperparameters of the simulation
hyper = dict()
 
############### ADJUST HYPER AND THE LOADING
hyper['sample_start'] = 1_000
hyper['sample_stop'] = 2_000
hyper['sample_step'] = 1
hyper['iter'] = 100

T = data_x.shape[0] # time dimension, total number of steps
A = data_x.shape[1] # this is not used anywhere !!! Each action has it's own context, this is unusual, interesting setup
K = data_x.shape[2] # feature space dimension, context shape
alpha = 20 # exploration parameter in the policy definition
sigma = .10 # variance of the rewards

# PC-BICB UTILITIES

In [3]:
## xs ~ constructs a matrix where each row represents the context vector selected by the action at each time step, 
## padded with zeros for time steps beyond the current one, this is reflected with t0, t1
_xs = lambda t0, t1: jax.lax.select(t1 <= t0, data_x[t1,data_a[t1]], np.zeros(K)) 
_xs = jax.vmap(jax.vmap(_xs, in_axes=(None,0)), in_axes=(0,None)) 
xs = _xs(np.arange(T-1), np.arange(T-1)) 


# Belief updates utilities

## cumulative sum of outer products of contexts selected by actions up to time t
## later used for covariance matrix updates
__betas_N = lambda t: np.einsum('i,j->ij', data_x[t,data_a[t]], data_x[t,data_a[t]])
__betas_N = jax.vmap(__betas_N)
_betas_N = __betas_N(np.arange(T-1)).cumsum(axis=0)
_betas_N = np.concatenate((np.zeros((K,K))[None,...], _betas_N))


## cumulative sum of dot product of rewards and context for each time step
## later used for posterior mean updates
__betas_y = lambda r, t: r * data_x[t,data_a[t]]
__betas_y = jax.vmap(__betas_y)
_betas_y = lambda rs: np.concatenate((np.zeros(K)[None,...], __betas_y(rs, np.arange(T-1)).cumsum(axis=0)))
_betas_y = jax.jit(_betas_y)
_BETAS_Y = jax.jit(jax.vmap(_betas_y))

@jax.jit
def decode(params):
    '''
    initialize parameters 
    '''
    beta0 = np.exp(20 * params['beta0']) # WHY 20 HERE for? (it is 1 in the end)
    beta0_y = -np.ones(K)/K * beta0  # vector
    beta0_N = np.eye(K) * beta0      # matrix
    return beta0_y, beta0_N


# Implement likelihood functions for each policy
def optimistic_like(rho, x, beta_cov, a):
    alpha=20
    q = alpha * np.einsum('ij,j->i', x, rho) + np.einsum('ij,jk,ki->i', x, beta_cov, x.T)
    
    return q[a] - logsumexp(q)    

def softmax_like(rho, x,  beta_cov, a):
    alpha=20
    q = alpha * np.einsum('ij,j->i', x, rho)
    
    return q[a] - logsumexp(q) 

def ts_like(rho, x, beta_cov, a, key):
    num_samples=10
    alpha=20
    K = rho.shape[0]  # Number of features
    A = x.shape[0]  # Number of actions

    # Result(num_samples, K) #sampled_rhos = np.random.multivariate_normal(rho, beta_cov, size=num_samples)
    sampled_rhos = jax.random.multivariate_normal(key, rho, beta_cov, (num_samples,))
        
    # x to (1, A, K) and broadcasted 
    # Result shape (num_samples, A)
    scores = np.dot(sampled_rhos, x.T)

    # Result shape (num_samples,)
    best_actions = np.argmax(scores, axis=1)
    
#    counts = np.bincount(best_actions, minlength=A)
#    freq = best_action_counts / num_samples
    freq=np.zeros(A)
    for action in range(A):
        freq=freq.at[action].set(np.sum(best_actions == action) / num_samples)
        
    return np.log(freq[a])
    

def igw_like(rho, x, beta_cov, a):
    alpha=20
    erewards = np.einsum('ij,j->i', x, rho)  # prediction
    best_arm = np.argmax(erewards)
    gaps = erewards[best_arm] - erewards  # Gaps

    A = x.shape[0]  # x is (A, K)
    # Compute the prob for non-best 
    pi = 1 / (A + alpha * gaps)
    pi=pi.at[best_arm].set(0)  # temp

    # Adjust the best arm
    pi_best = 1 - np.sum(pi)
    pi=pi.at[best_arm].set(pi_best)
    
    return np.log(pi[a])

In [4]:
def _sample_rhos_optimistic(beta_mean, beta_cov, x, a, key): #alpha,
    '''
    Sample two rhos, do a round of Metropolis Hastings
    '''
    keys = jax.random.split(key, 3) # pseud random seed controlling
    
    rho  = jax.random.multivariate_normal(keys[0], beta_mean, beta_cov) # sample rho'
    _rho = jax.random.multivariate_normal(keys[1], beta_mean, beta_cov) # sample rho''
    
    like  = optimistic_like(rho, x, beta_cov, a) # calc like scores rho'  #, alpha
    _like = optimistic_like(_rho, x, beta_cov, a) # calc like scores rho'' #, alpha
    
    cond  = _like - like > np.log(jax.random.uniform(keys[2])) # Metropolis Hastings
    return jax.lax.select(cond, _rho, rho) # _rho if cond true rho otherwise

_sample_rhos_optimistic = jax.vmap(_sample_rhos_optimistic)


def _sample_rhos_softmax(beta_mean, beta_cov, x, a, key): #alpha,
    '''
    Sample two rhos, do a round of Metropolis Hastings
    '''
    keys = jax.random.split(key, 3) # pseud random seed controlling
    
    rho  = jax.random.multivariate_normal(keys[0], beta_mean, beta_cov) # sample rho'
    _rho = jax.random.multivariate_normal(keys[1], beta_mean, beta_cov) # sample rho''
    
    like  = softmax_like(rho, x, beta_cov, a) # calc like scores rho'  #, alpha
    _like = softmax_like(_rho, x, beta_cov, a) # calc like scores rho'' #, alpha
    
    cond  = _like - like > np.log(jax.random.uniform(keys[2])) # Metropolis Hastings
    return jax.lax.select(cond, _rho, rho) # _rho if cond true rho otherwise

_sample_rhos_softmax = jax.vmap(_sample_rhos_softmax)


def _sample_rhos_ts(beta_mean, beta_cov, x, a, key): #alpha
    '''
    Sample two rhos, do a round of Metropolis Hastings
    '''
    keys = jax.random.split(key, 3) # pseud random seed controlling
    
    rho  = jax.random.multivariate_normal(keys[0], beta_mean, beta_cov) # sample rho'
    _rho = jax.random.multivariate_normal(keys[1], beta_mean, beta_cov) # sample rho''
    
    like  = ts_like(rho, x, beta_cov, a, keys[0]) # calc like scores rho' #, alpha
    _like = ts_like(_rho, x, beta_cov, a, keys[1]) # calc like scores rho'' #, alpha
    
    cond  = _like - like > np.log(jax.random.uniform(keys[2])) # Metropolis Hastings
    return jax.lax.select(cond, _rho, rho) # _rho if cond true rho otherwise

_sample_rhos_ts = jax.vmap(_sample_rhos_ts)


def _sample_rhos_igw(beta_mean, beta_cov, x, a, key): #, alpha
    '''
    Sample two rhos, do a round of Metropolis Hastings
    '''
    print(key, key.shape)
    keys = jax.random.split(key, 3) # pseud random seed controlling
    
    print(keys[0], keys[0].shape)
    rho  = jax.random.multivariate_normal(keys[0], beta_mean, beta_cov) # sample rho'
    _rho = jax.random.multivariate_normal(keys[1], beta_mean, beta_cov) # sample rho''
    
    like  = igw_like(rho, x, beta_cov, a) # calc like scores rho' #alpha, 
    _like = igw_like(_rho, x, beta_cov, a) # calc like scores rho'' #, alpha
    
    cond  = _like - like > np.log(jax.random.uniform(keys[2])) # Metropolis Hastings
    return jax.lax.select(cond, _rho, rho) # _rho if cond true rho otherwise

_sample_rhos_igw = jax.vmap(_sample_rhos_igw)

def _sample_rs_init(rhox, key):
    '''
    expected reward for each action at the last time step T-1 + add gaussian noise and return
    '''
    mean = np.einsum('ij,j->i', xs[-1], rhox)
    rs = mean + sigma * jax.random.normal(key, shape=(T-1,))
    return rs

def _sample_rs(rhox, rhos, beta0_y, betas_invN, key):
    '''
    Lemma 1 sampling
    '''
    invcov = np.eye(T-1) # initialize for inverse covariance 1/σ2 I
    
        # outer product of xs, weighted by inverse covariance of believes?  # sum over all time steps 'a' 
        # cross-product 'bc' and 'de' indices  # inverse covariance weights `cd' indices # \sum^T_t=2 X^T_{t−1} Σ_t X_{t−1}
    invcov = invcov + np.einsum('abc,acd,aed->be', xs, betas_invN[1:], xs) 
    
        # dot product of the last context vectors xs[-1] with the parameters rhox  # (X^T_t \rhox)
    invcov_at_mean = np.einsum('ij,j->i', xs[-1], rhox)
    
        # 1) product of each context vector with its corresponding parameter estimate adjustment
        # 2) weighted sum of these adjustments across all contexts xs and adjusted parameters rhos[1:] - ...
        # \sum^T_t=2 X^T_t-1(\rho_t - C_t C^{-1}_1 \mu_1)
    invcov_at_mean = invcov_at_mean + np.einsum('ijk,ik->j', xs, rhos[1:] - np.einsum('ijk,k->ij', betas_invN[1:], beta0_y))
    cov = np.linalg.inv(invcov)
    mean = cov @ invcov_at_mean

    rs = jax.random.multivariate_normal(key, mean, cov * sigma**2) # sampling the reward from a multivariate normal
    return rs

def _sample_optimistic(arg0, arg1):
    (rhos, rs, rhox, beta0_y, betas_invN), key = arg0, arg1 # , alpha
    keys = jax.random.split(key, 2)
        
    # Calculate beta mean and covariance
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, beta0_y + _betas_y(rs))
    betas_cov = betas_invN * sigma**2
        
    # Sample rhos and rewards using the specified policy
    rhos = _sample_rhos_optimistic(betas_mean, betas_cov, data_x, data_a, jax.random.split(keys[0], T)) #policy alpha
    rs = _sample_rs(rhox, rhos, beta0_y, betas_invN, keys[1])
        
    return (rhos, rs, rhox, beta0_y, betas_invN), (rhos, rs)

def _sample_softmax(arg0, arg1):
    (rhos, rs, rhox, beta0_y, betas_invN), key = arg0, arg1 # , alpha
    keys = jax.random.split(key, 2)
        
    # Calculate beta mean and covariance
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, beta0_y + _betas_y(rs))
    betas_cov = betas_invN * sigma**2
        
    # Sample rhos and rewards using the specified policy
    rhos = _sample_rhos_softmax(betas_mean, betas_cov, data_x, data_a, jax.random.split(keys[0], T)) #policy alpha
    rs = _sample_rs(rhox, rhos, beta0_y, betas_invN, keys[1])
    
    return (rhos, rs, rhox, beta0_y, betas_invN), (rhos, rs)
    

def _sample_ts(arg0, arg1):
    (rhos, rs, rhox, beta0_y, betas_invN), key = arg0, arg1 # , alpha
    keys = jax.random.split(key, 2)
        
    # Calculate beta mean and covariance
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, beta0_y + _betas_y(rs))
    betas_cov = betas_invN * sigma**2
        
    # Sample rhos and rewards using the specified policy
    rhos = _sample_rhos_ts(betas_mean, betas_cov, data_x, data_a, jax.random.split(keys[0], T)) #policy alpha
    rs = _sample_rs(rhox, rhos, beta0_y, betas_invN, keys[1])
        
    return (rhos, rs, rhox, beta0_y, betas_invN), (rhos, rs)
    

def _sample_igw(arg0, arg1):
    (rhos, rs, rhox, beta0_y, betas_invN), key = arg0, arg1 # , alpha
    keys = jax.random.split(key, 2)
        
    # Calculate beta mean and covariance
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, beta0_y + _betas_y(rs))
    betas_cov = betas_invN * sigma**2
        
    # Sample rhos and rewards using the specified policy
    rhos = _sample_rhos_igw(betas_mean, betas_cov, data_x, data_a, jax.random.split(keys[0], T)) #policy alpha
    rs = _sample_rs(rhox, rhos, beta0_y, betas_invN, keys[1])
        
    return (rhos, rs, rhox, beta0_y, betas_invN), (rhos, rs)


In [6]:
@jax.jit
def sample_optimistic(rhox, beta0_y, beta0_N, key): #policy , alpha=20
    '''
    Orchestration of sampling that leads to sequences of rhos and rewards
    '''
    keys = jax.random.split(key, 2) # pseudorandom seed keys: 2
    
    betas_invN = np.linalg.inv(beta0_N + _betas_N) 
    
    rs = _sample_rs_init(rhox, keys[0]) #
    
    initial =(np.zeros((T,K)), rs, rhox, beta0_y, betas_invN) #, alpha
    _, (_RHOS, _RS) = jax.lax.scan(_sample_optimistic, initial, jax.random.split(keys[1], hyper['sample_stop'])) 
    
                
    RHOS = _RHOS[hyper['sample_start']::hyper['sample_step']]
    RS = _RS[hyper['sample_start']::hyper['sample_step']]
    
    return RHOS, RS

@jax.jit
def sample_softmax(rhox, beta0_y, beta0_N, key): #policy , alpha=20
    '''
    Orchestration of sampling that leads to sequences of rhos and rewards
    '''
    keys = jax.random.split(key, 2) # pseudorandom seed keys: 2
    
    betas_invN = np.linalg.inv(beta0_N + _betas_N) 
    
    rs = _sample_rs_init(rhox, keys[0]) #
    
    initial =(np.zeros((T,K)), rs, rhox, beta0_y, betas_invN) #, alpha
    _, (_RHOS, _RS) = jax.lax.scan(_sample_softmax, initial, jax.random.split(keys[1], hyper['sample_stop'])) 
    
                
    RHOS = _RHOS[hyper['sample_start']::hyper['sample_step']]
    RS = _RS[hyper['sample_start']::hyper['sample_step']]
    
    return RHOS, RS

@jax.jit
def sample_ts(rhox, beta0_y, beta0_N, key): #policy , alpha=20
    '''
    Orchestration of sampling that leads to sequences of rhos and rewards
    '''
    keys = jax.random.split(key, 2) # pseudorandom seed keys: 2
    
    betas_invN = np.linalg.inv(beta0_N + _betas_N) 
    
    rs = _sample_rs_init(rhox, keys[0]) #
    
    initial =(np.zeros((T,K)), rs, rhox, beta0_y, betas_invN) #, alpha
    _, (_RHOS, _RS) = jax.lax.scan(_sample_ts, initial, jax.random.split(keys[1], hyper['sample_stop'])) 
    
                
    RHOS = _RHOS[hyper['sample_start']::hyper['sample_step']]
    RS = _RS[hyper['sample_start']::hyper['sample_step']]
    
    return RHOS, RS

@jax.jit    
def sample_igw(rhox, beta0_y, beta0_N, key): #policy , alpha=20
    '''
    Orchestration of sampling that leads to sequences of rhos and rewards
    '''
    keys = jax.random.split(key, 2) # pseudorandom seed keys: 2
    
    betas_invN = np.linalg.inv(beta0_N + _betas_N) 
    
    rs = _sample_rs_init(rhox, keys[0]) #
    
    initial =(np.zeros((T,K)), rs, rhox, beta0_y, betas_invN) #, alpha
    _, (_RHOS, _RS) = jax.lax.scan(_sample_igw, initial, jax.random.split(keys[1], hyper['sample_stop'])) 
    
                
    RHOS = _RHOS[hyper['sample_start']::hyper['sample_step']]
    RS = _RS[hyper['sample_start']::hyper['sample_step']]
    
    return RHOS, RS

def compute_rhox(RS):
    '''
    the mean of the product of rewards and contexts over all sampled paths
    '''
    # updated estimate of the expected rewards conditioned on the actions taken
    _beta_y = _BETAS_Y(RS)[:,-1,:].mean(axis=0)
    # the last cumulative sum of outer products of contexts
    _beta_N = _betas_N[-1]
    # solving a linear system 
    # _beta_N coefficient matrix 
    # _beta_y vector term 
    rhox = np.einsum('ij,j->i', np.linalg.inv(_beta_N), _beta_y)
    return rhox

def _likelihood0(rho, beta_mean, beta_invcov):
    '''
    Q + logP: calculating log likelihood  | beta_mean beta_invcov
    '''
    # negative 
    res = -np.einsum('i,ij,j->', rho-beta_mean, beta_invcov, rho-beta_mean)
    res = res + np.log(np.linalg.det(beta_invcov))
    return res

_likelihood0 = jax.vmap(_likelihood0)

def _likelihood1(params, rhos, rs):
    '''
    Updating posterior beliefs, mean covariance then _likelihood0
    '''
    beta0_y, beta0_N = decode(params) # initialize parameters
    betas_y = beta0_y + _betas_y(rs)
    betas_N = beta0_N + _betas_N
    betas_invN = np.linalg.inv(betas_N)
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, betas_y)
    betas_invcov = betas_N / sigma**2
    
    return _likelihood0(rhos, betas_mean, betas_invcov).sum()

_likelihood1 = jax.vmap(_likelihood1, in_axes=(None,0,0))

def likelihood(params, RHOS, RS):
    '''
    Q mean
    '''
    return _likelihood1(params, RHOS, RS).mean()

grad_likelihood = jax.grad(likelihood)
grad_likelihood = jax.jit(grad_likelihood)

# PC-BICB RUN

In [7]:
output_dir = "C:/Users/huber/Dropbox/ADULTERY/PHD/Cambridge_task/Datasets"
input_dir = "C:/Users/huber/Dropbox/ADULTERY/PHD/Cambridge_task/Results"

RUNS=10

algos = ["ts", "optimistic", "softmax", "igw", "ucb", "greedy"]
policies = ["ts","igw"] #"softmax",  "optimistic",
policy_map = {"optimistic": 0, "softmax": 1, "ts": 2,  "igw": 3}

In [9]:
key = jax.random.PRNGKey(0)
alpha = 20 # exploration parameter in the policy definition
sigma = .10 # variance of the rewards

for i in tqdm.tqdm(range(1,10)):  # Assuming 100 datasets
    for algo in algos:
        for policy in policies:
            policy_index = policy_map[policy]
            filename = f'{output_dir}/dataset_{i}_{algo}.dill'
            with open(filename, 'rb') as f:
                data = dill.load(f)
            
                data_x = np.array(data['x'])  
                data_a = np.array(data['a'])  
                # rhox = data['rhox']           
                # betas_mean = data['betas_mean']
                # betas_cov=data['betas_cov']

                T,A,K = data_x.shape    
            
                # A dictionary of hyperparameters of the simulation
                hyper = dict()
                hyper['sample_start'] = 1000
                hyper['sample_stop'] = 2000
                hyper['sample_step'] = 1
                hyper['iter'] = 20

                _xs = lambda t0, t1: jax.lax.select(t1 <= t0, data_x[t1,data_a[t1]], np.zeros(K)) 
                _xs = jax.vmap(jax.vmap(_xs, in_axes=(None,0)), in_axes=(0,None)) 
                xs = _xs(np.arange(T-1), np.arange(T-1)) 
                
                __betas_N = lambda t: np.einsum('i,j->ij', data_x[t,data_a[t]], data_x[t,data_a[t]])
                __betas_N = jax.vmap(__betas_N)
                _betas_N = __betas_N(np.arange(T-1)).cumsum(axis=0)
                _betas_N = np.concatenate((np.zeros((K,K))[None,...], _betas_N))

                __betas_y = lambda r, t: r * data_x[t,data_a[t]]
                __betas_y = jax.vmap(__betas_y)
                _betas_y = lambda rs: np.concatenate((np.zeros(K)[None,...], __betas_y(rs, np.arange(T-1)).cumsum(axis=0)))
                _betas_y = jax.jit(_betas_y)
                _BETAS_Y = jax.jit(jax.vmap(_betas_y))
            
                rhox = -np.ones(K)/K
                params = {'beta0': 0.}
                grad_mnsq = {'beta0': 0.}
                beta0_y, beta0_N = decode(params)

                for j in range(hyper['iter']): #tqdm.tqdm(
    
                    key, subkey = jax.random.split(key)
                
                    if policy_index == 0:
                        RHOS, RS = sample_optimistic(rhox, beta0_y, beta0_N, subkey)
                    elif policy_index == 1:
                        RHOS, RS = sample_softmax(rhox, beta0_y, beta0_N, subkey)
                    elif policy_index == 2:
                        RHOS, RS = sample_ts(rhox, beta0_y, beta0_N, subkey)
                    elif policy_index == 3:
                        RHOS, _RS = sample_igw(rhox, beta0_y, beta0_N, subkey)
                    
                    rhox = compute_rhox(RS)

                    grad = grad_likelihood(params, RHOS, RS)
                    grad_mnsq['beta0'] = .1 * grad['beta0']**2 + .9 * grad_mnsq['beta0']
                    params['beta0'] += .001 * grad['beta0'] / (np.sqrt(grad_mnsq['beta0']) + 1e-8)
                    beta0_y, beta0_N = decode(params)

                # print(rhox, beta0_N[0,0])
    
                rhox = rhox / np.abs(rhox).sum()

                res = dict()
                res['rhox'] = rhox
                res['beta0_y'] = beta0_y
                res['beta0_N'] = beta0_N

                key, subkey = jax.random.split(key)
                if policy_index == 0:
                    _, RS = sample_optimistic(rhox, beta0_y, beta0_N, subkey)
                elif policy_index == 1:
                    _, RS = sample_softmax(rhox, beta0_y, beta0_N, subkey)
                elif policy_index == 2:
                    _, RS = sample_ts(rhox, beta0_y, beta0_N, subkey)
                elif policy_index == 3:
                    _, _RS = sample_igw(rhox, beta0_y, beta0_N, subkey)

                BETAS_Y = beta0_y + _BETAS_Y(RS)
                betas_invN = np.linalg.inv(beta0_N + _betas_N)
                betas_mean = np.einsum('ijk,lik->lij', betas_invN, BETAS_Y).mean(axis=0)
                betas_cov = betas_invN * sigma**2

                res['betas_mean'] = betas_mean
                res['betas_cov'] = betas_invN

                # Save the results
                filename = f'{input_dir}/dataset_{i}_{algo}_{policy}_PCICB.dill'
                with open(filename, 'wb') as f:
                    dill.dump(res, f)

100%|████████████████████████████████████████████████████████████████████████████████| 9/9 [2:56:26<00:00, 1176.24s/it]


# PC-NBICB UTILITIES

In [10]:
import dill
import jax
import jax.numpy as np
import numpy as np1
from jax.scipy.special import logsumexp
import tqdm

In [11]:
#jax.config.update('jax_platform_name', 'cpu')
# parser = argparse.ArgumentParser()
# parser.add_argument('-i', '--input', required=True)
# parser.add_argument('-o', '--output', default='res/general.obj')
# args = parser.parse_args()

hyper = dict()
hyper['sample_start'] = 10_000
hyper['sample_stop'] = 20_000
hyper['sample_step'] = 10
hyper['variance_rho'] = 5e-4
hyper['variance_beta'] = 5e-5
hyper['offset'] = -np.ones(K)/K

In [12]:
_cov_T = lambda t0, t1: np.minimum(t0, t1) + 1
_cov_T = jax.vmap(jax.vmap(_cov_T, in_axes=(None,0)), in_axes=(0,None))
cov_T = _cov_T(np.arange(T), np.arange(T))

_cov_K = np.eye(K)
_cov_K = _cov_K.at[:,0].set(np.ones(K))
_cov_K = _cov_K.at[:,0].set(_cov_K[:,0] / np.sum(_cov_K[:,0]**2)**.5)
for i in range(1,K):
    for j in range(i):
        _cov_K = _cov_K.at[:,i].add(-np.sum(_cov_K[:,i] * _cov_K[:,j]) * _cov_K[:,j])
    _cov_K = _cov_K.at[:,i].set(_cov_K[:,i] / np.sum(_cov_K[:,i]**2)**.5)
_scale = np.eye(K)
_scale = _scale.at[0,0].set(.1)
_cov_K = _cov_K @ _scale @ np.linalg.inv(_cov_K)
cov_K = _cov_K @ _cov_K.T

cov_rho = cov_K * hyper['variance_rho']
cov_rhos = np.kron(np.eye(T), cov_K) * hyper['variance_rho']
cov_betas = np.kron(cov_T, cov_K) * hyper['variance_beta']
invcov_rhos = np.linalg.inv(cov_rhos)
invcov_betas = np.linalg.inv(cov_betas)
mean_betas = hyper['offset'][None,...].repeat(T, axis=0).reshape(-1)
cov = np.linalg.inv(invcov_rhos + invcov_betas)
cov_at_invcov_betas_at_mean_betas = cov @ invcov_betas @ mean_betas
cov_at_invcov_rhos = cov @ invcov_rhos

cov_rho_L = np1.linalg.cholesky(cov_rho)
cov_L = np1.linalg.cholesky(cov)


In [13]:
# Implement likelihood functions for each policy
def softmax_like(rho, x,  a): #beta_cov,
    alpha=20
    q = alpha * np.einsum('ij,j->i', x, rho)
    
    return q[a] - logsumexp(q) 
    
def igw_like(rho, x, a): #beta_cov,
    alpha=20
    erewards = np.einsum('ij,j->i', x, rho)  # prediction
    best_arm = np.argmax(erewards)
    gaps = erewards[best_arm] - erewards  # Gaps

    A = x.shape[0] 
    pi = 1 / (A + alpha * gaps)
    pi=pi.at[best_arm].set(0) 

    # Adjust the best arm
    pi_best = 1 - np.sum(pi)
    pi=pi.at[best_arm].set(pi_best)
    
    return np.log(pi[a])


def _sample_rhos_softmax(beta, x, a, key): #beta_cov,
    keys = jax.random.split(key, 3)
    rho = beta + cov_rho_L @ jax.random.normal(keys[0], shape=(K,))
    _rho = beta + cov_rho_L @ jax.random.normal(keys[1], shape=(K,))
    like = softmax_like(rho, x, a) #beta_cov,
    _like = softmax_like(_rho, x, a) #beta_cov,
    cond = _like - like > np.log(jax.random.uniform(keys[2]))
    return jax.lax.select(cond, _rho, rho)
_sample_rhos_softmax = jax.vmap(_sample_rhos_softmax)


def _sample_rhos_igw(beta, x, a, key): #beta_cov,
    keys = jax.random.split(key, 3)
    rho = beta + cov_rho_L @ jax.random.normal(keys[0], shape=(K,))
    _rho = beta + cov_rho_L @ jax.random.normal(keys[1], shape=(K,))
    like = igw_like(rho, x, a) #beta_cov,
    _like = igw_like(_rho, x,  a) #beta_cov,
    cond = _like - like > np.log(jax.random.uniform(keys[2]))
    return jax.lax.select(cond, _rho, rho)
_sample_rhos_igw = jax.vmap(_sample_rhos_igw)


def _sample_betas(rhos, key):
    mean = cov_at_invcov_betas_at_mean_betas + cov_at_invcov_rhos @ rhos.reshape(-1)
    _betas = mean + cov_L @ jax.random.normal(key, shape=(T*K,))
    return _betas.reshape(-1,K)

def _sample_softmax(arg0, arg1):
    (rhos, betas), key = arg0, arg1
    keys = jax.random.split(key, 2)
    rhos = _sample_rhos_softmax(betas, data_x, data_a, jax.random.split(keys[0], T))
    betas = _sample_betas(rhos, keys[1])
    return (rhos, betas), (rhos, betas)

def _sample_igw(arg0, arg1):
    (rhos, betas), key = arg0, arg1
    keys = jax.random.split(key, 2)
    rhos = _sample_rhos_igw(betas, data_x, data_a, jax.random.split(keys[0], T))
    betas = _sample_betas(rhos, keys[1])
    return (rhos, betas), (rhos, betas)

def sample_softmax(rhos, betas, key, count):
    (rhos, betas), (RHOS, BETAS) = jax.lax.scan(_sample_softmax, (rhos, betas), jax.random.split(key, count))
    return rhos, betas, RHOS, BETAS
sample_softmax = jax.jit(sample_softmax, static_argnums=3)

def sample_igw(rhos, betas, key, count):
    (rhos, betas), (RHOS, BETAS) = jax.lax.scan(_sample_igw, (rhos, betas), jax.random.split(key, count))
    return rhos, betas, RHOS, BETAS
sample_igw = jax.jit(sample_igw, static_argnums=3)


# PC-NBICB RUN

In [16]:
#algos = ["softmax", "igw", "optimistic", "ucb", "ts"]
algos = ["greedy"]
policies = ["igw"] #"softmax"
policy_map = {"optimistic": 0, "softmax": 1, "ts": 2,  "igw": 3}

In [17]:
key = jax.random.PRNGKey(0)
alpha = 20 # exploration parameter in the policy definition
sigma = .10 # variance of the rewards

for i in tqdm.tqdm(range(RUNS)):  # Assuming 100 datasets
    for algo in algos:
        for policy in policies:
            policy_index = policy_map[policy]
            print(algo, policy_index)
            filename = f'{output_dir}/dataset_{i}_{algo}.dill'
            with open(filename, 'rb') as f:
                data = dill.load(f)
            
                data_x = np.array(data['x'])  
                data_a = np.array(data['a'])  
                rhox = data['rhox']           
                betas_mean = data['betas_mean']
                betas_cov=data['betas_cov']

                T,A,K = data_x.shape    
                
                _cov_T = lambda t0, t1: np.minimum(t0, t1) + 1
                _cov_T = jax.vmap(jax.vmap(_cov_T, in_axes=(None,0)), in_axes=(0,None))
                cov_T = _cov_T(np.arange(T), np.arange(T))

                _cov_K = np.eye(K)
                _cov_K = _cov_K.at[:,0].set(np.ones(K))
                _cov_K = _cov_K.at[:,0].set(_cov_K[:,0] / np.sum(_cov_K[:,0]**2)**.5)
                for s in range(1,K):
                    for k in range(s):
                        _cov_K = _cov_K.at[:,s].add(-np.sum(_cov_K[:,s] * _cov_K[:,k]) * _cov_K[:,k])
                        _cov_K = _cov_K.at[:,s].set(_cov_K[:,s] / np.sum(_cov_K[:,s]**2)**.5)
                _scale = np.eye(K)
                _scale = _scale.at[0,0].set(.1)
                _cov_K = _cov_K @ _scale @ np.linalg.inv(_cov_K)
                cov_K = _cov_K @ _cov_K.T

                cov_rho = cov_K * hyper['variance_rho']
                cov_rhos = np.kron(np.eye(T), cov_K) * hyper['variance_rho']
                cov_betas = np.kron(cov_T, cov_K) * hyper['variance_beta']
                invcov_rhos = np.linalg.inv(cov_rhos)
                invcov_betas = np.linalg.inv(cov_betas)
                mean_betas = hyper['offset'][None,...].repeat(T, axis=0).reshape(-1)
                cov = np.linalg.inv(invcov_rhos + invcov_betas)
                cov_at_invcov_betas_at_mean_betas = cov @ invcov_betas @ mean_betas
                cov_at_invcov_rhos = cov @ invcov_rhos

                cov_rho_L = np1.linalg.cholesky(cov_rho)
                cov_L = np1.linalg.cholesky(cov)
                
                rhos = np.zeros((T,K))
                betas = np.zeros((T,K)) + hyper['offset']

                BETAS = np.zeros((0,T,K))
                for z in range(hyper['sample_stop'] // 200): #tqdm.tqdm( , unit_scale=200)
                    key, subkey = jax.random.split(key)
                    if policy_index == 1:
                        rhos, betas, _RHOS, _BETAS  = sample_softmax(rhos, betas, subkey, 200)
                    elif policy_index == 3:
                        rhos, betas, _RHOS, _BETAS  = sample_igw(rhos, betas, subkey, 200)                
                    
#                rhos, betas, _RHOS, _BETAS = sample(rhos, betas, subkey, 200)
                    BETAS = np.concatenate((BETAS, _BETAS))

                betas = BETAS[hyper['sample_start']::hyper['sample_step']].mean(axis=0)
                betas = betas / np.abs(betas).sum(axis=-1, keepdims=True)

                res = dict()
                res['betas'] = betas
                
                print(betas)
                
                output = f'{input_dir}/dataset_{i}_{algo}_{policy}_PCNBICB.dill'
                with open(output, 'wb') as f:
                    dill.dump(res, f)            

  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

greedy 3


 10%|████████▎                                                                          | 1/10 [00:47<07:11, 47.97s/it]

[[-0.23941611 -0.26060665 -0.25816903 -0.24180819]
 [-0.22404097 -0.26649195 -0.2616108  -0.2478563 ]
 [-0.20828402 -0.271469   -0.26413536 -0.25611162]
 ...
 [ 0.25112474 -0.33944878 -0.08910514 -0.32032132]
 [ 0.25095668 -0.33958814 -0.0888503  -0.3206049 ]
 [ 0.250948   -0.3395229  -0.08881148 -0.32071754]]
greedy 3


 20%|████████████████▌                                                                  | 2/10 [01:38<06:35, 49.45s/it]

[[-0.24130087 -0.2607189  -0.25996354 -0.23801665]
 [-0.22542223 -0.26359338 -0.2623589  -0.24862546]
 [-0.20995812 -0.26745006 -0.26532215 -0.2572697 ]
 ...
 [ 0.26957262 -0.2008628  -0.15042719 -0.3791373 ]
 [ 0.2695559  -0.20093472 -0.15043432 -0.3790751 ]
 [ 0.2695856  -0.2011073  -0.15022251 -0.37908465]]
greedy 3


 30%|████████████████████████▉                                                          | 3/10 [02:29<05:52, 50.34s/it]

[[-0.24069317 -0.2620011  -0.258943   -0.23836279]
 [-0.22524428 -0.26767987 -0.26144502 -0.24563082]
 [-0.2094872  -0.27210027 -0.26272002 -0.25569248]
 ...
 [ 0.2859512  -0.32933033 -0.10572312 -0.27899536]
 [ 0.28594074 -0.32934213 -0.10565412 -0.27906302]
 [ 0.28590712 -0.32928005 -0.10572741 -0.27908543]]
greedy 3


 40%|█████████████████████████████████▏                                                 | 4/10 [03:19<05:01, 50.20s/it]

[[-0.23976469 -0.26068446 -0.2572771  -0.24227378]
 [-0.225028   -0.26710796 -0.2606368  -0.24722724]
 [-0.20868438 -0.27131972 -0.2621783  -0.25781757]
 ...
 [ 0.24602117 -0.46581617 -0.05001359 -0.238149  ]
 [ 0.24600615 -0.4658581  -0.04999042 -0.23814537]
 [ 0.24598584 -0.46613285 -0.04995407 -0.23792727]]
greedy 3


 50%|█████████████████████████████████████████▌                                         | 5/10 [04:12<04:15, 51.13s/it]

[[-0.24011941 -0.2602844  -0.25829187 -0.24130434]
 [-0.2245218  -0.26532164 -0.26132318 -0.2488334 ]
 [-0.20878279 -0.26917517 -0.26389876 -0.25814325]
 ...
 [ 0.25039747 -0.29715514 -0.13681884 -0.31562862]
 [ 0.25030965 -0.297239   -0.1368036  -0.31564772]
 [ 0.25034314 -0.2973754  -0.1367659  -0.31551552]]
greedy 3


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [05:02<03:22, 50.61s/it]

[[-0.24047603 -0.26128483 -0.25848058 -0.2397586 ]
 [-0.2251692  -0.26710272 -0.2612894  -0.24643873]
 [-0.2089535  -0.27115464 -0.26306906 -0.25682276]
 ...
 [ 0.263715   -0.33751833 -0.11441308 -0.2843536 ]
 [ 0.2636678  -0.3376225  -0.1143425  -0.2843672 ]
 [ 0.26369575 -0.3376672  -0.11426371 -0.28437337]]
greedy 3


 70%|██████████████████████████████████████████████████████████                         | 7/10 [05:49<02:28, 49.62s/it]

[[-0.23995888 -0.26132095 -0.25734863 -0.24137154]
 [-0.22493526 -0.26824108 -0.26049522 -0.2463284 ]
 [-0.20855261 -0.27291042 -0.26143408 -0.25710294]
 ...
 [ 0.25350836 -0.43805707 -0.0792273  -0.22920725]
 [ 0.25349596 -0.43800965 -0.0793397  -0.22915465]
 [ 0.25349778 -0.43809935 -0.07933308 -0.22906977]]
greedy 3


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [06:40<01:39, 49.82s/it]

[[-0.23959437 -0.2601576  -0.25650966 -0.2437384 ]
 [-0.22489874 -0.26640117 -0.26001135 -0.24868874]
 [-0.20871043 -0.2707454  -0.2617703  -0.2587739 ]
 ...
 [ 0.22272862 -0.47086045 -0.06089786 -0.24551305]
 [ 0.22268246 -0.4708727  -0.06071983 -0.24572502]
 [ 0.22266011 -0.47109696 -0.06062872 -0.24561416]]
greedy 3


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [07:29<00:49, 49.75s/it]

[[-0.23991638 -0.26023716 -0.2576037  -0.2422428 ]
 [-0.2248919  -0.2661596  -0.26105693 -0.24789162]
 [-0.2093277  -0.2702507  -0.26303324 -0.25738838]
 ...
 [ 0.244986   -0.35565627 -0.11614183 -0.28321594]
 [ 0.24497396 -0.35595146 -0.11604501 -0.28302953]
 [ 0.24497415 -0.35597414 -0.11601222 -0.28303945]]
greedy 3


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [08:18<00:00, 49.85s/it]

[[-0.23942623 -0.26006606 -0.25710058 -0.24340713]
 [-0.22486173 -0.26658848 -0.26052797 -0.24802186]
 [-0.20865345 -0.27107632 -0.26215595 -0.25811428]
 ...
 [ 0.23945545 -0.42183226 -0.10007455 -0.23863778]
 [ 0.23940662 -0.4217413  -0.10014167 -0.23871039]
 [ 0.23933457 -0.4219046  -0.09998736 -0.23877355]]





{'betas': Array([[-0.23648252, -0.2549834 , -0.25364617, -0.25488794],
        [-0.22289003, -0.2602487 , -0.25705463, -0.2598066 ],
        [-0.20957328, -0.2651861 , -0.2606942 , -0.26454642],
        [-0.19630934, -0.2698891 , -0.26354352, -0.2702581 ],
        [-0.1838874 , -0.27484503, -0.2673433 , -0.27392417],
        [-0.17134437, -0.27950135, -0.27049285, -0.27866143],
        [-0.15903208, -0.28454483, -0.27336025, -0.28306282],
        [-0.1464307 , -0.28899592, -0.27699545, -0.28757796],
        [-0.13425006, -0.29396236, -0.27991995, -0.29186764],
        [-0.12192728, -0.29852554, -0.28316343, -0.2963838 ],
        [-0.1110523 , -0.3029864 , -0.28615808, -0.2998032 ],
        [-0.09864198, -0.30718237, -0.289412  , -0.30476367],
        [-0.08795324, -0.3114569 , -0.2922193 , -0.30837053],
        [-0.07632723, -0.31552696, -0.29481173, -0.31333405],
        [-0.06705999, -0.31896666, -0.29736364, -0.3166097 ],
        [-0.05656556, -0.3230236 , -0.30024797, -0.32016286],