In [None]:
# import argparse using notebook for now, don't need these
# import dill
import jax
import dill
import jax.numpy as np
from jax.scipy.special import logsumexp
from tqdm import tqdm

In [None]:
RUNS=1
input_dir = "../Datasets"
algos = ["optimistic", "softmax", "igw", "greedy", "ucb", "ts"]

filename = f'{input_dir}/dataset_1_ucb.dill'
with open(filename, 'rb') as f:
    data = dill.load(f)
            
    data_x = np.array(data['x'])  
    data_a = np.array(data['a'])  
#    rhox = data['rhox']           
#    betas_mean = data['betas_mean']
#    betas_cov=data['betas_cov']
    T,A,K = data_x.shape    

# A dictionary of hyperparameters of the simulation
hyper = dict()
 
############### ADJUST HYPER AND THE LOADING
hyper['n_samples'] = 100
hyper['iter'] = 100

T = data_x.shape[0] # time dimension, total number of steps
A = data_x.shape[1] # this is not used anywhere !!! Each action has it's own context, this is unusual, interesting setup
K = data_x.shape[2] # feature space dimension, context shape
alpha = 20 # exploration parameter in the policy definition
sigma = .10 # variance of the rewards

In [None]:
# Belief calculation utilities

## cumulative sum of outer products of contexts selected by actions up to time t
## later used for covariance matrix updates

__betas_N = lambda t: np.einsum('i,j->ij', data_x[t,data_a[t]], data_x[t,data_a[t]])
__betas_N = jax.vmap(__betas_N)
_betas_N = __betas_N(np.arange(T-1)).cumsum(axis=0)
_betas_N = np.concatenate((np.zeros((K,K))[None,...], _betas_N))

__betas_y = lambda r, t: r * data_x[t,data_a[t]]
__betas_y = jax.vmap(__betas_y)
_betas_y = lambda rs: np.concatenate((np.zeros(K)[None,...], __betas_y(rs, np.arange(T-1)).cumsum(axis=0)))
_betas_y = jax.jit(_betas_y)
_BETAS_Y = jax.jit(jax.vmap(_betas_y))

@jax.jit
def decode(params):
    '''
    initialize or unpack parameters 
    '''
    beta0 = np.exp(20 * params['beta0']) 
    beta0_y = -np.ones(K)/K * beta0  # vector
    beta0_N = np.eye(K) * beta0      # matrix
    return beta0_y, beta0_N

# Implement likelihood functions for each policy
def optimistic_pi(beta_mean, x, beta_cov, a):
    q = alpha * np.einsum('ij,j->i', x, beta_mean) + np.einsum('ij,jk,ki->i', x, beta_cov, x.T)
    return q[a] - logsumexp(q)    

def softmax_pi(beta_mean, x, beta_cov, a):
    q = alpha * np.einsum('ij,j->i', x, beta_mean)
    return q[a] - logsumexp(q) 

def ts_pi(beta_mean, x, beta_cov, a, key):
    num_samples=10
    alpha=20
    K = beta_mean.shape[0] 
    A = x.shape[0]  
    
    sampled_rhos = jax.random.multivariate_normal(key, beta_mean, beta_cov, (num_samples,))    
    scores = np.dot(sampled_rhos, x.T)
    best_actions = np.argmax(scores, axis=1)
    freq=np.zeros(A)
    for action in range(A):
        freq=freq.at[action].set(np.sum(best_actions == action) / num_samples)
        
    return np.log(freq[a])

def igw_pi(beta_mean, x, beta_cov, a):
    alpha=20
    erewards = np.einsum('ij,j->i', x, beta_mean)  # prediction
    best_arm = np.argmax(erewards)
    gaps = erewards[best_arm] - erewards  # Gaps

    A = x.shape[0]  # x is (A, K)
    # Compute the prob for non-best 
    pi = 1 / (A + alpha * gaps)
    pi.at[best_arm].set(0)  # temp

    # Adjust the best arm
    pi_best = 1 - np.sum(pi)
    pi=pi.at[best_arm].set(pi_best)
    
    return np.log(pi[a])

