In [73]:
ReloadProject('deep_learning')

notebook_init.py imported and reloaded
forwarded symbol: Activation
forwarded symbol: Dense
forwarded symbol: Dict
forwarded symbol: List
forwarded symbol: Sequential
forwarded symbol: Tuple
reloaded: gym
forwarded symbol: gym
reloaded: keras
forwarded symbol: keras
reloaded: q_learning
forwarded symbol: q_learning
reloaded: q_learning_impl
forwarded symbol: q_learning_impl


## Example problem 1
Let's assume a world with 11 states: 0-10. Each time the agent and move +1 or -1, with 0-1 -> 10 and 10+1 -> 0. All actions that gets the agent closer to state "5" gets reward +1, otherwise gets reward -1.

In [38]:
TARGET_STATE = 5

class IntState(q_learning_impl.HashableState):
    
    def __init__(self, state: int):
        self.value = state
        
    def __hash__(self):
        return hash(self.value)
        
    def __str__(self):
        return '%d' % self.value


class IntAction(q_learning_impl.HashableAction):
    
    def __init__(self, action: int):
        self.value = action
        
    def __hash__(self):
        return hash(self.value)
        
    def __str__(self):
        return '%d' % self.value


class CircularWorld(q_learning.Environment):
    
    def __init__(self):
        super().__init__()
        self._current_state = IntState(0)
        self._last_action = IntAction(0)
        self._last_reward = 0.0
        
        self._action_plus = IntAction(1)
        self._action_minus = IntAction(-1)
        
    def Print(self):
        print('At: %s (last action: %s; last reward: %s)' % (
            self._current_state, self._last_action, self._last_reward))
    
    #@ Override
    def GetActionSpace(self) -> List[IntAction]:
        return [self._action_plus, self._action_minus]
        
    #@ Override
    def TakeAction(self, action: IntAction) -> None:
        if action == self._action_plus:
            if self._current_state.value < TARGET_STATE:
                self._last_reward = 1.0
            else:
                self._last_reward = -1.0
            
            new_state = self._current_state.value + 1
            if new_state == 11:
                new_state = 0
            self._current_state = IntState(new_state)
        else:
            if self._current_state.value > TARGET_STATE:
                self._last_reward = 1.0
            else:
                self._last_reward = -1.0
            
            new_state = self._current_state.value - 1
            if new_state == -1:
                new_state = 10
            self._current_state = IntState(new_state)
        self._last_action = action

Let's try out the environment.

In [27]:
env = CircularWorld()
for _ in range(20):
    env.Print()
    env.TakeAction(np.random.choice(env.GetActionSpace()))

At: 0 (last action: 0; last reward: 0.0)
At: 1 (last action: 1; last reward: 1.0)
At: 0 (last action: -1; last reward: -1.0)
At: 1 (last action: 1; last reward: 1.0)
At: 2 (last action: 1; last reward: 1.0)
At: 3 (last action: 1; last reward: 1.0)
At: 2 (last action: -1; last reward: -1.0)
At: 1 (last action: -1; last reward: -1.0)
At: 0 (last action: -1; last reward: -1.0)
At: 10 (last action: -1; last reward: -1.0)
At: 9 (last action: -1; last reward: 1.0)
At: 10 (last action: 1; last reward: -1.0)
At: 9 (last action: -1; last reward: 1.0)
At: 8 (last action: -1; last reward: 1.0)
At: 9 (last action: 1; last reward: -1.0)
At: 8 (last action: -1; last reward: 1.0)
At: 9 (last action: 1; last reward: -1.0)
At: 8 (last action: -1; last reward: 1.0)
At: 9 (last action: 1; last reward: -1.0)
At: 8 (last action: -1; last reward: 1.0)


Looks good! Now try out the learner.

In [81]:
%%time

env = CircularWorld()
qfunc = q_learning_impl.FiniteStateQFunction()
qfunc.SetLearningRate(0.9)
qfunc.SetDiscountFactor(0.9)
max_policy = q_learning_impl.MaxValuePolicy()

for _ in range(20000):
#     env.Print()
#     qfunc.Print()
    s = env.GetCurrentState()
    a = max_policy.Decide(qfunc, s, env.GetActionSpace())
    env.TakeAction(a)
    s_new = env.GetCurrentState()
    qfunc.UpdateWithTransition(s, a, env.GetLastReward(), s_new, env.GetActionSpace())
    
# qfunc.Print()

CPU times: user 420 ms, sys: 0 ns, total: 420 ms
Wall time: 499 ms


In [82]:
for state in range(10):
    for action in (-1, 1):
        print('(%d, %d): %s' % (state, action, qfunc.GetValue(IntState(state), IntAction(action))))

(0, -1): 0.0
(0, 1): 1.719
(1, -1): 0.0
(1, 1): 1.719
(2, -1): 0.0
(2, 1): 1.719
(3, -1): 0.0
(3, 1): 1.719
(4, -1): 0.0
(4, 1): 0.5263157894736845
(5, -1): -0.5263157894736838
(5, 1): -0.9
(6, -1): 0.0
(6, 1): -0.9
(7, -1): 0.0
(7, 1): -0.9
(8, -1): 0.0
(8, 1): -0.9
(9, -1): 0.0
(9, 1): -0.9
