In [1]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import numpy as np
import copy
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

invlogit = lambda x: 1 / (1 + jnp.exp(-x))


class Environment:
    def __init__(self, thetas: np.array):
        self.thetas = thetas


class State:
    def __init__(self, n_arms: (int, int)):
        xv, yv = np.meshgrid(range(n_arms[0]), range(n_arms[1]))
        dim = n_arms[0] * n_arms[1]
        f1, f2 = yv.reshape((dim,)), xv.reshape((dim,))
        self.rewards = {(i, j): [] for i, j in zip(f1, f2)}


def model(f1, f2, obs=None):
    b0 = numpyro.sample('b0', dist.Normal(0, 1))
    b1 = numpyro.sample('b1', dist.Normal(0, 1).expand((2,)))
    b2 = numpyro.sample('b2', dist.Normal(0, 1).expand((2,)))

    y = numpyro.deterministic('y', b0 + b1[f1] + b2[f2])
    theta = numpyro.deterministic('theta', invlogit(y))
    o = numpyro.sample('o', dist.Bernoulli(theta), obs=obs)


def infer(state: State) -> Predictive:
    # Obtain the number of successes for each combination
    f1s = []
    f2s = []
    success = []
    for i, j in state.rewards:
        r = state.rewards[i, j]
        f1s.append(np.repeat(i, len(r)))
        f2s.append(np.repeat(j, len(r)))
        success.append(r)

    f1s = np.hstack(f1s)
    f2s = np.hstack(f2s)
    success = np.hstack(success)

    mcmc = MCMC(NUTS(model=model),
                num_warmup=500,
                num_samples=500,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), f1=f1s, f2=f2s, obs=success)
    return Predictive(model, mcmc.get_samples(), return_sites=['theta'])


class ContextualThompsonSampling:
    @classmethod
    def get_arm(cls, state: State, predictive: Predictive) -> (int, int):
        f1f2 = np.array(list(state.rewards.keys()))
        f1, f2 = f1f2[:,0], f1f2[:,1]

        pred = predictive(random.PRNGKey(1), f1=f1, f2=f2)

        num_sample = pred['theta'].shape[0]
        sidx = np.random.randint(0, num_sample + 1)
        idx = np.argmax(pred['theta'][sidx])

        print(f"Thetas: {pred['theta'][sidx]}")

        arms = list(zip(f1, f2))
        return arms[idx]


def react(env: Environment, arm: (int, int)) -> float:
    if np.random.random() < env.thetas[arm]:
        return 1.0
    else:
        return 0


def update(state: State, arm: (int, int), reward: float) -> State:
    s = copy.deepcopy(state)
    s.rewards[arm].append(reward)
    return s


In [15]:
thetas = np.array([[0.05, 0.08],
                   [0.05, 0.12]])

env = Environment(thetas)
state = State(thetas.shape)
predictive = infer(state)
states = [state]
arms = []

num_trials = 100

for i in range(num_trials):
    predictive = infer(state)
    arm = ContextualThompsonSampling.get_arm(state, predictive)
    arms.append(arm)

    reward = react(env, arm)

    state = update(state, arm, reward)
    states.append(state)

    if i % 10 == 0:
        print(f"Iteration {i}")
        print(arm)

print("Done")

Thetas: [0.9325334 0.8801061 0.7110399 0.5665043]
Iteration 0
(0, 0)
Thetas: [0.30019128 0.06904013 0.1606878  0.0320383 ]
Thetas: [0.18947504 0.11568389 0.4623024  0.3248425 ]
Thetas: [0.52698886 0.12759778 0.71693903 0.24953386]
Thetas: [0.6599344  0.7032388  0.6542885  0.69798285]
Thetas: [0.24007818 0.01231242 0.31800884 0.01806685]
Thetas: [0.01321099 0.0137184  0.00978718 0.01016444]
Thetas: [0.19562158 0.11958963 0.14790834 0.08838337]
Thetas: [0.16256325 0.05029975 0.38378498 0.14524643]
Thetas: [0.01289794 0.00739373 0.11342165 0.06797287]
Thetas: [0.05647401 0.0410629  0.16427289 0.12328855]
Iteration 10
(1, 0)
Thetas: [0.04669187 0.06634442 0.03623958 0.05173151]
Thetas: [0.09772017 0.2143871  0.20124239 0.38831237]
Thetas: [0.03196008 0.08344987 0.16918679 0.35962626]
Thetas: [0.03952746 0.10640036 0.13470155 0.31053227]
Thetas: [0.01663862 0.02964344 0.13855177 0.22503816]
Thetas: [0.07570603 0.03312276 0.00827    0.00347564]
Thetas: [0.26084852 0.2579002  0.1898985  0.187

In [16]:
rs = states[-1].rewards
print(f"Num (0, 0): {len(rs[0, 0])}")
print(f"Num (0, 1): {len(rs[0, 1])}")
print(f"Num (1, 0): {len(rs[1, 0])}")
print(f"Num (1, 1): {len(rs[1, 1])}")

Num (0, 0): 12
Num (0, 1): 17
Num (1, 0): 18
Num (1, 1): 53


In [38]:
!pip install altair

Collecting vega_datasets
  Using cached vega_datasets-0.9.0-py3-none-any.whl (210 kB)
Installing collected packages: vega_datasets
Successfully installed vega_datasets-0.9.0


In [47]:
import altair as alt
import pandas as pd

points = []
for i, arm in enumerate(arms, 1):
    if arm == (0, 0):
        points.append((i, "(0, 0)"))
    elif arm == (0, 1):
        points.append((i, "(0, 1)"))
    elif arm == (1, 0):
        points.append((i, "(1, 0)"))
    elif arm == (1, 1):
        points.append((i, "(1, 1)"))

df = pd.DataFrame(points, columns=('Trial', 'Arm'))
alt.Chart(df).mark_point().encode(
    x='Trial:Q',
    y='Arm:O'
)

In [55]:
pred = predictive(random.PRNGKey(1), f1=np.array([0, 0, 1, 1]), f2=np.array([0, 1, 0, 1]))['theta']

pd.DataFrame(pred).describe()
# thetas' medians: 0.026, 0.083, 0.041, 0.124

Unnamed: 0,0,1,2,3
count,500.0,500.0,500.0,500.0
mean,0.036674,0.095659,0.050601,0.127215
std,0.035155,0.059443,0.037836,0.043267
min,0.000778,0.006732,0.003853,0.032018
25%,0.012971,0.050888,0.022766,0.096846
50%,0.02588,0.08278,0.041129,0.124309
75%,0.046706,0.129465,0.066263,0.153934
max,0.221464,0.401146,0.215675,0.311705