def _sample_rs_softmax(rhox, x, a, t, key, beta0_N, beta0_y, _betas_N, rs): #beta_mean, beta_cov, rhox, x, a, t, key, 
    keys = jax.random.split(key, 3)
    ereward = np.dot(rhox, x[a])
    
    r = jax.random.normal(keys[0]) * sigma + ereward
    _r = jax.random.normal(keys[1]) * sigma + ereward

    # Update rs with the new sampled rewards for comparison
    _rs = rs.at[t].set(r)
    _rs_ = rs.at[t].set(_r)
    
    # Calculate updated beliefs for both reward samples
    betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, _rs)
    betas_mean_, betas_cov_ = f_update(beta0_N, _betas_N, beta0_y, _rs_)
    
    # Compute softmax policy probabilities for both sets of beliefs
    like = softmax_pi(betas_mean[t, :], x, betas_cov[t, :, :], a)
    _like = softmax_pi(betas_mean_[t, :], x, betas_cov_[t, :, :], a)
    
    # Metropolis-Hastings condition to select reward
    cond = _like - like > np.log(jax.random.uniform(keys[2]))
    selected_r = jax.lax.select(cond, _r, r)
    
    # Return the selected reward and updated beliefs for the selected reward
    updated_beliefs = f_update(beta0_N, _betas_N, beta0_y, rs.at[t].set(selected_r))
    return selected_r, updated_beliefs

def _sample_rs_igw(rhox, x, a, t, key, beta0_N, beta0_y, _betas_N, rs): #beta_mean, beta_cov, rhox, x, a, t, key, 
    keys = jax.random.split(key, 3)
    ereward = np.dot(rhox, x[a])
    
    r = jax.random.normal(keys[0]) * sigma + ereward
    _r = jax.random.normal(keys[1]) * sigma + ereward

    # Update rs with the new sampled rewards for comparison
    _rs = rs.at[t].set(r)
    _rs_ = rs.at[t].set(_r)
    
    # Calculate updated beliefs for both reward samples
    betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, _rs)
    betas_mean_, betas_cov_ = f_update(beta0_N, _betas_N, beta0_y, _rs_)
    
    # Compute softmax policy probabilities for both sets of beliefs
    like = igw_pi(betas_mean[t, :], x, betas_cov[t, :, :], a)
    _like = igw_pi(betas_mean_[t, :], x, betas_cov_[t, :, :], a)
    
    # Metropolis-Hastings condition to select reward
    cond = _like - like > np.log(jax.random.uniform(keys[2]))
    selected_r = jax.lax.select(cond, _r, r)
    
    # Return the selected reward and updated beliefs for the selected reward
    updated_beliefs = f_update(beta0_N, _betas_N, beta0_y, rs.at[t].set(selected_r))
    return selected_r, updated_beliefs

def _sample_rs_ts(rhox, x, a, t, key, beta0_N, beta0_y, _betas_N, rs): #beta_mean, beta_cov, rhox, x, a, t, key, 
    keys = jax.random.split(key, 3)
    ereward = np.dot(rhox, x[a])
    
    r = jax.random.normal(keys[0]) * sigma + ereward
    _r = jax.random.normal(keys[1]) * sigma + ereward

    # Update rs with the new sampled rewards for comparison
    _rs = rs.at[t].set(r)
    _rs_ = rs.at[t].set(_r)
    
    # Calculate updated beliefs for both reward samples
    betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, _rs)
    betas_mean_, betas_cov_ = f_update(beta0_N, _betas_N, beta0_y, _rs_)
    
    # Compute softmax policy probabilities for both sets of beliefs
    like = ts_pi(betas_mean[t, :], x, betas_cov[t, :, :], a,  keys[0])
    _like = ts_pi(betas_mean_[t, :], x, betas_cov_[t, :, :], a,  keys[0])
    
    # Metropolis-Hastings condition to select reward
    cond = _like - like > np.log(jax.random.uniform(keys[2]))
    selected_r = jax.lax.select(cond, _r, r)
    
    # Return the selected reward and updated beliefs for the selected reward
    updated_beliefs = f_update(beta0_N, _betas_N, beta0_y, rs.at[t].set(selected_r))
    return selected_r, updated_beliefs


