In [None]:
import argparse
import dill
import jax
import jax.numpy as np
from jax.scipy.special import logsumexp
import tqdm

In [None]:
RUNS=10
input_dir = "C:/Users/huber/Dropbox/ADULTERY/PHD/Cambridge_task/Datasets"
algos = ["igw", "optimistic", "softmax", "greedy", "ucb", "ts"]

filename = f'{input_dir}/dataset_1_ucb.dill'
with open(filename, 'rb') as f:
    data = dill.load(f)
            
    data_x = np.array(data['x'])  
    data_a = np.array(data['a'])  
#    rhox = data['rhox']           
#    betas_mean = data['betas_mean']
#    betas_cov=data['betas_cov']
    T,A,K = data_x.shape    

# BICB UTILITIES

In [None]:
# xs ~ constructs a matrix where each row represents the context vector selected by the action at each time step, 
# padded with zeros for time steps beyond the current one, this is reflected with t0, t1
_xs = lambda t0, t1: jax.lax.select(t1 <= t0, data_x[t1,data_a[t1]], np.zeros(K)) 
_xs = jax.vmap(jax.vmap(_xs, in_axes=(None,0)), in_axes=(0,None)) 
xs = _xs(np.arange(T-1), np.arange(T-1)) 

# Belief calculation utilities

## cumulative sum of outer products of contexts selected by actions up to timse t
## later used for covariance matrix updates
__betas_N = lambda t: np.einsum('i,j->ij', data_x[t,data_a[t]], data_x[t,data_a[t]])
__betas_N = jax.vmap(__betas_N)
_betas_N = __betas_N(np.arange(T-1)).cumsum(axis=0)
_betas_N = np.concatenate((np.zeros((K,K))[None,...], _betas_N))


## cumulative sum of dot product of rewards and context for each time step
## later used for posterior mean updates
__betas_y = lambda r, t: r * data_x[t,data_a[t]]
__betas_y = jax.vmap(__betas_y)
_betas_y = lambda rs: np.concatenate((np.zeros(K)[None,...], __betas_y(rs, np.arange(T-1)).cumsum(axis=0)))
_betas_y = jax.jit(_betas_y)
_BETAS_Y = jax.jit(jax.vmap(_betas_y))

@jax.jit
def decode(params):
    '''
    initialize parameters 
    '''
    beta0 = np.exp(20 * params['beta0']) # WHY 20 HERE for? (it is 1 in the end)
    beta0_y = -np.ones(K)/K * beta0  # vector
    beta0_N = np.eye(K) * beta0      # matrix
    return beta0_y, beta0_N

def _sample_rhos_like(rho, x, a):
    '''
    calculate log likelihood scores for each action using softmax   
    '''
    q = alpha * np.einsum('ij,j->i', x, rho) # dot product <context vector, parameter> for each action
    return q[a] - logsumexp(q) #LogSumExp - after exponentiation sum to 1? log-probability of choosing action over the other

def _sample_rhos(beta_mean, beta_cov, x, a, key):
    '''
    Sample two rhos, do a round of Metropolis Hastings
    '''
    keys = jax.random.split(key, 3) # pseud random seed controlling
    
    rho  = jax.random.multivariate_normal(keys[0], beta_mean, beta_cov) # sample rho'
    _rho = jax.random.multivariate_normal(keys[1], beta_mean, beta_cov) # sample rho''
    
    like  = _sample_rhos_like(rho, x, a) # calc like scores rho' 
    _like = _sample_rhos_like(_rho, x, a) # calc like scores rho''
    
    cond  = _like - like > np.log(jax.random.uniform(keys[2])) # Metropolis Hastings
    return jax.lax.select(cond, _rho, rho) # _rho if cond true rho otherwise

_sample_rhos = jax.vmap(_sample_rhos)

