# Introduction to Reinforcement Learning

## Exercise: introducing state (contextual bandits)

In the _contextual bandit_ problem, there is a different reward distribution over the
actions for each state.  For simplicity, the number of states equals
the number of arms, but in general the state space is often much larger than the
action space. Here, we have $n$ different reward distributions
over actions for each of $n$ states. 

Instead of storing the rewards for each state-action pair, we use a neural
network to learn the relation between state-action and reward.

In [None]:
from jax import random, vmap
import numpy as np
from flax import nnx
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [None]:
def softmax(av, tau=1.12):
    softm = (jnp.exp(av / tau) / jnp.sum(jnp.exp(av / tau)))
    return softm

def one_hot(N, pos, val=1):
    one_hot_vec = jnp.zeros(N)
    one_hot_vec = one_hot_vec.at[pos].set(val)
    return one_hot_vec

def running_mean(x,N=50):
    c = x.shape[0] - N
    y = np.zeros(c)
    conv = np.ones(N)
    for i in range(c):
        y[i] = (x[i:i+N] @ conv)/N
    return y

1. Complete the implementation of the `ContextBandit` environment.

In [None]:
class ContextBandit:
    def __init__(self, seed=42, arms=10):
        key = random.key(seed)
        self.arms = arms
        self.init_distribution(key, arms)
        
    def init_distribution(self, key, arms):
        self.bandit_matrix = random.uniform(key, shape=(arms,arms))

    def reward(self, prob, key):
        subkeys = random.split(key, self.arms + 1)
        key = subkeys[0]

        rewards = vmap(lambda k: random.uniform(k) < prob)(subkeys[1:])
        
        return rewards.sum(), key 

    def update_state(self, key):
        key, subkey = random.split(key)
        state = random.randint(subkey, shape=(), minval=0, maxval=self.arms)
        return state, key
        
    def get_reward(self,arm,state,key):
        return self.reward(self.bandit_matrix[arm,state], key)
        
    def choose_arm(self, arm, state, key):
        reward, key = self.get_reward(arm, state, key)
        state, key = self.update_state(key)
        return reward, state, key

In [None]:
arms = 10
N, D_in, H, D_out = 1, arms, 100, arms
env = ContextBandit(arms)

2. Create a two-layer neural network that takes as input a one-hot encoded vector of the state and
   returns the values (rewards) associated to choosing each arm from that state.

In [None]:
class Model(nnx.Module):
  def __init__(self, D_in, D_out, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(D_in, H, rngs=rngs)
    self.linear2 = nnx.Linear(H, D_out, rngs=rngs)

  def __call__(self, x):
    y = nnx.relu(self.linear1(x))
    y = nnx.relu(self.linear2(y))
    return y


3. Define the training loop for the NN model, where at each epoch an action is chosen _probabilistically_
   based on the predicted rewards from the model and the model weights are updated based
   on the _actual_ reward obtained by taking that action.

In [None]:
@nnx.jit
def _train_epoch(model, optimizer, cur_state, key):
    oh_cur_state = one_hot(arms, cur_state)

    def loss(model, cur_state, key):
        y_pred = model(oh_cur_state)

        av_softmax = softmax(y_pred, tau=2.0)

        key, subkey = random.split(key)
        choice = random.choice(key=subkey, a=arms, p=av_softmax)
        cur_reward, cur_state, key = env.choose_arm(choice, cur_state, key)

        reward = y_pred.at[choice].set(cur_reward)
        MSE = optax.losses.squared_error(y_pred, reward).mean()
        return MSE, (cur_reward, cur_state, key)

    grads, res = nnx.grad(loss, has_aux=True)(model, cur_state, key)
    optimizer.update(grads)

    return res

def train(env, model, optimizer, epochs=5000):
    key = random.key(6)
    cur_state, key = env.update_state(key)

    rewards = []

    for _ in range(epochs):

        cur_reward, cur_state, key = _train_epoch(model, optimizer, cur_state, key)

        rewards.append(cur_reward)

    return np.array(rewards)

4. Train the model for 5000 epochs and plot the running mean (use the auxiliary function
   defined above) of the rewards.

In [None]:
model = Model(D_in, D_out, rngs=nnx.Rngs(0))

lr = 0.01 # learning rate
optimizer = nnx.ModelAndOptimizer(model, optax.adam(lr))

rewards = train(env, model, optimizer)
plt.plot(running_mean(rewards,N=500))
plt.show()