def _sample_rs_optimistic(rhox, x, a, t, key, beta0_N, beta0_y, _betas_N, rs): #beta_mean, beta_cov, rhox, x, a, t, key, 
    keys = jax.random.split(key, 3)
    ereward = np.dot(rhox, x[a])
    
    r = jax.random.normal(keys[0]) * sigma + ereward
    _r = jax.random.normal(keys[1]) * sigma + ereward

    # Update rs with the new sampled rewards for comparison
    _rs = rs.at[t].set(r)
    _rs_ = rs.at[t].set(_r)
    
    # Calculate updated beliefs for both reward samples
    betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, _rs)
    betas_mean_, betas_cov_ = f_update(beta0_N, _betas_N, beta0_y, _rs_)
    
    # Compute softmax policy probabilities for both sets of beliefs
    like = optimistic_pi(betas_mean[t, :], x, betas_cov[t, :, :], a)
    _like = optimistic_pi(betas_mean_[t, :], x, betas_cov_[t, :, :], a)
    
    # Metropolis-Hastings condition to select reward
    cond = _like - like > np.log(jax.random.uniform(keys[2]))
    selected_r = jax.lax.select(cond, _r, r)
    
    # Return the selected reward and updated beliefs for the selected reward
    updated_beliefs = f_update(beta0_N, _betas_N, beta0_y, rs.at[t].set(selected_r))
    return selected_r, updated_beliefs

# @jax.jit
def sample_rs_softmax(args, keys):
    (rhox, data_x, data_a, beta0_N, beta0_y, rs) = args
    rs_init =rs
    T, _, _ = data_x.shape
    updated_rs = np.zeros(T)
    updated_beliefs = None
    
    iter_samples = keys.shape[0]    
    RS = np.zeros((iter_samples, T))

    for iter_idx in range(iter_samples):
        rs = np.copy(rs_init)  # Start with the initial set of rewards for each sample\
        key_sample =keys[iter_idx]        

        for t in range(T):
            x_t = data_x[t]
            a_t = data_a[t]
            key_timestep, key_sample = jax.random.split(key_sample)

            # Sample rewards and update beliefs for the current timestep
            sampled_r, beliefs_for_timestep = _sample_rs_softmax(rhox, x_t, a_t, t, key_timestep, beta0_N, beta0_y, _betas_N, rs)
            rs=rs.at[t].set(sampled_r)
            updated_beliefs = beliefs_for_timestep
            
            RS=RS.at[iter_idx, t].set(rs[t])  # Update the corresponding entry in RS

    return RS


# @jax.jit
def sample_rs_optimistic(args, keys):
    (rhox, data_x, data_a, beta0_N, beta0_y, rs) = args
    rs_init =rs
    T, _, _ = data_x.shape
    updated_rs = np.zeros(T)
    updated_beliefs = None
    
    iter_samples = keys.shape[0]    
    RS = np.zeros((iter_samples, T))

    for iter_idx in range(iter_samples):
        rs = np.copy(rs_init)  # Start with the initial set of rewards for each sample\
        key_sample =keys[iter_idx]        

        for t in range(T):
            x_t = data_x[t]
            a_t = data_a[t]
            key_timestep, key_sample = jax.random.split(key_sample)

            # Sample rewards and update beliefs for the current timestep
            sampled_r, beliefs_for_timestep = _sample_rs_optimistics(rhox, x_t, a_t, t, key_timestep, beta0_N, beta0_y, _betas_N, rs)
            rs=rs.at[t].set(sampled_r)
            updated_beliefs = beliefs_for_timestep
            
            RS=RS.at[iter_idx, t].set(rs[t])  # Update the corresponding entry in RS

    return RS