def _sample_rs_init(rhox, key):
    '''
    expected reward for each action at the last time step T-1 + add gaussian noise and return
    '''
    mean = np.einsum('ij,j->i', xs[-1], rhox)
    rs = mean + sigma * jax.random.normal(key, shape=(T-1,))
    return rs

def _sample_rs(rhox, rhos, beta0_y, betas_invN, key):
    '''
    Lemma 1 sampling
    '''
    invcov = np.eye(T-1) # initialize for inverse covariance 1/σ2 I
        # outer product of xs, weighted by inverse covariance of believes?
        # sum over all time steps 'a' 
        # cross-product 'bc' and 'de' indices
        # inverse covariance weights `cd' indices
        # \sum^T_t=2 X^T_{t−1} Σ_t X_{t−1}
    invcov = invcov + np.einsum('abc,acd,aed->be', xs, betas_invN[1:], xs) 
        # dot product of the last context vectors xs[-1] with the parameters rhox
        # (X^T_t \rhox)
    invcov_at_mean = np.einsum('ij,j->i', xs[-1], rhox)
        # 1) product of each context vector with its corresponding parameter estimate adjustment
        # 2) weighted sum of these adjustments across all contexts xs and adjusted parameters rhos[1:] - ...
        # \sum^T_t=2 X^T_t-1(\rho_t - C_t C^{-1}_1 \mu_1)
    invcov_at_mean = invcov_at_mean + np.einsum('ijk,ik->j', xs, rhos[1:] - np.einsum('ijk,k->ij', betas_invN[1:], beta0_y))
        # \tilde{C}
    cov = np.linalg.inv(invcov)
        # \tilde{\mu}
    mean = cov @ invcov_at_mean
    # sample normal
    rs = jax.random.multivariate_normal(key, mean, cov * sigma**2) # sampling the reward from a multivariate normal
    return rs

def _sample(arg0, arg1):
    '''
    Updating the posterior mean and covariance. 
    Sampling new env states (rhos) and rewards (rs) 
        based on the updated beliefs (betas_mean, betas_cov), using the context data (data_x, data_a)
    '''
    (rhos, rs, rhox, beta0_y, betas_invN), key = arg0, arg1 # unpack
    keys = jax.random.split(key, 2) # pseudo random seed
    
    # calculate beta mean and cov 
    # using cumulative dot product of rewards and context _betas_y(rs)) 
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, beta0_y + _betas_y(rs)) 
    betas_cov = betas_invN * sigma**2 
    
    # sample rhos and rewards
    rhos = _sample_rhos(betas_mean, betas_cov, data_x, data_a, jax.random.split(keys[0], T))
    rs = _sample_rs(rhox, rhos, beta0_y, betas_invN, keys[1])
    
    return (rhos, rs, rhox, beta0_y, betas_invN), (rhos, rs)

@jax.jit
def sample(rhox, beta0_y, beta0_N, key):
    '''
    Orchestration of sampling that leads to sequences of rhos and rewards
    '''
    keys = jax.random.split(key, 2) # pseudorandom seed keys: 2
    
        # inverse of the sum of:
        # (1) the initial belief "beta0_N"
        # (2) cumulative outer product of contexts selected by actions "_betas_N"
    betas_invN = np.linalg.inv(beta0_N + _betas_N) 
        # reward sampling based on current rho
    rs = _sample_rs_init(rhox, keys[0]) #
        # iteratively apply the "_sample" function across time steps, generating sequences of rewards and states
    _, (_RHOS, _RS) = jax.lax.scan(_sample, (np.zeros((T,K)), rs, rhox, beta0_y, betas_invN), jax.random.split(keys[1], hyper['sample_stop']))
         # initial samples were for burn-in purposes
    RHOS = _RHOS[hyper['sample_start']::hyper['sample_step']]
    RS = _RS[hyper['sample_start']::hyper['sample_step']]
    
    return RHOS, RS

