In [1]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import numpy as np
import copy
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

invlogit = lambda x: 1 / (1 + jnp.exp(-x))


class Environment:
    def __init__(self, thetas: np.array):
        self.thetas = thetas


class State:
    def __init__(self, n_arms: (int, int)):
        xv, yv = np.meshgrid(range(n_arms[0]), range(n_arms[1]))
        dim = n_arms[0] * n_arms[1]
        f1, f2 = yv.reshape((dim,)), xv.reshape((dim,))
        self.rewards = {(i, j): [] for i, j in zip(f1, f2)}


def model(f1, f2, obs=None):
    b0 = numpyro.sample('b0', dist.Normal(0, 1))
    b1 = numpyro.sample('b1', dist.Normal(0, 1).expand((2,)))
    b2 = numpyro.sample('b2', dist.Normal(0, 1).expand((2,)))

    y = numpyro.deterministic('y', b0 + b1[f1] + b2[f2])
    theta = numpyro.deterministic('theta', invlogit(y))
    o = numpyro.sample('o', dist.Bernoulli(theta), obs=obs)


def infer(state: State) -> Predictive:
    # Obtain the number of successes for each combination
    f1s = []
    f2s = []
    success = []
    for i, j in state.rewards:
        r = state.rewards[i, j]
        f1s.append(np.repeat(i, len(r)))
        f2s.append(np.repeat(j, len(r)))
        success.append(r)

    f1s = np.hstack(f1s)
    f2s = np.hstack(f2s)
    success = np.hstack(success)

    mcmc = MCMC(NUTS(model=model),
                num_warmup=500,
                num_samples=500,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), f1=f1s, f2=f2s, obs=success)
    return Predictive(model, mcmc.get_samples(), return_sites=['theta'])


class ContextualThompsonSampling:
    @classmethod
    def get_arm(cls, state: State, predictive: Predictive) -> (int, int):
        f1f2 = np.array(list(state.rewards.keys()))
        f1, f2 = f1f2[:,0], f1f2[:,1]

        pred = predictive(random.PRNGKey(1), f1=f1, f2=f2)

        num_sample = pred['theta'].shape[0]
        sidx = np.random.randint(0, num_sample + 1)
        idx = np.argmax(pred['theta'][sidx])

        print(f"Thetas: {pred['theta'][sidx]}")

        arms = list(zip(f1, f2))
        return arms[idx]


def react(env: Environment, arm: (int, int)) -> float:
    if np.random.random() < env.thetas[arm]:
        return 1.0
    else:
        return 0


def update(state: State, arm: (int, int), reward: float) -> State:
    s = copy.deepcopy(state)
    s.rewards[arm].append(reward)
    return s


In [7]:
thetas = np.array([[0.05, 0.08],
                   [0.05, 0.12]])

env = Environment(thetas)
state = State(thetas.shape)
predictive = infer(state)
states = [state]
arms = []

num_trials = 1000

for i in range(num_trials):

        predictive = infer(state)
    arm = ContextualThompsonSampling.get_arm(state, predictive)
    arms.append(arm)

    reward = react(env, arm)

    state = update(state, arm, reward)
    states.append(state)

    if i % 10 == 0:
        print(f"Iteration {i}")
        print(arm)

print("Done")

Thetas: [0.79934865 0.4428238  0.92118174 0.69984806]
Iteration 0
(1, 0)
Thetas: [0.96761316 0.94044346 0.7905579  0.6661086 ]
Thetas: [0.64162725 0.454119   0.24509847 0.13108462]
Thetas: [0.62017095 0.23771104 0.2665544  0.06490526]
Thetas: [0.25681046 0.04745442 0.842295   0.4350315 ]
Thetas: [0.02095823 0.06386057 0.03490371 0.10333978]
Thetas: [0.37545097 0.28691334 0.5792936  0.47959974]
Thetas: [0.2509966  0.6370945  0.16942658 0.51658964]
Thetas: [0.21970473 0.51228154 0.06758771 0.21285103]
Thetas: [0.0474296  0.05478333 0.01066106 0.01238814]
Thetas: [0.47552246 0.21417098 0.13767917 0.04579615]
Iteration 10
(0, 0)
Thetas: [0.00555131 0.03376404 0.01636935 0.09434529]
Thetas: [0.02469887 0.08260234 0.15166093 0.38861272]
Thetas: [0.04632898 0.06569528 0.03571064 0.05087508]
Thetas: [0.0107466  0.07223204 0.01958844 0.12525621]
Thetas: [0.13336803 0.1004483  0.18698248 0.14301276]
Thetas: [0.23056997 0.0636043  0.23882985 0.06639901]
Thetas: [0.12428176 0.12261176 0.08649027 0

Thetas: [0.09786756 0.1247609  0.06079866 0.07839084]
Iteration 150
(0, 1)
Thetas: [0.1199282  0.14803317 0.10212747 0.12666124]
Thetas: [0.1317752  0.10768233 0.19144174 0.15843049]
Thetas: [0.10104877 0.09971213 0.1616108  0.15961526]
Thetas: [0.17746446 0.25176683 0.20472766 0.28646934]
Thetas: [0.10144088 0.13966173 0.1200622  0.1640185 ]
Thetas: [0.06930909 0.1263509  0.09393381 0.16759248]
Thetas: [0.06251808 0.10657991 0.17119828 0.26981157]
Thetas: [0.05796236 0.10870859 0.13360545 0.23411876]
Thetas: [0.08769003 0.07019671 0.1432729  0.11610245]
Thetas: [0.07924014 0.10593612 0.14320694 0.18707484]
Iteration 160
(1, 1)
Thetas: [0.05171628 0.12564473 0.08021811 0.18686137]
Thetas: [0.08963526 0.08230884 0.15119846 0.13961181]
Thetas: [0.11683512 0.06748455 0.1520882  0.08935346]
Thetas: [0.14176625 0.13098879 0.1709982  0.15840864]
Thetas: [0.07189699 0.06385522 0.09765116 0.08699877]
Thetas: [0.08519051 0.28291148 0.08569981 0.28423554]
Thetas: [0.1528246  0.1044727  0.1493679

In [8]:
rs = states[-1].rewards
print(f"Num (0, 0): {len(rs[0, 0])}")
print(f"Num (0, 1): {len(rs[0, 1])}")
print(f"Num (1, 0): {len(rs[1, 0])}")
print(f"Num (1, 1): {len(rs[1, 1])}")

Num (0, 0): 45
Num (0, 1): 51
Num (1, 0): 83
Num (1, 1): 121


In [None]:
!pip install altair

In [6]:
import altair as alt
import pandas as pd

points = []
for i, arm in enumerate(arms, 1):
    if arm == (0, 0):
        points.append((i, "(0, 0)"))
    elif arm == (0, 1):
        points.append((i, "(0, 1)"))
    elif arm == (1, 0):
        points.append((i, "(1, 0)"))
    elif arm == (1, 1):
        points.append((i, "(1, 1)"))

df = pd.DataFrame(points, columns=('Trial', 'Arm'))
alt.Chart(df).mark_point().encode(
    x='Trial:Q',
    y='Arm:O'
)

In [None]:
pred = predictive(random.PRNGKey(1), f1=np.array([0, 0, 1, 1]), f2=np.array([0, 1, 0, 1]))['theta']

pd.DataFrame(pred).describe()
# thetas' medians: 0.026, 0.083, 0.041, 0.124