In [43]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import numpy as np
import copy
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

invlogit = lambda x: 1 / (1 + jnp.exp(-x))


class Environment:
    def __init__(self, thetas: np.array):
        self.thetas = thetas


class State:
    def __init__(self, n_arms: (int, int)):
        xv, yv = np.meshgrid(range(n_arms[0]), range(n_arms[1]))
        dim = n_arms[0] * n_arms[1]
        f1, f2 = yv.reshape((dim,)), xv.reshape((dim,))
        self.rewards = {(i, j): [] for i, j in zip(f1, f2)}


def model(f1, f2, obs=None):
    b0 = numpyro.sample('b0', dist.Normal(0, 1))
    b1 = numpyro.sample('b1', dist.Normal(0, 1).expand((2,)))
    b2 = numpyro.sample('b2', dist.Normal(0, 1).expand((2,)))

    y = numpyro.deterministic('y', b0 + b1[f1] + b2[f2])
    theta = numpyro.deterministic('theta', invlogit(y))
    o = numpyro.sample('o', dist.Bernoulli(theta), obs=obs)


def infer(state: State) -> Predictive:
    # Obtain the number of successes for each combination
    f1s = []
    f2s = []
    success = []
    for i, j in state.rewards:
        r = state.rewards[i, j]
        f1s.append(np.repeat(i, len(r)))
        f2s.append(np.repeat(j, len(r)))
        success.append(r)

    f1s = np.hstack(f1s)
    f2s = np.hstack(f2s)
    success = np.hstack(success)

    mcmc = MCMC(NUTS(model=model),
                num_warmup=500,
                num_samples=500,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), f1=f1s, f2=f2s, obs=success)
    return Predictive(model, mcmc.get_samples(), return_sites=['theta'])


class ContextualThompsonSampling:
    @classmethod
    def get_arm(cls, state: State, predictive: Predictive) -> (int, int):
        f1f2 = np.array(list(state.rewards.keys()))
        f1, f2 = f1f2[:,0], f1f2[:,1]

        pred = predictive(random.PRNGKey(1), f1=f1, f2=f2)

        num_sample = pred['theta'].shape[0]
        sidx = np.random.randint(0, num_sample + 1)
        idx = np.argmax(pred['theta'][sidx])

        print(f"Thetas: {pred['theta'][sidx]}")

        arms = list(zip(f1, f2))
        return arms[idx]


def react(env: Environment, arm: (int, int)) -> float:
    if np.random.random() < env.thetas[arm]:
        return 1.0
    else:
        return 0


def update(state: State, arm: (int, int), reward: float) -> State:
    s = copy.deepcopy(state)
    s.rewards[arm].append(reward)
    return s


In [44]:
thetas = np.array([[0.05, 0.08],
                   [0.05, 0.12]])

env = Environment(thetas)
state = State(thetas.shape)
predictive = infer(state)
states = [state]
rewards = []

num_trials = 50

for i in range(num_trials):
    predictive = infer(state)
    arm = ContextualThompsonSampling.get_arm(state, predictive)
    reward = react(env, arm)

    rewards.append(reward)

    state = update(state, arm, reward)
    states.append(state)

    if i % 10 == 0:
        print(f"Iteration {i}")
        print(arm)

print("Done")

Thetas: [0.45756787 0.20883837 0.6165215  0.3347021 ]
Iteration 0
(1, 0)
Thetas: [0.8020169  0.73338985 0.48370364 0.38882196]
Thetas: [0.03643825 0.0975935  0.08815541 0.21659763]
Thetas: [0.15478359 0.60812896 0.01976643 0.14594238]
Thetas: [0.58660907 0.808774   0.02170291 0.0620202 ]
Thetas: [0.09176698 0.12421461 0.3605927  0.4418507 ]
Thetas: [0.0967627  0.19786192 0.0789149  0.16476813]
Thetas: [0.21657933 0.05197893 0.09266794 0.01985371]
Thetas: [0.07968867 0.07043399 0.10047285 0.08903791]
Thetas: [0.04550152 0.2974365  0.21025273 0.70276564]
Thetas: [0.0130841  0.17081854 0.00990127 0.13449442]
Iteration 10
(0, 1)
Thetas: [0.06257033 0.3004934  0.0346259  0.18754938]
Thetas: [0.12392066 0.02383256 0.15533337 0.03076484]
Thetas: [0.03695567 0.03853837 0.0840483  0.08746465]
Thetas: [0.07400339 0.24444419 0.05131935 0.17965136]
Thetas: [0.07459759 0.15434635 0.18801771 0.34395182]
Thetas: [0.3061482  0.3189017  0.12130103 0.12777218]
Thetas: [0.07920347 0.1967578  0.01592388 0