# @jax.jit
def sample_rs_igw(args, keys):
    (rhox, data_x, data_a, beta0_N, beta0_y, rs) = args
    rs_init =rs
    T, _, _ = data_x.shape
    updated_rs = np.zeros(T)
    updated_beliefs = None
    
    iter_samples = keys.shape[0]    
    RS = np.zeros((iter_samples, T))


    for iter_idx in range(iter_samples):
        rs = np.copy(rs_init)  # Start with the initial set of rewards for each sample\
        key_sample =keys[iter_idx]        

        for t in range(T):
            x_t = data_x[t]
            a_t = data_a[t]
            key_timestep, key_sample = jax.random.split(key_sample)

            # Sample rewards and update beliefs for the current timestep
            sampled_r, beliefs_for_timestep = _sample_rs_igw(rhox, x_t, a_t, t, key_timestep, beta0_N, beta0_y, _betas_N, rs)
            rs=rs.at[t].set(sampled_r)
            updated_beliefs = beliefs_for_timestep
            
            RS=RS.at[iter_idx, t].set(rs[t])  # Update the corresponding entry in RS

    return RS

# @jax.jit
def sample_rs_ts(args, keys):
    (rhox, data_x, data_a, beta0_N, beta0_y, rs) = args
    rs_init =rs
    T, _, _ = data_x.shape
    updated_rs = np.zeros(T)
    updated_beliefs = None
    
    iter_samples = keys.shape[0]    
    RS = np.zeros((iter_samples, T))
    
    for iter_idx in range(iter_samples):
        rs = np.copy(rs_init)  # Start with the initial set of rewards for each sample\
        key_sample =keys[iter_idx]        

        for t in range(T):
            x_t = data_x[t]
            a_t = data_a[t]
            key_timestep, key_sample = jax.random.split(key_sample)

            # Sample rewards and update beliefs for the current timestep
            sampled_r, beliefs_for_timestep = _sample_rs_ts(rhox, x_t, a_t, t, key_timestep, beta0_N, beta0_y, _betas_N, rs)
            rs=rs.at[t].set(sampled_r)
            updated_beliefs = beliefs_for_timestep
            
            RS=RS.at[iter_idx, t].set(rs[t])  # Update the corresponding entry in RS

    return RS

def f_update(beta0_N, _betas_N, beta0_y, rs):
    betas_y_updated = _betas_y(rs[:-1])  
    betas_invN = np.linalg.inv(beta0_N + _betas_N)  
    
    betas_mean = np.einsum('ijk,ik->ij', betas_invN, beta0_y + betas_y_updated)
    betas_cov = betas_invN * sigma**2
    
    return betas_mean, betas_cov

@jax.jit
def compute_rhox(RS):
    _beta_y = _BETAS_Y(RS[:, :-1])[:,-1,:].mean(axis=0)
    _beta_N = _betas_N[-1]
    rhox = np.einsum('ij,j->i', np.linalg.inv(_beta_N), _beta_y)
    return rhox

In [None]:
def likelihood_rs_igw(beta0_N, beta0_y, data_x, data_a, rs):
    # Assuming beta0_N, beta0_y are your model parameters for this example
    # Compute betas_mean, betas_cov using f_update or similar for the entire sequence
    
    
    # Initialize likelihood sum
    likelihood_sum = 0.0
    for sample_rs in rs:
        betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, sample_rs)    
        # Iterate over all timesteps
        for t in range(T):
            policy_output = igw_pi(betas_mean[t], data_x[t], betas_cov, data_a[t])                    
            likelihood_sum += policy_output
    
    return (likelihood_sum/rs.shape[0]).mean()