def compute_rhox(RS):
    '''
    the mean of the product of rewards and contexts over all sampled paths
    '''
    # updated estimate of the expected rewards conditioned on the actions taken
    _beta_y = _BETAS_Y(RS)[:,-1,:].mean(axis=0)
    # the last cumulative sum of outer products of contexts
    _beta_N = _betas_N[-1]
    # solving a linear system 
    # _beta_N coefficient matrix 
    # _beta_y vector term 
    rhox = np.einsum('ij,j->i', np.linalg.inv(_beta_N), _beta_y)
    return rhox

def _likelihood0(rho, beta_mean, beta_invcov):
    '''
    Q + logP: calculating log likelihood  | beta_mean beta_invcov
    '''
    # negative 
    res = -np.einsum('i,ij,j->', rho-beta_mean, beta_invcov, rho-beta_mean)
    res = res + np.log(np.linalg.det(beta_invcov))
    return res

_likelihood0 = jax.vmap(_likelihood0)

def _likelihood1(params, rhos, rs):
    '''
    Updating posterior beliefs, mean covariance then _likelihood0
    '''
    beta0_y, beta0_N = decode(params) # initialize parameters
    betas_y = beta0_y + _betas_y(rs)
    betas_N = beta0_N + _betas_N
    betas_invN = np.linalg.inv(betas_N)
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, betas_y)
    betas_invcov = betas_N / sigma**2
    
    return _likelihood0(rhos, betas_mean, betas_invcov).sum()

_likelihood1 = jax.vmap(_likelihood1, in_axes=(None,0,0))

def likelihood(params, RHOS, RS):
    '''
    Q mean
    '''
    return _likelihood1(params, RHOS, RS).mean()

grad_likelihood = jax.grad(likelihood)
grad_likelihood = jax.jit(grad_likelihood)

In [None]:
output_dir = "../Datasets"
input_dir = "../Results"
algos = ["igw", "optimistic", "softmax", "greedy", "ucb", "ts"]
RUNS=10

# A dictionary of hyperparameters of the simulation
hyper = dict()
 
############### ADJUST HYPER AND THE LOADING
# hyper['sample_start'] = 500
# hyper['sample_stop'] = 1_000
# hyper['sample_step'] = 1
# hyper['iter'] = 10

# data_x = np.array(data['x'])  
# data_a = np.array(data['a'])  
# rhox = data['rhox']           
# betas_mean = data['betas_mean']" 

In [None]:
for i in tqdm.tqdm(range(RUNS)):
    print(i)

# BICB RUN

In [None]:
key = jax.random.PRNGKey(0)
alpha = 20 # exploration parameter in the policy definition
sigma = .10 # variance of the rewards

