In [81]:
ReloadProject('deep_learning')

notebook_init.py imported and reloaded
forwarded symbol: Activation
forwarded symbol: Dense
forwarded symbol: Dict
forwarded symbol: InputLayer
forwarded symbol: List
forwarded symbol: Model
forwarded symbol: Sequential
forwarded symbol: Tuple
reloaded: gym
forwarded symbol: gym
reloaded: keras
forwarded symbol: keras
reloaded: q_learning
forwarded symbol: q_learning
reloaded: q_learning_impl
forwarded symbol: q_learning_impl
reloaded: q_learning_impl_v2
forwarded symbol: q_learning_impl_v2
reloaded: q_learning_v2
forwarded symbol: q_learning_v2


## Environment Setup
Let's assume a world with 11 states: 0-10. Each time the agent and move +1 or -1, with 0-1 -> 10 and 10+1 -> 0. All actions that gets the agent closer to state "5" gets reward +1, otherwise gets reward -1.

In [2]:
STATE_ZERO_ARRAY = np.zeros(1, dtype=int)
TARGET_STATE = 5


class CircularWorld(q_learning_v2.Environment):
    
    def __init__(self):
        super().__init__(state_array_size=1, action_space_size=3)
                
        self.debug_verbosity = 0
        
        # action encoding
        self._action_minus = 0
        self._action_stay = 1
        self._action_plus = 2

        
    #@ Override
    def TakeAction(self, action: q_learning_v2.Action) -> q_learning_v2.Reward:
        current_state = self.GetState()
        new_state = current_state
        reward = 0
        if action == self._action_plus:
            if current_state < TARGET_STATE:
                reward = 1.0
            else:
                reward = -1.0
            new_state = current_state + 1
            if new_state == 11:
                new_state = STATE_ZERO_ARRAY
        elif action == self._action_minus:
            if current_state > TARGET_STATE:
                reward = 1.0
            else:
                reward = -1.0
            new_state = current_state - 1
            if new_state == -1:
                new_state = STATE_ZERO_ARRAY + 10
        else:
            if current_state != TARGET_STATE:
                reward = -1.0
            else:
                reward = 1.0

        self._protected_SetState(new_state)
        if self.debug_verbosity >= 1:
            print('Action %s: (%s) -> (%s), reward: %s' % (
                action, current_state, new_state, reward))
        return reward

Let's try out the environment.

In [3]:
env = CircularWorld()
env.debug_verbosity = 10
for _ in range(20):
    env.TakeAction(np.random.choice(env.GetActionSpace()))

Action 0: ([0.]) -> ([10]), reward: -1.0
Action 0: ([10]) -> ([9]), reward: 1.0
Action 0: ([9]) -> ([8]), reward: 1.0
Action 2: ([8]) -> ([9]), reward: -1.0
Action 1: ([9]) -> ([9]), reward: -1.0
Action 1: ([9]) -> ([9]), reward: -1.0
Action 2: ([9]) -> ([10]), reward: -1.0
Action 1: ([10]) -> ([10]), reward: -1.0
Action 2: ([10]) -> ([0]), reward: -1.0
Action 2: ([0]) -> ([1]), reward: 1.0
Action 2: ([1]) -> ([2]), reward: 1.0
Action 1: ([2]) -> ([2]), reward: -1.0
Action 2: ([2]) -> ([3]), reward: 1.0
Action 2: ([3]) -> ([4]), reward: 1.0
Action 2: ([4]) -> ([5]), reward: 1.0
Action 1: ([5]) -> ([5]), reward: 1.0
Action 1: ([5]) -> ([5]), reward: 1.0
Action 2: ([5]) -> ([6]), reward: -1.0
Action 2: ([6]) -> ([7]), reward: -1.0
Action 1: ([7]) -> ([7]), reward: -1.0


## Learning

In [82]:
%%time

env = CircularWorld()
qfunc = q_learning_impl_v2.KerasModelQFunction(
    env, (6, 6, 6), learning_rate=0.9, discount_factor=0.9)
policy = q_learning_impl_v2.MaxValueWithRandomnessPolicy(certainty = 0.95)

env.debug_verbosity = 5
qfunc.SetDebugVerbosity(5)
policy.debug_verbosity = 5

# First train qfunc with a random policy.
for _ in range(40):
    s = env.GetState()
    a = policy.Decide(qfunc, s, env.GetActionSpace())
    r = env.TakeAction(a)
    s_new = env.GetState()
    qfunc.UpdateWithTransition(s, a, r, s_new)
    
# Then see its action using a max policy.
for _ in range(0):
    s = env.GetState()
    a = policy.Decide(qfunc, s, env.GetActionSpace())
    r = env.TakeAction(a)
    s_new = env.GetState()
    qfunc.UpdateWithTransition(s, a, r, s_new)


GET: ([0.], 0) -> [[-0.05694097]]
GET: ([0.], 1) -> [[0.00800415]]
GET: ([0.], 2) -> [[-0.07984762]]
Action 1: ([0.]) -> ([0.]), reward: -1.0
GET: ([0.], 0) -> [[-0.05694097]]
GET: ([0.], 1) -> [[0.00800415]]
GET: ([0.], 2) -> [[-0.07984762]]
GET: ([0.], 1) -> [[0.00800415]]
SET: ([0.], 1) <- [[-0.89271617]]
GET: ([0.], 0) -> [[-0.09520315]]
GET: ([0.], 1) -> [[-0.04586127]]
GET: ([0.], 2) -> [[-0.13251716]]
Action 1: ([0.]) -> ([0.]), reward: -1.0
GET: ([0.], 0) -> [[-0.09520315]]
GET: ([0.], 1) -> [[-0.04586127]]
GET: ([0.], 2) -> [[-0.13251716]]
GET: ([0.], 1) -> [[-0.04586127]]
SET: ([0.], 1) <- [[-0.9417337]]
GET: ([0.], 0) -> [[-0.12424774]]
GET: ([0.], 1) -> [[-0.08418848]]
GET: ([0.], 2) -> [[-0.1849702]]
Action 1: ([0.]) -> ([0.]), reward: -1.0
GET: ([0.], 0) -> [[-0.12424774]]
GET: ([0.], 1) -> [[-0.08418848]]
GET: ([0.], 2) -> [[-0.1849702]]
GET: ([0.], 1) -> [[-0.08418848]]
SET: ([0.], 1) <- [[-0.97661155]]
GET: ([0.], 0) -> [[-0.15445906]]
GET: ([0.], 1) -> [[-0.1225706]]
