<a href="https://colab.research.google.com/github/JiujiaZ/restless_bandit_basics/blob/main/Whittle_index.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A toy examples of whittle index.

In [1]:
import numpy as np

In [2]:
# verify approximated WI with exact WI with a toy example from :
#   https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8945748

# p(a = 1, s, s')
p_1 = np.array([
    [.5, .5, 0, 0],
    [0, .5, .5, 0],
    [0, 0, .5, .5],
    [.5, 0, 0, .5],])

# p(a = 0, s, s')
p_0 = np.array([
    [.5, 0, 0, .5],
    [.5, .5, 0, 0],
    [0, .5, .5, 0],
    [0, 0, .5, .5],])

# p(a, s, s'):
transitions = np.array([p_0, p_1])

# reward vector corresponding to each states, (same for all actions)
R = np.array([-1, 0, 0, 1])

# exact WI:
WI_exact = [-.5, .5, 1, -1]

n_actions, n_states = transitions.shape[:-1]

In [3]:
def value_iteration(transitions, R, lamb_val, gamma, epsilon=1e-2):

    """
    value iteration for a MDS:

        @param transitions: transition matrix p[a, s, s']
        @param R:           reward for all states R[s]
        @param lamb_val:    agrangian multiplier associated with formulation of WI
        @param gamma:       discounted factor for reward
        @param epsilon:     tolerance to terminate value iteration

        @return Q_func:     Q_values for each state and actions [s, a]

    """

    n_actions, n_states = transitions.shape[:-1]
    value_func = np.random.rand(n_states)
    delta = np.ones((n_states))
    iters = 0

    while np.max(delta) >= epsilon:
        iters += 1
        orig_value_func = np.copy(value_func)

        Q_func = np.zeros((n_states, n_actions))
        for s in range(n_states):

            for a in range(n_actions):
                Q_func[s, a] += - a * lamb_val + R[s] + gamma * np.dot(transitions[a, s, :], value_func)
            value_func[s] = np.max(Q_func[s, :])

        delta = np.abs(orig_value_func - value_func)

    return Q_func


def whittle_index(transitions, state, R, gamma, lb, ub, subsidy_break, epsilon=1e-4):
    """
    whittle index for a specified state using binary search: https://arxiv.org/pdf/2205.15372.pdf

        @param transitions:     transition matrix p[a, s, s']
        @param state:           a single specified state \in [S]
        @param R:               reward for all states R[s]
        @param lamb_val:        lgrangian multiplier associated with formulation of WI
        @param gamma:           discounted factor for reward
        @param lb, ub:          initial lower / upper bound of WI
        @param subsidy_break:   lower tolerance of WI (if returned, decrease lb)
        @param epsilon:         tolerance to terminate binary search

        @return subsidy:        approximated whittle index for specified state
    """

    while abs(ub - lb) > epsilon:
        lamb_val = (lb + ub) / 2
        # print('lamb', lamb_val, lb, ub)

        # need to adjust initial lb
        if ub < subsidy_break:
            # print('breaking early!', subsidy_break, lb, ub)
            return -10

        Q_func = value_iteration(transitions, R, lamb_val, gamma)

        # binary search:
        action = np.argmax(Q_func[state, :])

        # Q(s, 0) > Q(s, 1)_{lamb_val}: lamb_val in smaller interval
        if action == 0:
            ub = lamb_val
        # Q(s, 0) < Q(s, 1)_{lamb_val}: lamb_val in bigger interval
        elif action == 1:
            lb = lamb_val
        else:
            raise Error(f'action not binary: {action}')

    subsidy = (ub + lb) / 2
    return subsidy

## check against exact WI

In [4]:
# check with exact WI:
WI_approx = list()
for state in range(4):
    WI_approx.append(whittle_index(transitions, state, R, gamma = 0.99, lb = -1, ub = 1, subsidy_break=-1))
print(f'exact WI {WI_exact}, approximated WI {WI_approx}')

# this means we perfer to pull arm in state 2, 1, 0, 3, respectively.
# eg: assume we have 4 arms, each in distinct state.
#     for budget constraint of 2, we will pull arm in state 2 and in state 1

exact WI [-0.5, 0.5, 1, -1], approximated WI [-0.497467041015625, 0.492523193359375, 0.997589111328125, -0.992401123046875]


## check against random policy

In [5]:
# N arms, each reprented by a MDS, one step simulation
def Arms_step(init_states, actions, transitions):
    """
        @param initial_states: vector with one entry 1
        @param actions:        vectors with arms should be pulled as 1, otherwise 0
        @param transitions:    transition matrix p[a, s, s']

        @param current_states: one_step randomized state based on p[s'| a ]
    """

    n_actions, n_states = transitions.shape[:-1]
    current_states = np.zeros_like(init_states)
    reward = 0

    for i, (s, a) in enumerate(zip(init_states, actions)):
        states = np.zeros((1, n_states))
        states[0,s] = 1

        pos_state = (states @ transitions[a]).reshape(-1)

        # sample through p[s'|a]
        current_states[i] = np.random.choice(n_states, size=1, p = pos_state)

    # compute reward from pulled arms:
    reward += (R[current_states] * actions).sum()

    return current_states, reward



In [6]:
init_states = np.array([2, 0, 1, 3])

# WI policy tells us to pull arm 0 and 3:
WI_actions = np.array([1, 0, 1, 0])
WI_currant_state, WI_reward = Arms_step(init_states, WI_actions, transitions)

# compare with random selection:
indx = np.random.choice(n_states, size=2, replace = False)
random_actions = np.zeros_like(WI_actions)
random_actions[indx] = 1
random_currant_state, random_reward = Arms_step(init_states, random_actions, transitions)

print(f'WI actions {WI_actions}, random actions {random_actions}')
print(f'WI reward {WI_reward}, random reward {random_reward}')


WI actions [1 0 1 0], random actions [1 0 1 0]
WI reward 1, random reward 1


In [7]:
# multiple simulations, average reward
T = 1000

WI_reward = 0
random_reward = 0

for t in range(T):
    _, reward = Arms_step(init_states, WI_actions, transitions)
    WI_reward += reward

    # compare with random selection:
    indx = np.random.choice(n_states, size=2, replace = False)
    random_actions = np.zeros_like(WI_actions)
    random_actions[indx] = 1
    _, reward = Arms_step(init_states, random_actions, transitions)
    random_reward += reward

WI_reward /= 1000
random_reward /= 1000

print(f'WI reward {WI_reward}, random reward {random_reward}')


WI reward 0.498, random reward 0.012