for i in tqdm.tqdm(range(1,10)):  # Assuming 100 datasets\
    for algo in algos:
        print(i, algo)
        filename = f'{output_dir}/dataset_{i}_{algo}.dill'
        with open(filename, 'rb') as f:
            data = dill.load(f)
            
            data_x = np.array(data['x'])  
            data_a = np.array(data['a'])  
            rhox = data['rhox']           
            betas_mean = data['betas_mean']
            betas_cov=data['betas_cov']

            T,A,K = data_x.shape    
            
            # A dictionary of hyperparameters of the simulation
            hyper = dict()
            hyper['sample_start'] = 1_000
            hyper['sample_stop'] = 2_000
            hyper['sample_step'] = 1
            hyper['iter'] = 20

            _xs = lambda t0, t1: jax.lax.select(t1 <= t0, data_x[t1,data_a[t1]], np.zeros(K)) 
            _xs = jax.vmap(jax.vmap(_xs, in_axes=(None,0)), in_axes=(0,None)) 
            xs = _xs(np.arange(T-1), np.arange(T-1)) 
    
            __betas_N = lambda t: np.einsum('i,j->ij', data_x[t,data_a[t]], data_x[t,data_a[t]])
            __betas_N = jax.vmap(__betas_N)
            _betas_N = __betas_N(np.arange(T-1)).cumsum(axis=0)
            _betas_N = np.concatenate((np.zeros((K,K))[None,...], _betas_N))

            __betas_y = lambda r, t: r * data_x[t,data_a[t]]
            __betas_y = jax.vmap(__betas_y)
            _betas_y = lambda rs: np.concatenate((np.zeros(K)[None,...], __betas_y(rs, np.arange(T-1)).cumsum(axis=0)))
            _betas_y = jax.jit(_betas_y)
            _BETAS_Y = jax.jit(jax.vmap(_betas_y))
            
            rhox = -np.ones(K)/K
            params = {'beta0': 0.}
            grad_mnsq = {'beta0': 0.}
            beta0_y, beta0_N = decode(params)

            for j in range(hyper['iter']):  #tqdm.tqdm()
    
                key, subkey = jax.random.split(key)
                RHOS, RS = sample(rhox, beta0_y, beta0_N, subkey)

                rhox = compute_rhox(RS)

                grad = grad_likelihood(params, RHOS, RS)
                grad_mnsq['beta0'] = .1 * grad['beta0']**2 + .9 * grad_mnsq['beta0']
                params['beta0'] += .001 * grad['beta0'] / (np.sqrt(grad_mnsq['beta0']) + 1e-8)
                beta0_y, beta0_N = decode(params)

            # print(rhox, beta0_N[0,0])

            rhox = rhox / np.abs(rhox).sum()

            res = dict()
            res['rhox'] = rhox
            res['beta0_y'] = beta0_y
            res['beta0_N'] = beta0_N

            key, subkey = jax.random.split(key)
            _, RS = sample(rhox, beta0_y, beta0_N, subkey)

            BETAS_Y = beta0_y + _BETAS_Y(RS)
            betas_invN = np.linalg.inv(beta0_N + _betas_N)
            betas_mean = np.einsum('ijk,lik->lij', betas_invN, BETAS_Y).mean(axis=0)
            betas_cov = betas_invN * sigma**2

            res['betas_mean'] = betas_mean
            res['betas_cov'] = betas_invN

            print(res['betas_mean'])
            
            # Save the results
            filename = f'{input_dir}/dataset_{i}_{algo}_BICB.dill'
            with open(filename, 'wb') as f:
                dill.dump(res, f)


# NBICB

In [None]:
import argparse
import dill
import jax
import jax.numpy as np
import numpy as np1
from jax.scipy.special import logsumexp
import tqdm

In [None]:
#jax.config.update('jax_platform_name', 'cpu')
# parser = argparse.ArgumentParser()
# parser.add_argument('-i', '--input', required=True)
# parser.add_argument('-o', '--output', default='res/general.obj')
# args = parser.parse_args()

hyper = dict()
hyper['sample_start'] = 10_000
hyper['sample_stop'] = 20_000
hyper['sample_step'] = 10
hyper['variance_rho'] = 5e-4
hyper['variance_beta'] = 5e-5
hyper['offset'] = -np.ones(K)/K

In [None]:
_cov_T = lambda t0, t1: np.minimum(t0, t1) + 1
_cov_T = jax.vmap(jax.vmap(_cov_T, in_axes=(None,0)), in_axes=(0,None))
cov_T = _cov_T(np.arange(T), np.arange(T))

_cov_K = np.eye(K)
_cov_K = _cov_K.at[:,0].set(np.ones(K))
_cov_K = _cov_K.at[:,0].set(_cov_K[:,0] / np.sum(_cov_K[:,0]**2)**.5)
for i in range(1,K):
    for j in range(i):
        _cov_K = _cov_K.at[:,i].add(-np.sum(_cov_K[:,i] * _cov_K[:,j]) * _cov_K[:,j])
    _cov_K = _cov_K.at[:,i].set(_cov_K[:,i] / np.sum(_cov_K[:,i]**2)**.5)
