# 🫰💵***Statistically* Rob the casino with** <img src='https://upload.wikimedia.org/wikipedia/commons/8/86/Google_JAX_logo.svg' alt="Environment" width="60" />

<div align="center">
<img src="https://images.unsplash.com/photo-1518895312237-a9e23508077d?auto=format&fit=crop&q=80&w=1784&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" width=1000>
</div>

## **1) Incremental update for K-armed Bandits**

In [31]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../")

from src import (K_armed_bandits, SimpleBandit, BanditEpsilonGreedy, 
                 bandits_rollout, bandits_parallel_rollout, 
                 bandits_multi_run_parallel_rollout)

In [32]:
SEED = 1
K = 10

key = random.PRNGKey(SEED)

agent = SimpleBandit()
policy = BanditEpsilonGreedy()
env = K_armed_bandits(K, SEED)
# env.render()

In [33]:
TIMESTEPS = 200
EPSILON = 0.1
bandits_q = random.normal(key, (env.K,))
q_values, pulls, _, rewards = bandits_rollout(key, TIMESTEPS, bandits_q, EPSILON, env, agent, policy)
px.line(rewards.cumsum()/(jnp.arange(len(rewards))+1), title=f"Single agent reward per step, epsilon = {EPSILON}")

Running for 200 iterations: 100%|██████████| 200/200 [00:00<00:00, 200014.50it/s]


In [34]:
EPSILONS = jnp.array([0, 0.1, 0.01])
N_ENV = len(EPSILONS)
q_values, pulls, _, rewards = bandits_parallel_rollout(key, TIMESTEPS, N_ENV, bandits_q, EPSILONS, env, agent, policy)

Running for 200 iterations: 100%|██████████| 200/200 [00:00<00:00, 199871.53it/s]


In [35]:
def plot_cumulative_rewards(rewards):
    steps = jnp.arange(len(rewards))
    fig = go.Figure()
    
    for i in range(rewards.shape[1]):
        fig.add_trace(go.Scatter(x=steps, y=jnp.cumsum(rewards[:,i], axis=0)/(steps+1), mode='lines', name=f"{str(EPSILONS[i])}"))
    fig.update_xaxes(title_text="Steps")
    fig.update_yaxes(title_text="Rewards")
    fig.update_layout(title_text="Average rewards of Epsilon Values", legend_title_text="Epsilon Values")
    fig.show()

plot_cumulative_rewards(rewards)

In [76]:
N_RUNS = 300
TIMESTEPS = 2500
EPSILONS = jnp.array([0, 0.2, 0.1, 0.05, 0.01])
N_ENV = len(EPSILONS)
bandits_q = random.normal(key, (env.K, N_RUNS))
q_values, pulls, _, rewards = bandits_multi_run_parallel_rollout(key, TIMESTEPS, N_ENV, N_RUNS, bandits_q, EPSILONS, env, agent, policy)

Running for 2,500 iterations: 100%|██████████| 2500/2500 [00:01<00:00, 1882.99it/s]


In [77]:
def plot_rewards(rewards, epsilons):
    steps = jnp.arange(rewards.shape[0])
    rewards_mean_df = pd.DataFrame({i:jnp.mean(rewards[:, i, :], axis=1).cumsum()/(steps+1) for i in range(len(epsilons))})
    rewards_std_df = pd.DataFrame({i:jnp.std(rewards[:,i,:], axis=1) for i in range(len(epsilons))})
    rewards_mean_df.columns = epsilons
    rewards_std_df.columns = epsilons

    fig = go.Figure()

    for idx, i in enumerate(rewards_mean_df.columns):
        fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i], mode='lines', name=f"Epsilon = {str(epsilons[idx])}"))
        # fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i]+rewards_std_df[i], name="Mean +- std", mode='lines', line=dict(width=0)))
        # fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i]-rewards_std_df[i], mode='lines', line=dict(width=0), fill='tonexty', showlegend=False))
    fig.update_xaxes(title_text="Steps")
    fig.update_yaxes(title_text="Averaged Rewards")
    fig.update_layout(title_text=f"Rewards of Epsilon Values (averaged over {N_RUNS} runs per Epsilon)", legend_title_text="Epsilon Values")

    return fig

plot_rewards(rewards, jnp.asarray(EPSILONS))