In [1]:
import numpy as np
from typing import List
from tqdm import tqdm

In [2]:
class LineWorld:
    def __init__(self):
        self.agent_pos = 2

    def available_actions(self) -> List[int]:
        if self.agent_pos in [1, 2, 3]:
            return [0, 1]  # 0: left, 1: right
        return []

    def is_game_over(self) -> bool:
        return True if self.agent_pos in [0, 4] else False

    def state_id(self) -> int:
        return self.agent_pos

    def step(self, action: int):
        assert (not self.is_game_over())
        assert (action in self.available_actions())

        if action == 0:
            self.agent_pos -= 1
        else:
            self.agent_pos += 1

    def score(self) -> float:
        if self.agent_pos == 0:
            return -1.0
        if self.agent_pos == 4:
            return 1.0
        return 0.0

    def display(self):
        for i in range(5):
            print('X' if self.agent_pos == i else '_', end='')
        print()

    def reset(self):
        self.agent_pos = 2

In [3]:
env = LineWorld()

In [4]:
env.display()

__X__


In [5]:
env.available_actions()

[0, 1]

In [6]:
env.step(0)
env.display()

_X___


In [7]:
env.step(0)
env.display()

X____


In [8]:
env.score()

-1.0

In [9]:
env.reset()
env.display()

__X__


In [10]:
env.step(1)
env.display()

___X_


In [11]:
env.step(1)
env.display()

____X


In [12]:
env.score()

1.0

In [19]:
# QLearning (off policy TD control)
def naive_q_learning(env_type,
                     alpha: float = 0.1,
                     epsilon: float = 0.1,
                     gamma: float = 0.999,
                     nb_iter: int = 100000):
    Q = {}
    env = env_type()
    for it in tqdm(range(nb_iter)):
        env.reset()

        while not env.is_game_over():
            s = env.state_id()
            aa = env.available_actions()

            if s not in Q:
                Q[s] = {}
                for a in aa:
                    Q[s][a] = np.random.random()

            if np.random.random() < epsilon:
                a = np.random.choice(aa)
            else:
                q_s = [Q[s][a] for a in aa]
                best_a_index = np.argmax(q_s)
                a = aa[best_a_index]

            prev_score = env.score()
            env.step(a)
            r = env.score() - prev_score

            s_p = env.state_id()
            aa_p = env.available_actions()

            if env.is_game_over():
                target = r
            else:
                if s_p not in Q:
                    Q[s_p] = {}
                    for a_p in aa_p:
                        Q[s_p][a_p] = np.random.random()

                q_s_p = [Q[s_p][a_p] for a_p in aa_p]
                max_a_p = np.max(q_s_p)
                target = r + gamma * max_a_p

            Q[s][a] = (1 - alpha) * Q[s][a] + alpha * target

    Pi = {}
    print(Q)
    for s in Q.keys():
        best_a = None
        best_a_score = 0.0
        
        for a, a_score in Q[s].items():
            if best_a is None or best_a_score <= a_score:
                best_a = a
                best_a_score = a_score

        Pi[s] = best_a

    return Pi, Q

In [20]:
naive_q_learning(LineWorld)

100%|██████████| 100000/100000 [00:07<00:00, 12599.53it/s]

{2: {0: 0.9970029989999992, 1: 0.9989999999999996}, 1: {0: -0.9999999999999878, 1: 0.9980009999999991}, 3: {0: 0.9980009999999991, 1: 0.9999999999999994}}





({2: 1, 1: 1, 3: 1},
 {2: {0: 0.9970029989999992, 1: 0.9989999999999996},
  1: {0: -0.9999999999999878, 1: 0.9980009999999991},
  3: {0: 0.9980009999999991, 1: 0.9999999999999994}})