# 🫰💵***Statistically* Rob the casino with** <img src='https://upload.wikimedia.org/wikipedia/commons/8/86/Google_JAX_logo.svg' alt="Environment" width="60" />

<div align="center">
<img src="https://images.unsplash.com/photo-1518895312237-a9e23508077d?auto=format&fit=crop&q=80&w=1784&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" width=1000>
</div>

## **1) Q-learning for K-armed Bandits**

In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../")

from src import K_armed_bandits, SimpleBandit, BanditEpsilonGreedy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 1
K = 10

key = random.PRNGKey(SEED)

agent = SimpleBandit()
policy = BanditEpsilonGreedy()
env = K_armed_bandits(K, SEED)
# env.render()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
def rollout(key:random.PRNGKey, timesteps:int, bandits_q:list, epsilon:int=0.1):
    @jit
    @loop_tqdm(timesteps)
    def fori_body(i:int, val:tuple):
        q_values, pulls, key, rewards = val
        action, key = policy(key, env.K, q_values, epsilon)
        reward, key = env.get_reward(key, action, bandits_q)
        q_values, pulls = agent(action, reward, pulls, q_values)
        rewards = rewards.at[i].set(reward)

        val = (q_values, pulls, key, rewards)

        return val

    q_values = jnp.zeros(env.K)
    pulls = jnp.zeros(env.K)  
    rewards = jnp.zeros(timesteps)

    val_init = (q_values, pulls, key, rewards)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)

    return val

TIMESTEPS = 80
bandits_q = random.normal(key, (env.K,))
q_values, pulls, _, rewards = rollout(key, TIMESTEPS, bandits_q, 0.1)

Running for 80 iterations: 100%|██████████| 80/80 [00:00<00:00, 79399.98it/s]


In [4]:
px.line(rewards.cumsum()/(jnp.arange(80)+1))

In [5]:

def parallel_rollout(key:random.PRNGKey, timesteps:int, n_env:int, bandits_q:list, epsilons:list):
    @jit
    @loop_tqdm(timesteps)
    def fori_body(i:int, val:tuple):
        q_values, pulls, keys, rewards = val
        action, keys = policy.batched_call(keys, env.K, q_values, epsilons)
        reward, keys = env.get_batched_reward(keys, action, bandits_q)
        q_values, pulls = agent.batch_update(action, reward, pulls, q_values)
        rewards = rewards.at[i].set(reward)

        val = (q_values, pulls, keys, rewards)

        return val

    keys = random.split(key, (n_env,))
    q_values = jnp.zeros([env.K, n_env])
    pulls = jnp.zeros([env.K, n_env])  
    rewards = jnp.zeros([timesteps, n_env])

    val_init = (q_values, pulls, keys, rewards)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)

    return val

N_ENV = 20
EPSILONS = jnp.array([0, 0.1, 0.01])
bandits_q = random.normal(key, (env.K,))

q_values, pulls, _, rewards = parallel_rollout(key, 2000, 3, bandits_q, EPSILONS)

Running for 2,000 iterations: 100%|██████████| 2000/2000 [00:00<00:00, 658911.95it/s]


In [6]:
px.line(jnp.cumsum(rewards[:,0], axis=0)/(jnp.arange(2000)+1))

In [56]:

def multi_run_parallel_rollout(key:random.PRNGKey, timesteps:int, n_env:int, n_runs:int, bandits_q:jnp.array, epsilons:list):
    @jit
    @loop_tqdm(timesteps)
    def fori_body(i:int, val:tuple):
        q_values, pulls, keys, rewards = val
        action, keys = policy.multi_run_batched_call(keys, env.K, q_values, epsilons)
        reward, keys = env.multi_run_batched_reward(keys, action, bandits_q)
        q_values, pulls = agent.multi_run_batch_update(action, reward, pulls, q_values)
        rewards = rewards.at[i].set(reward)

        val = (q_values, pulls, keys, rewards)

        return val

    keys = random.split(key, (n_env, n_runs))
    q_values = jnp.zeros([env.K, n_env, n_runs])
    pulls = jnp.zeros([env.K, n_env, n_runs])
    rewards = jnp.zeros([timesteps, n_env, n_runs])

    val_init = (q_values, pulls, keys, rewards)
    val = lax.fori_loop(0, timesteps, fori_body, val_init)

    return val

N_RUNS = 1000
TIMESTEPS = 1000
EPSILONS = jnp.array([0, 0.1, 0.01])
N_ENV = len(EPSILONS)
bandits_q = random.normal(key, (env.K, N_RUNS))
q_values, pulls, _, rewards = multi_run_parallel_rollout(key, TIMESTEPS, N_ENV, N_RUNS, bandits_q, EPSILONS)

Running for 1,000 iterations: 100%|██████████| 1000/1000 [00:00<00:00, 1021.56it/s]


In [57]:
def plot_rewards(rewards, epsilons):
    steps = jnp.arange(rewards.shape[0])

    rewards_mean_df = pd.DataFrame({i:jnp.mean(rewards[:, i, :], axis=1).cumsum()/(steps+1) for i in range(len(epsilons))})
    rewards_std_df = pd.DataFrame({i:jnp.std(rewards[:,i,:], axis=1) for i in range(len(epsilons))})

    rewards_mean_df.columns = epsilons
    rewards_std_df.columns = epsilons

    fig = go.Figure()

    for idx, i in enumerate(rewards_mean_df.columns):
        fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i], mode='lines', name=f"Epsilon = {str(epsilons[idx])}"))
        # fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i]+rewards_std_df[i], name="Mean +- std", mode='lines', line=dict(width=0)))
        # fig.add_trace(go.Scatter(x=steps, y=rewards_mean_df[i]-rewards_std_df[i], mode='lines', line=dict(width=0), fill='tonexty', showlegend=False))
    fig.update_xaxes(title_text="Steps")
    fig.update_yaxes(title_text="Averaged Rewards")

    fig.update_layout(title_text="Average rewards of Epsilon Values", legend_title_text="Epsilon Values")

    fig.show()

plot_rewards(rewards, jnp.asarray(EPSILONS))

In [None]:
import plotly.subplots as sp
import plotly.graph_objects as go

def plot_rewards(rewards):
    avg = jnp.mean(rewards.cumsum(axis=0), axis=1)
    std = jnp.std(rewards.cumsum(axis=0), axis=1)
    steps = jnp.arange(rewards.shape[0])

    # Create individual traces for the mean and std deviation
    mean_trace = go.Scatter(x=steps, y=avg, mode='lines', name='Mean Reward')
    upper_trace = go.Scatter(x=steps, y=avg+std, mode='lines', line=dict(width=0), showlegend=False)
    lower_trace = go.Scatter(x=steps, y=avg-std, mode='lines', line=dict(width=0), fill='tonexty', showlegend=False)
    
    return mean_trace, upper_trace, lower_trace

# Simulate some example rewards data for illustration
rewards1 = np.random.rand(1000, 20)
rewards2 = np.random.rand(1000, 20)

# Create a 1x2 subplot grid
fig = sp.make_subplots(rows=1, cols=2, subplot_titles=('Plot 1', 'Plot 2'))

# Plot the first set of rewards on the left subplot
for trace in plot_rewards(rewards1):
    fig.add_trace(trace, row=1, col=1)

# Plot the second set of rewards on the right subplot
for trace in plot_rewards(rewards2):
    fig.add_trace(trace, row=1, col=2)

# Update layout and show the combined subplot
fig.update_layout(title_text="Comparison of Rewards")
fig.update_xaxes(title_text="Steps")
fig.update_yaxes(title_text="Rewards")
fig.show()
