# 🫰💵***Statistically* Rob the casino with** <img src='https://upload.wikimedia.org/wikipedia/commons/8/86/Google_JAX_logo.svg' alt="Environment" width="60" />

<div align="center">
<img src="https://images.unsplash.com/photo-1518895312237-a9e23508077d?auto=format&fit=crop&q=80&w=1784&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" width=1000>
</div>

## **1) Q-learning for K-armed Bandits**

In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../")

from src import K_armed_bandits, SimpleBandit, BanditEpsilonGreedy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 1
K = 10

key = random.PRNGKey(SEED)

agent = SimpleBandit()
policy = BanditEpsilonGreedy()
env = K_armed_bandits(K, SEED)
env.render()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
def rollout(key:random.PRNGKey, timesteps:int, epsilon:int=0.1):
    @jit
    @loop_tqdm(timesteps)
    def fori_body(i:int, val:tuple):
        q_values, pulls, key, rewards = val
        action, key = policy(key, env.K, q_values, epsilon)
        reward, key = env.get_reward(key, action)
        q_values, pulls = agent(action, reward, pulls, q_values)
        rewards = rewards.at[i].set(reward)

        val = (q_values, pulls, key, rewards)

        return val

    q_values = jnp.zeros(env.K)
    pulls = jnp.zeros(env.K)  
    rewards = jnp.zeros(timesteps)

    val_init = (q_values, pulls, key, rewards)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)

    return val


def parallel_rollout(key:random.PRNGKey, timesteps:int, n_env:int, epsilons:list):
    @jit
    @loop_tqdm(timesteps)
    def fori_body(i:int, val:tuple):
        q_values, pulls, keys, rewards = val
        action, keys = policy.batched_call(keys, env.K, q_values, epsilons)
        reward, keys = env.get_batched_reward(keys, action)
        q_values, pulls = agent.batch_update(action, reward, pulls, q_values)
        rewards = rewards.at[i].set(reward)

        val = (q_values, pulls, keys, rewards)

        return val

    keys = random.split(key, (n_env,))
    q_values = jnp.zeros([env.K, n_env])
    pulls = jnp.zeros([env.K, n_env])  
    rewards = jnp.zeros([timesteps, n_env])

    val_init = (q_values, pulls, keys, rewards)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)

    return val

def multi_run_parallel_rollout(key:random.PRNGKey, timesteps:int, n_env:int, epsilons:list, n_runs:int):
    @jit
    @loop_tqdm(timesteps)
    def fori_body(i:int, val:tuple):
        q_values, pulls, keys, rewards = val
        action, keys = policy.multi_run_batched_call(keys, env.K, q_values, epsilons)
        reward, keys = env.multi_run_batched_reward(keys, action)
        q_values, pulls = agent.multi_run_batch_update(action, reward, pulls, q_values)
        rewards = rewards.at[i].set(reward)

        val = (q_values, pulls, keys, rewards)

        return val

    keys = random.split(key, (n_env, n_runs))
    q_values = jnp.zeros([env.K, n_env, n_runs])
    pulls = jnp.zeros([env.K, n_env, n_runs])
    rewards = jnp.zeros([timesteps, n_env, n_runs])

    val_init = (q_values, pulls, keys, rewards)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)

    return val

In [53]:
q_values, pulls, _, rewards = multi_run_parallel_rollout(key, 2000, 3, jnp.array([0, 0.1, 0.01]), 400)

Running for 2,000 iterations: 100%|██████████| 2000/2000 [00:00<00:00, 2353.50it/s]


```python
TIMESTEPS = 80
q_values, pulls, _, rewards = rollout(key, TIMESTEPS, 0.1)
```

```python
N_ENV = 20
EPSILONS = jnp.array([0, 0.1, 0.01])
q_values, pulls, _, rewards = parallel_rollout(key, 2000, 3, EPSILONS)
```


In [56]:
rewards.shape

(2000, 3, 400)

In [63]:
def plot_rewards(rewards):
    steps = jnp.arange(rewards.shape[0])

    rewards_mean_df = pd.DataFrame({i:jnp.mean(rewards[:, i, :], axis=1).cumsum()/(steps+1) for i in range(3)})
    rewards_std_df = pd.DataFrame({i:jnp.std(rewards[:,i,:], axis=1) for i in range(3)})

    rewards_mean_df.columns = [0, 0.01, 0.1]
    rewards_std_df.columns = [0, 0.01, 0.1]

    fig = go.Figure()

    for i in rewards_mean_df.columns:
        fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i], mode='lines', name='Mean Reward'))
        # fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i]+rewards_std_df[i], name="Mean +- std", mode='lines', line=dict(width=0)))
        # fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i]-rewards_std_df[i], mode='lines', line=dict(width=0), fill='tonexty', showlegend=False))
    fig.update_xaxes(title_text="Steps")
    fig.update_yaxes(title_text="Rewards")
    fig.show()

plot_rewards(rewards)

