# Jeu du 421

Ce notebook montre quelques algorithmes d'apprentissage par renforcement appliqués au jeu du [421](https://fr.wikipedia.org/wiki/421_(jeu)) :
* Value Iteration
* Policy Iteration
* SARSA
* Q-learning

L'objectif est ici de maximiser son score moyen en trois coups au plus.

## Init

In [36]:
import numpy as np

In [37]:
# valeurs des dés
dice = np.arange(1,7)

In [40]:
# nombre de lancers restants
throws = np.arange(3)

In [4]:
def get_scores():
    scores = {(4,2,1): 11, (1,1,1): 7, (2,2,1): 0}
    scores.update({(d,1,1): d for d in dice if d > 1})
    scores.update({(d,d,d): d for d in dice if d > 1})
    scores.update({(d,d - 1,d - 2): 2 for d in dice if d > 2})
    specials = list(scores.keys())
    scores.update({(d,e,f): 1 for d in dice for e in dice if e <= d for f in dice if f <= e and (d,e,f) not in specials})
    return scores

In [5]:
scores = get_scores()

In [6]:
# les dés que l'on a
states = list(scores.keys())

In [7]:
len(states)

56

In [8]:
# les dés que l'on relance
actions = [(a,b,c) for a in range(2) for b in range(2) for c in range(2)]

In [44]:
actions

[(0, 0, 0),
 (0, 0, 1),
 (0, 1, 0),
 (0, 1, 1),
 (1, 0, 0),
 (1, 0, 1),
 (1, 1, 0),
 (1, 1, 1)]

In [9]:
len(actions)

8

In [10]:
def get_reward(state):
    return scores[state]

## Value iteration

In [11]:
def transition_prob(state, action):
    prob = {}
    nb = np.sum(action)
    for i in range(6**nb):
        new_state = np.array(state)
        die = i
        for j in range(3):
            if action[j]:
                new_state[j] = die % 6 + 1
                die = die // 6
        new_state = tuple(sorted(new_state,reverse = True))
        if new_state in prob:
            prob[new_state] += 1 / 6**nb
        else:
            prob[new_state] = 1 / 6**nb
    return prob

In [12]:
def value_iteration(tol = 1e-2, max_iter = 50):
    '''
    Returns the value function
        (state, throw) = état + nombre de lancers restants
    
    Parameters
    ----------
    tol: float
        tolerance for convergence
    max_iter: int
        max number of iterations
    '''
    V = {(state, throw): 0 for state in states for throw in throws}
    error = np.inf
    iteration = 0
    while error > tol:
        for state, throw in V:
            if throw == 0:
                V[(state, throw)] = get_reward(state)
            else:
                max_value = 0
                for action in actions:
                    value = 0
                    prob = transition_prob(state, action)
                    for s in prob:
                        value += prob[s] * V[(s,throw - 1)]
                    if value > max_value:
                        max_value = value
                error_ = np.abs(max_value - V[(state, throw)])
                if error_ > error:
                    error = error_
                V[(state, throw)] = max_value
        iteration += 1
        if iteration >= max_iter:
            break
    return V

In [13]:
value = value_iteration()

In [14]:
def best_policy(value):
    policy = {}
    for state, throw in value:
        if throw > 0:
            v_max = 0
            for action in actions:
                v = 0
                prob = transition_prob(state, action)
                for s in prob:
                    v += prob[s] * value[(s,throw - 1)]
                if v > v_max:
                    v_max = v
                    best_action = action
            policy[(state, throw)] = best_action
    return policy

In [15]:
policy = best_policy(value)

In [16]:
def get_action(state, throw, policy):
    if throw == 0:
        return (0,0,0)
    else:
        return policy[(state, throw)]    

In [17]:
def random_state():
    a = np.random.choice(6) + 1
    b = np.random.choice(6) + 1
    c = np.random.choice(6) + 1
    return tuple(sorted((a,b,c),reverse = True))

In [18]:
def move(state, action):
    prob = transition_prob(state,action)
    i = np.random.choice(np.arange(len(prob)), p = list(prob.values()))
    return list(prob.keys())[i]

In [19]:
def score_policy(policy, runs = 1000):
    score = 0
    for i in range(runs):
        # play the game
        throw = 2
        state = random_state()
        action = get_action(state, throw, policy)
        sequence = []
        while action != (0,0,0):
            throw -= 1
            state = move(state, action)
            action = get_action(state, throw, policy)
        score += get_reward(state)
    return score / runs

In [60]:
score_policy(policy)

3.314

## Policy iteration

In [21]:
def init_random_policy():
    policy = {(state, throw): actions[np.random.choice(len(actions))] for state in states for throw in throws}
    return policy