grad_likelihood_igw = jax.grad(likelihood_rs_igw)
grad_likelihood_igw = jax.jit(grad_likelihood_igw)

def likelihood_rs_ts(beta0_N, beta0_y, data_x, data_a, rs):    
    betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, rs)
    
   # Initialize likelihood sum
    likelihood_sum = 0.0
    for sample_rs in rs:
        betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, sample_rs)
    
        # Iterate over all timesteps
        for t in range(T):        
            policy_output = ts_pi(betas_mean[t], data_x[t], betas_cov, data_a[t])                
            likelihood_sum += policy_output
    
    return (likelihood_sum/rs.shape[0]).mean()


grad_likelihood_ts = jax.grad(likelihood_rs_ts)
grad_likelihood_ts = jax.jit(grad_likelihood_ts)

def likelihood_rs_optimistic(beta0_N, beta0_y, data_x, data_a, rs):
    # Assuming beta0_N, beta0_y are your model parameters for this example
    # Compute betas_mean, betas_cov using f_update or similar for the entire sequence
    betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, rs)
    
   # Initialize likelihood sum
    likelihood_sum = 0.0
    for sample_rs in rs:
        betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, sample_rs)
    
        # Iterate over all timesteps
        for t in range(T):
            policy_output = optimistic_pi(betas_mean[t], data_x[t], betas_cov, data_a[t])
            likelihood_sum += policy_output
    
    return (likelihood_sum/rs.shape[0]).mean()


grad_likelihood_optimistic = jax.grad(likelihood_rs_optimistic)
grad_likelihood_optimistic = jax.jit(grad_likelihood_optimistic)


def likelihood_rs_softmax(beta0_N, beta0_y, data_x, data_a, rs):

    likelihood_sum = 0.0
    for sample_rs in rs:
        betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, sample_rs)
    
        for t in range(T):
            policy_output = softmax_pi(betas_mean[t], data_x[t], betas_cov, data_a[t])
            likelihood_sum += policy_output
    
    return (likelihood_sum/rs.shape[0]).mean()


grad_likelihood_softmax = jax.grad(likelihood_rs_softmax)
grad_likelihood_softmax = jax.jit(grad_likelihood_softmax)

In [None]:
likelihood_sum = 0.0
for sample_rs in RS:
    betas_mean, betas_cov = f_update(beta0_N, _betas_N, beta0_y, sample_rs)
    
    for t in range(T):
        policy_output = softmax_pi(betas_mean[t], data_x[t], betas_cov, data_a[t])
        likelihood_sum += policy_output
    
(likelihood_sum/rs.shape[0]).mean()

In [None]:
grad_likelihood_softmax(beta0_N, beta0_y, data_x, data_a, RS)

In [None]:
output_dir = "../Datasets"
input_dir = "../Results"

RUNS=10
import tqdm
algos = ["ts", "optimistic", "softmax", "igw", "ucb", "greedy"]
policies = ["softmax",  "optimistic", "ts","igw"]
policy_map = {"optimistic": 0, "softmax": 1, "ts": 2,  "igw": 3}

# Model ICB run

In [None]:
key = jax.random.PRNGKey(0)
alpha = 20 # exploration parameter in the policy definition
sigma = .10 # variance of the rewards