_scale = np.eye(K)
_scale = _scale.at[0,0].set(.1)
_cov_K = _cov_K @ _scale @ np.linalg.inv(_cov_K)
cov_K = _cov_K @ _cov_K.T

cov_rho = cov_K * hyper['variance_rho']
cov_rhos = np.kron(np.eye(T), cov_K) * hyper['variance_rho']
cov_betas = np.kron(cov_T, cov_K) * hyper['variance_beta']
invcov_rhos = np.linalg.inv(cov_rhos)
invcov_betas = np.linalg.inv(cov_betas)
mean_betas = hyper['offset'][None,...].repeat(T, axis=0).reshape(-1)
cov = np.linalg.inv(invcov_rhos + invcov_betas)
cov_at_invcov_betas_at_mean_betas = cov @ invcov_betas @ mean_betas
cov_at_invcov_rhos = cov @ invcov_rhos

cov_rho_L = np1.linalg.cholesky(cov_rho)
cov_L = np1.linalg.cholesky(cov)

def _sample_rhos_like(rho, x, a):
    q = alpha * np.einsum('ij,j->i', x, rho)
    return q[a] - logsumexp(q)

def _sample_rhos(beta, x, a, key):
    keys = jax.random.split(key, 3)
    rho = beta + cov_rho_L @ jax.random.normal(keys[0], shape=(K,))
    _rho = beta + cov_rho_L @ jax.random.normal(keys[1], shape=(K,))
    like = _sample_rhos_like(rho, x, a)
    _like = _sample_rhos_like(_rho, x, a)
    cond = _like - like > np.log(jax.random.uniform(keys[2]))
    return jax.lax.select(cond, _rho, rho)
_sample_rhos = jax.vmap(_sample_rhos)

def _sample_betas(rhos, key):
    mean = cov_at_invcov_betas_at_mean_betas + cov_at_invcov_rhos @ rhos.reshape(-1)
    _betas = mean + cov_L @ jax.random.normal(key, shape=(T*K,))
    return _betas.reshape(-1,K)

def _sample(arg0, arg1):
    (rhos, betas), key = arg0, arg1
    keys = jax.random.split(key, 2)
    rhos = _sample_rhos(betas, data_x, data_a, jax.random.split(keys[0], T))
    betas = _sample_betas(rhos, keys[1])
    return (rhos, betas), (rhos, betas)

def sample(rhos, betas, key, count):
    (rhos, betas), (RHOS, BETAS) = jax.lax.scan(_sample, (rhos, betas), jax.random.split(key, count))
    return rhos, betas, RHOS, BETAS
sample = jax.jit(sample, static_argnums=3)

In [None]:
for i in tqdm.tqdm(range(1,10)):  # Assuming 100 datasets
    for algo in algos:
        filename = f'{output_dir}/dataset_{i}_{algo}.dill'
        with open(filename, 'rb') as f:
            data = dill.load(f)
            
            data_x = np.array(data['x'])  
            data_a = np.array(data['a'])  
            rhox = data['rhox']           
            betas_mean = data['betas_mean']
            betas_cov=data['betas_cov']

            T,A,K = data_x.shape    

            rhos = np.zeros((T,K))
            betas = np.zeros((T,K)) + hyper['offset']

            BETAS = np.zeros((0,T,K))
            for j in range(hyper['sample_stop'] // 200): #tqdm.tqdm(, unit_scale=200)
                key, subkey = jax.random.split(key)
                rhos, betas, _RHOS, _BETAS = sample(rhos, betas, subkey, 200)
                BETAS = np.concatenate((BETAS, _BETAS))

            betas = BETAS[hyper['sample_start']::hyper['sample_step']].mean(axis=0)
            betas = betas / np.abs(betas).sum(axis=-1, keepdims=True)

            res = dict()
            res['betas'] = betas
            
            filename = f'{input_dir}/dataset_{i}_{algo}_NBICB.dill'
            with open(filename, 'wb') as f:
                dill.dump(res, f)            