In [22]:
def policy_evaluation(policy, tol = 1e-2, max_iter = 50):
    '''
    tol: float
        tolerance for convergence
    max_iter: float
        max number of iterations
    '''
    V = {(state, throw): 0 for state in states for throw in throws}
    error = np.inf
    iteration = 0
    while error > tol:
        for state, throw in V:
            if throw == 0:
                V[(state, throw)] = get_reward(state)
            else:
                action = policy[(state, throw)]
                new_throw = throw - 1
                new_state = move(state, action)
                if action == (0,0,0):
                    V[(state, throw)] = get_reward(state)
                else:
                    value = get_reward(new_state) + V[(new_state, new_throw)]
                    error_ = np.abs(value - V[(state, throw)])
                    if error_ > error:
                        error = error_
                    V[(state, throw)] = value
        iteration += 1
        if iteration >= max_iter:
            break
    return V

In [52]:
policy = init_random_policy()

In [53]:
V = policy_evaluation(policy)

In [54]:
policy = best_policy(V)

In [55]:
V = policy_evaluation(policy)

In [56]:
policy = best_policy(V)

In [57]:
score_policy(policy)

3.146

## SARSA

Now the transition probabilities are supposed to be unknown and need to be learned.

In [29]:
def get_action_explore(state, throw, policy, eps = 0.1):
    '''
    eps: float
        Parameter of the epsilon-greedy policy
        Controls the exploration (eps = 0 means no exploration)
    '''
    if throw == 0:
        return (0,0,0)
    elif np.random.random() < eps:
        return actions[np.random.choice(len(actions))]
    else:
        return policy[(state, throw)]    

In [30]:
def sarsa(init_policy = 'random', alpha = 0.9, steps = 10**5, verbose = True, batch = 10**4):
    if init_policy == 'random':
        policy = init_random_policy()
    elif init_policy == 'all':
        policy = {(state, throw): (1,1,1)  for state in states for throw in throws}
    elif init_policy == 'none':
        policy = {(state, throw): (0,0,0)  for state in states for throw in throws}
    Q = {(state, throw, action): 0 for state in states for throw in throws for action in actions}
    for t in range(steps):
        if verbose: 
            if t%batch == 0:
                print('Batch ',str(t // batch + 1),' over ',str(steps // batch))
        # play the game
        state = random_state()
        throw = 2
        action = get_action_explore(state, throw, policy)
        sequence = []
        while action != (0,0,0):
            sequence.append((state, throw))
            new_throw = throw - 1
            new_state = move(state, action)
            new_action = get_action_explore(new_state, new_throw, policy)
            Q[(state, throw, action)] += alpha * (Q[(new_state, new_throw, new_action)] 
                                                  - Q[(state, throw, action)]) 
            throw = new_throw
            state = new_state
            action = new_action
        sequence.append((state, throw))
        Q[(state, throw, action)] += alpha * (get_reward(state) - Q[(state, throw, action)]) 
        # update policy
        for state, throw in sequence:
            qvalues = {a: Q[(state, throw, a)] for a in actions}
            policy[(state, throw)] = max(qvalues, key = qvalues.get)
    return policy

In [31]:
sarsa_policy = sarsa()

Batch  1  over  10
Batch  2  over  10
Batch  3  over  10
Batch  4  over  10
Batch  5  over  10
Batch  6  over  10
Batch  7  over  10
Batch  8  over  10
Batch  9  over  10
Batch  10  over  10


In [32]:
score_policy(sarsa_policy)

2.922

## Q-learning

In [33]:
def qlearning(init_policy = 'random', alpha = 0.9, steps = 10**5, verbose = True, batch = 10**4):
    if init_policy == 'random':
        policy = init_random_policy()
    elif init_policy == 'all':
        policy = {(state, throw): (1,1,1)  for state in states for throw in throws}
    elif init_policy == 'none':
        policy = {(state, throw): (0,0,0)  for state in states for throw in throws}
    Q = {(state, throw, action): 0 for state in states for throw in throws for action in actions}
    for t in range(steps):
        if verbose:
            if t%batch == 0:
                print('Batch ',str(t // batch + 1),' over ',str(steps // batch))
        # play the game
        state = random_state()
        throw = 2
        action = get_action_explore(state, throw, policy)
        sequence = []
        while action != (0,0,0):
            sequence.append((state, throw))
            new_throw = throw - 1
            new_state = move(state, action)
            qvalues = [Q[(new_state, new_throw, a)] for a in actions]
            Q[(state, throw, action)] += alpha * (np.max(np.array(qvalues)) - 
                                                  Q[(state, throw, action)]) 
            throw = new_throw
            state = new_state
            action = get_action_explore(state, throw, policy)
        sequence.append((state, throw))
        Q[(state, throw, action)] += alpha * (get_reward(state) - Q[(state, throw, action)]) 
        # update policy
        for state, throw in sequence:
            qvalues = {a: Q[(state, throw, a)] for a in actions}
            policy[(state, throw)] = max(qvalues, key = qvalues.get)
    return policy

In [34]:
qlearning_policy = qlearning()

Batch  1  over  10
Batch  2  over  10
Batch  3  over  10
Batch  4  over  10
Batch  5  over  10
Batch  6  over  10
Batch  7  over  10
Batch  8  over  10
Batch  9  over  10
Batch  10  over  10


In [63]:
score_policy(qlearning_policy)

2.886