In [34]:
data = {f"{i}":rewards[:,i,:] for i in range(3)}
pd.DataFrame(data['0'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,190,191,192,193,194,195,196,197,198,199
0,2.295416,1.638701,-2.158695,-0.726157,2.209134,0.734351,-1.688863,-1.741005,1.867153,0.446583,...,0.656266,1.986437,-0.218949,1.287741,1.477795,-0.818079,0.945033,2.064794,0.706929,2.284793
1,2.460150,0.352738,-0.053828,1.962221,-2.438420,-0.265484,2.115399,-1.205811,1.576765,0.194605,...,-0.368875,-1.025028,0.975543,1.788510,0.780696,0.679322,-0.205291,0.475319,-0.223128,2.249285
2,0.354897,0.779834,-0.402342,2.107031,-0.917160,0.028617,-0.095542,-1.018531,1.047869,-0.221299,...,-0.875024,0.269902,0.811034,1.057284,1.944880,-2.090024,0.217941,1.207711,-0.452977,2.631902
3,-0.169179,-0.109817,0.030765,0.712843,1.228589,-0.439902,0.125113,1.180972,2.168112,-0.442277,...,-1.394918,-0.451027,0.544867,1.558881,-0.050963,3.238135,0.104749,0.316205,1.119125,3.112452
4,0.733048,0.545624,0.817147,1.617034,1.043554,1.197874,2.573858,0.381002,0.512875,-0.807244,...,0.566142,0.760257,0.445508,0.815203,1.710527,2.118739,0.748613,-0.612048,-1.438865,4.836023
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
395,1.625942,-0.255183,1.142612,1.354964,2.427616,0.247349,3.093328,1.489966,1.907073,2.088707,...,-1.771759,3.060344,-0.931466,2.120260,-0.123870,0.509750,-1.526762,1.073409,-0.319103,2.916904
396,1.693455,-0.635093,3.369476,2.385907,2.127216,-1.120102,2.707844,1.254314,1.321576,-0.127467,...,1.032604,1.260469,1.115018,0.126437,1.776048,2.846959,0.775088,1.075865,0.326973,2.592893
397,1.738696,-0.669271,-0.806540,1.549094,1.279935,1.906394,0.378714,1.440222,0.315691,1.629780,...,1.019222,1.267558,-0.383829,2.646544,0.918643,-0.086406,0.692025,0.469836,-1.125751,2.592271
398,1.571340,1.501030,2.514209,0.526958,2.250140,0.004726,0.756870,0.984717,-0.189486,0.721694,...,0.430073,1.477166,0.574082,1.281561,0.687700,2.947315,1.381651,0.535880,1.131509,2.066348


In [None]:
rewards_df = pd.DataFrame(rewards, columns=[str(x) for x in EPSILONS])
rewards_df = rewards_df.apply(lambda x: x.cumsum()/jnp.arange(rewards_df.shape[0]))
rewards_df = rewards_df.iloc[1:]

px.line(rewards_df)

In [None]:
def plot_rewards(rewards):
    avg = jnp.mean(rewards.cumsum(axis=0), axis=1)
    std = jnp.std(rewards.cumsum(axis=0), axis=1)
    steps = jnp.arange(rewards.shape[0])

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=steps, y=avg, mode='lines', name='Mean Reward'))
    fig.add_trace(go.Scatter(x=steps, y=avg+std, name="Mean +- std", mode='lines', line=dict(width=0)))
    fig.add_trace(go.Scatter(x=steps, y=avg-std, mode='lines', line=dict(width=0), fill='tonexty', showlegend=False))
    fig.update_xaxes(title_text="Steps")
    fig.update_yaxes(title_text="Rewards")
    fig.show()

plot_rewards(rewards)

In [None]:
import plotly.subplots as sp
import plotly.graph_objects as go

def plot_rewards(rewards):
    avg = jnp.mean(rewards.cumsum(axis=0), axis=1)
    std = jnp.std(rewards.cumsum(axis=0), axis=1)
    steps = jnp.arange(rewards.shape[0])

    # Create individual traces for the mean and std deviation
    mean_trace = go.Scatter(x=steps, y=avg, mode='lines', name='Mean Reward')
    upper_trace = go.Scatter(x=steps, y=avg+std, mode='lines', line=dict(width=0), showlegend=False)
    lower_trace = go.Scatter(x=steps, y=avg-std, mode='lines', line=dict(width=0), fill='tonexty', showlegend=False)
    
    return mean_trace, upper_trace, lower_trace

# Simulate some example rewards data for illustration
rewards1 = np.random.rand(1000, 20)
rewards2 = np.random.rand(1000, 20)

# Create a 1x2 subplot grid
fig = sp.make_subplots(rows=1, cols=2, subplot_titles=('Plot 1', 'Plot 2'))

# Plot the first set of rewards on the left subplot
for trace in plot_rewards(rewards1):
    fig.add_trace(trace, row=1, col=1)

# Plot the second set of rewards on the right subplot
for trace in plot_rewards(rewards2):
    fig.add_trace(trace, row=1, col=2)

# Update layout and show the combined subplot
fig.update_layout(title_text="Comparison of Rewards")
fig.update_xaxes(title_text="Steps")
fig.update_yaxes(title_text="Rewards")
fig.show()
