In [1]:
import numpy as np
from scipy.stats import multinomial

In [2]:
activities = ['cycling', 'picnic', 'climbing', 'movie']

# Create the transition matrices
T_rain = np.array([[0.15, 0.05, 0.40, 0.40],  # Cycling
                   [0.05, 0.15, 0.40, 0.40],  # Picnic
                   [0.05, 0.05, 0.8, 0.1],  # Climbing
                   [0.05, 0.05, 0.1, 0.8]])  # Movie

T_sunny = np.array([[0.70, 0.20, 0.05, 0.05],  # Cycling
                    [0.20, 0.70, 0.05, 0.05],  # Picnic
                    [0.40, 0.40, 0.10, 0.10],  # Climbing
                    [0.40, 0.40, 0.10, 0.10]])  # Movie

assert np.allclose(T_rain.sum(axis=1), 1)
assert np.allclose(T_sunny.sum(axis=1), 1)

In [3]:
user_utility = np.array([[0.1, 0.1, 0.7, 0.8],
                         [0.8, 0.8, 0.2, 0.1]])

In [4]:
# Create the synthetic data for rainy days
o0_rain = multinomial.rvs(1, T_rain[0], size=100).argmax(axis=1)
o1_rain = multinomial.rvs(1, T_rain[1], size=100).argmax(axis=1)
o2_rain = multinomial.rvs(1, T_rain[2], size=100).argmax(axis=1)
o3_rain = multinomial.rvs(1, T_rain[3], size=100).argmax(axis=1)

x_rain = np.concatenate([o0_rain, o1_rain, o2_rain, o3_rain])
weather_rain = np.zeros(400)

actions_rain = np.concatenate([
    np.zeros(100),
    np.ones(100),
    np.ones(100) * 2,
    np.ones(100) * 3
]).astype(int)

# Create the synthetic data for sunny days
o0_sunny = multinomial.rvs(1, T_sunny[0], size=100).argmax(axis=1)
o1_sunny = multinomial.rvs(1, T_sunny[1], size=100).argmax(axis=1)
o2_sunny = multinomial.rvs(1, T_sunny[2], size=100).argmax(axis=1)
o3_sunny = multinomial.rvs(1, T_sunny[3], size=100).argmax(axis=1)

actions_sunny = np.concatenate([
    np.zeros(100),
    np.ones(100),
    np.ones(100) * 2,
    np.ones(100) * 3
]).astype(int)

x_sunny = np.concatenate([o0_sunny, o1_sunny, o2_sunny, o3_sunny])
weather_sunny = np.ones(400)

x = np.concatenate([x_rain, x_sunny]).astype(int)
weather = np.concatenate([weather_rain, weather_sunny]).astype(int)
actions = np.concatenate([actions_rain, actions_sunny]).astype(int)

In [5]:
def response(U: np.ndarray, weather: int, activity: int) -> int:
    u_prob = U[weather, activity]
    if np.random.rand() < u_prob:
        return 1
    else:
        return 0


r_rain = np.array([response(user_utility, 0, activity) for activity in x_rain])
r_sunny = np.array([response(user_utility, 1, activity) for activity in x_sunny])
r = np.concatenate([r_rain, r_sunny])

# 1. Batch Estimation

In [6]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

rng_key = random.PRNGKey(12345)


def model(weather, action, x=None, r=None):
    raise NotImplementedError("Implement this model")


kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, thinning=1)
rng_key, rng_key_ = random.split(rng_key, 2)
mcmc.run(rng_key_, weather=weather, action=actions, x=x, r=r)
mcmc.print_summary()

sample: 100%|██████████| 3000/3000 [00:01<00:00, 1598.22it/s, 7 steps of size 4.71e-01. acc. prob=0.90]



                  mean       std    median      5.0%     95.0%     n_eff     r_hat
    p[0,0,0]      0.16      0.04      0.16      0.11      0.22   3257.30      1.00
    p[0,0,1]      0.08      0.03      0.07      0.04      0.12   3310.87      1.00
    p[0,0,2]      0.34      0.05      0.34      0.26      0.41   3360.48      1.00
    p[0,0,3]      0.42      0.05      0.42      0.34      0.50   3847.00      1.00
    p[0,1,0]      0.05      0.02      0.04      0.01      0.08   2895.75      1.00
    p[0,1,1]      0.15      0.03      0.15      0.09      0.21   3524.14      1.00
    p[0,1,2]      0.35      0.05      0.35      0.28      0.43   3450.34      1.00
    p[0,1,3]      0.45      0.05      0.45      0.37      0.53   3375.54      1.00
    p[0,2,0]      0.05      0.02      0.05      0.02      0.08   2846.07      1.00
    p[0,2,1]      0.05      0.02      0.05      0.02      0.08   3231.80      1.00
    p[0,2,2]      0.80      0.04      0.80      0.73      0.86   3121.99      1.00
   

In [154]:
predictive = Predictive(model, mcmc.get_samples(), return_sites=['x', 'utility'])


def expected_utility(rng_key, predictive, weather: int, action: int):
    pred = predictive(rng_key, weather=[weather], action=[action], x=None, r=None)
    trans_prob = jnp.histogram(pred['x'], bins=4)[0] / 2000
    util_est = jnp.mean(pred['utility'][:, weather, :], axis=0)
    EU = jnp.dot(trans_prob, util_est)
    return EU


def list_utilities(rng_key, weather: int):
    return [expected_utility(rng_key, predictive, weather, action) for action in range(4)]


rng_key, rng_key_ = random.split(rng_key, 2)
expected_utilities = list_utilities(rng_key_, 0)
print(f'EU for cycling on a rainy day: {expected_utilities[0]}')
print(f'EU for picnic on a rainy day: {expected_utilities[1]}')
print(f'EU for climbing on a rainy day: {expected_utilities[2]}')
print(f'EU for movie on a rainy day: {expected_utilities[3]}')

activity_to_recommend = activities[np.argmax(expected_utilities)]
print(f'Activity to recommend on a rainy day: {activity_to_recommend}')

rng_key, rng_key_ = random.split(rng_key, 2)
expected_utilities = list_utilities(rng_key_, 1)
print(f'EU for cycling on a sunny day: {expected_utilities[0]}')
print(f'EU for picnic on a sunny day: {expected_utilities[1]}')
print(f'EU for climbing on a sunny day: {expected_utilities[2]}')
print(f'EU for movie on a sunny day: {expected_utilities[3]}')

activity_to_recommend = activities[np.argmax(expected_utilities)]
print(f'Activity to recommend on a sunny day: {activity_to_recommend}')

EU for cycling on a rainy day: 0.6251339912414551
EU for picnic on a rainy day: 0.6124654412269592
EU for climbing on a rainy day: 0.675921618938446
EU for movie on a rainy day: 0.7231666445732117
Activity to recommend on a rainy day: movie
EU for cycling on a sunny day: 0.728053867816925
EU for picnic on a sunny day: 0.7407978773117065
EU for climbing on a sunny day: 0.6686433553695679
EU for movie on a sunny day: 0.6209140419960022
Activity to recommend on a sunny day: picnic


# 2. Online Estimation

In [7]:
# Random weather generator
def weather_generator():
    """Generate a random weather state"""
    if np.random.rand() < 0.5:
        return 0  # Rainy
    else:
        return 1  # Sunny