In [1]:
import numpy as np
import random

In [15]:
class LinearChain:
    def __init__(self, n_states=10):
        self.n_states = n_states
        self.start = 0
        self.end = n_states - 1
        self.state = self.start
        self.actions = ['left', 'right']
    
    def reset(self):
        self.state = self.start
        return self.state
    
    def step(self, action):
        if action == 'left':
            self.state = max(0, self.state - 1)
        elif action == 'right':
            self.state = min(self.n_states - 1, self.state + 1)
    
        reward = 1 if self.state == self.end else 0
        done = self.state == self.end  
        return self.state, reward, done

In [16]:
#SARSA algorithme On-Policy
def sarsa(env, episodes=500, alpha=0.1, gamma=0.9, epsilon=0.1): 
    Q = {s: {a: 0 for a in env.actions} for s in range(env.n_states)}

    for episode in range(episodes): 
        state = env.reset()
        action = choose_action(state, Q, epsilon, env.actions)

        while True: 
            next_state, reward, done = env.step(action)
            next_action = choose_action(next_state, Q, epsilon, env.actions)

            # SARSA UPDATE
            Q[state][action] += alpha * (reward + gamma * Q[next_state][next_action] - Q[state][action])
            state, action = next_state, next_action

            if done: 
                break

    return Q

In [17]:
#epsilon-greedy policy
def choose_action(state, Q, epsilon, actions):
    if random.random() < epsilon:
        return random.choice(actions)
    else:
        return max(Q[state], key=Q[state].get)

In [24]:
env = LinearChain(n_states=6)
Q = sarsa(env, episodes=50)

In [25]:
for state, actions in Q.items():
    print(f"state {state} : {actions}")
state = env.reset()

print("\n optimal choice : ")
while state != env.end : 
    action = max(Q[state], key=Q[state].get)
    print(f"Etat {state}, Action {action}")
    state, _, _ = env.step(action)

state 0 : {'left': 0.018466953262478963, 'right': 0.32487462849253407}
state 1 : {'left': 0.029544628488536992, 'right': 0.4923427422517089}
state 2 : {'left': 0.030139855042335923, 'right': 0.658319833635631}
state 3 : {'left': 0.1370689627853191, 'right': 0.8690326987435225}
state 4 : {'left': 0.024858735844581897, 'right': 0.9948462247926799}
state 5 : {'left': 0, 'right': 0}

 optimal choice : 
Etat 0, Action right
Etat 1, Action right
Etat 2, Action right
Etat 3, Action right
Etat 4, Action right
