# Gambler's Problem

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_theme(style="darkgrid", palette="deep")
%matplotlib inline

![Value Iteration Pseudocode](images/value_iteration_pseudocode.jpeg)

In [None]:
NUM_STATES = 100
NUM_TERMINALS = 1
TRANSITION_PROBABILITY = 0.4
DISCOUNT_FACTOR = 1
THRESHOLD = 1e-50

In [None]:
def one_step_lookahead(p_h, s, V, rewards, gamma):
    A = np.zeros(NUM_STATES + NUM_TERMINALS)

    for a in range(1, min(s, NUM_STATES - s) + 1):
        A[a] = p_h * (rewards[s + a] + V[s + a] * gamma) + (1 - p_h) * (rewards[s - a] + V[s - a] * gamma)

    return A


def value_iteration_for_gamblers(p_h, gamma=0.99, theta=1e-5):
    rewards = np.zeros(NUM_STATES + NUM_TERMINALS) # 0-100
    rewards[100] = 1

    V = np.zeros(NUM_STATES + NUM_TERMINALS)
    num_iterations = 0

    history = {
        "state_values": [ V.copy() ],
        "delta" : []
    } 

    while True:
        delta = 0

        for s in range(1, NUM_STATES):
            best_action_value = V[s]
            V[s] = np.max(one_step_lookahead(p_h, s, V, rewards, gamma))
            delta = max(delta, np.abs(best_action_value - V[s]))

        num_iterations += 1
        history["state_values"].append(V.copy())
        history["delta"].append(delta)

        if delta <= theta:
            break


    print(f"Num Iterations: {num_iterations}")
    policy = np.zeros(NUM_STATES)
    for s in range(1, NUM_STATES):
        policy[s] = np.argmax(one_step_lookahead(p_h, s, V, rewards, gamma))

    return policy, V, history

In [None]:
policy, V, history = value_iteration_for_gamblers(TRANSITION_PROBABILITY, gamma=DISCOUNT_FACTOR, theta=THRESHOLD)

In [None]:
# Plotting Final Policy (action stake) vs State (Capital)
# The graph shows the value function found by successive sweeps of value iteration

plt.figure(figsize=(30, 5))

plt.title('Final Policy vs Stake')
plt.xlabel('Capital')
plt.ylabel('Value Estimate')

sns.lineplot(x=np.arange(100), y=V[:NUM_STATES])

In [None]:
# Plot Each Sweep(History)
plt.figure(figsize=(30, 10))

for idx, sweep in enumerate(history["state_values"]):
    sns.lineplot(x=np.arange(100), y=sweep[:NUM_STATES], label=f"Sweep {idx}")
    # print(sweep.shape)
    # break

plt.legend()

In [None]:
# Plotting Capital vs Final Policy
plt.figure(figsize=(30, 10))

plt.title('Capital vs Final Policy')  # title to the graph
plt.xlabel('Capital')                 # naming the x axis
plt.ylabel('Final policy (stake)')    # naming the y axis

sns.barplot(x=np.arange(100), y=policy)  # plotting the bars