for i in tqdm.tqdm(range(RUNS)):  
    for algo in algos:
        for policy in policies:
            print(i, algo, policy)
            policy_index = policy_map[policy]
            filename = f'{output_dir}/dataset_{i}_{algo}.dill'
            with open(filename, 'rb') as f:
                data = dill.load(f)
            
                data_x = np.array(data['x'])  
                data_a = np.array(data['a'])  
                T,A,K = data_x.shape     

                __betas_N = lambda t: np.einsum('i,j->ij', data_x[t,data_a[t]], data_x[t,data_a[t]])
                __betas_N = jax.vmap(__betas_N)
                _betas_N = __betas_N(np.arange(T-1)).cumsum(axis=0)
                _betas_N = np.concatenate((np.zeros((K,K))[None,...], _betas_N))

                __betas_y = lambda r, t: r * data_x[t,data_a[t]]
                __betas_y = jax.vmap(__betas_y)

                _betas_y = lambda rs: np.concatenate((np.zeros(K)[None,...], __betas_y(rs, np.arange(T-1)).cumsum(axis=0)))
                _betas_y = jax.jit(_betas_y)
                _BETAS_Y = jax.jit(jax.vmap(_betas_y))

                # A dictionary of hyperparameters of the simulation
                hyper = dict()
                hyper['n_samples'] = 2
                hyper['iter'] = 2

                # initialize beta, rewards and rho_env (rhox)
                params = {'beta0': 10e-4}
                grad_mnsq = {'beta0': 10e-4}
                key = jax.random.PRNGKey(0)
                beta0_y, beta0_N = decode(params)
                rs = jax.random.normal(key, shape=(T,))
                rhox = -np.ones(K)/K

                for j in range(hyper['iter']): 
                    key, subkey = jax.random.split(key)
    
                    # sample rewards
                    if policy_index == 0:
                        RS = sample_rs_optimistic((rhox, data_x, data_a, beta0_N, beta0_y, rs),  jax.random.split(subkey, hyper['n_samples'])  )
                    elif policy_index == 1:
                        RS = sample_rs_softmax((rhox, data_x, data_a, beta0_N, beta0_y, rs),  jax.random.split(subkey, hyper['n_samples'])  )
                    elif policy_index == 2:
                        RS = sample_rs_ts((rhox, data_x, data_a, beta0_N, beta0_y, rs),  jax.random.split(subkey, hyper['n_samples'])  )
                    elif policy_index == 3:
                        RS = sample_rs_igw((rhox, data_x, data_a, beta0_N, beta0_y, rs),  jax.random.split(subkey, hyper['n_samples'])  )

                        
                    rs = RS.mean(axis=0)
                    rhox = compute_rhox(RS)
                    
                    # gradients
                    if policy_index == 0:
                        grad = grad_likelihood_optimistic(beta0_N, beta0_y, data_x, data_a, RS)
                    elif policy_index == 1:
                        grad = grad_likelihood_softmax(beta0_N, beta0_y, data_x, data_a, RS)
                    elif policy_index == 2:
                        grad = grad_likelihood_ts(beta0_N, beta0_y, data_x, data_a, RS)
                    elif policy_index == 3:
                        grad = grad_likelihood_igw(beta0_N, beta0_y, data_x, data_a, RS)
                    
                    grad_mnsq['beta0'] = .1 * grad['beta0']**2 + .9 * grad_mnsq['beta0']
                    params['beta0'] += .001 * grad['beta0'] / (np.sqrt(grad_mnsq['beta0']) + 1e-8)
                    beta0_y, beta0_N = decode(params)

                rhox = rhox / np.abs(rhox).sum()
    
                res = dict()
                res['rhox'] = rhox
                res['beta0_y'] = beta0_y
                res['beta0_N'] = beta0_N

                key, subkey = jax.random.split(key)
                RS = sample_rs((rs, rhox, data_x, data_a, beta0_N, beta0_y), jax.random.split(subkey, hyper['n_samples']))
                BETAS_Y = beta0_y + _BETAS_Y #(RS[:, :-1])
                betas_invN = np.linalg.inv(beta0_N + _betas_N)
                betas_mean = np.einsum('ijk,lik->lij', betas_invN, BETAS_Y).mean(axis=0)
                betas_cov = betas_invN * sigma**2

                res['betas_mean'] = betas_mean
                res['betas_cov'] = betas_invN
                
                print(betas_mean)
                
                # Save the results
                filename = f'{input_dir}/dataset_{i}_{algo}_{policy}_MODEL.dill'
                with open(filename, 'wb') as f:
                    dill.dump(res, f)