In [1]:
import random


class SimpleTaxiEnv():
    """    
    ### Map: D -> destination, R -> origin
        +---+
        |D: |
        | :R|
        +---+

        +---+
        |0:1|
        |2:3|
        +---+

    ### taxi:
        #   -> taxi
        #<$ -> passenger in (taxi)
        #>$ -> passenger out (of taxi)

    ### Actions
        There are 6 discrete deterministic actions:
        - 0: move south
        - 1: move north
        - 2: move east
        - 3: move west
        - 4: pickup passenger
        - 5: drop off passenger

     ### Observations
        There are 12 discrete states since there are 4 taxi positions, 3 possible
        locations of the passenger (including the case when the passenger is in the
        taxi), and 1 destination locations.

    ### Rewards
        - -1 per step unless other reward is triggered.
        - -10  executing "pickup" and "drop-off" actions illegally.
        - +200 delivering passenger.
        - +10  executing "pickup" legally.

    ### ternimate
        True when the passenger is in destination and not in taxi

    """
    state = None
    states = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]  # All possible states
    rewards = [-1, -10, 200, 10]
    actions = [0, 1, 2, 3, 4, 5]
    ternimate = False

    def is_terminate(self):
        '''
        returns true if passenger is in 0 location and not in taxi
        '''
        return self.ternimate

    def get_actions(self):
        return self.actions

    def get_action_sample(self):
        return random.choice(self.actions)

    def get_states(self):
        return self.states

    def get_current_state(self):
        return self.state

    def encode(self, taxi_row, taxi_col):
        self.state = 0
        return self.state

    def reset(self, state: int = None):
        if state == None:
            self.state = random.choice(self.states)
        else:
            self.state = state
        self.ternimate = False
        return self.state

    def step(self, action: int):
        """       
        - 0: move south
        - 1: move north
        - 2: move east
        - 3: move west
        - 4: pickup passenger
        - 5: drop off passenger
        """
        if not isinstance(action, int):
            raise TypeError(
                f'action must be an integer (0, 1, 2, 3, 4, 5) but you sent {type(action)} ({action})')
        if action not in [0, 1, 2, 3, 4, 5]:
            raise ValueError(
                f'action must be an integer (0, 1, 2, 3, 4, 5) but you sent {type(action)} ({action})')

        r = self.rewards[0]

        if action == 0:
            if self.state == 0:
                self.state = 2
                return self.state, r
            if self.state == 1:
                self.state = 3
                return self.state, r
            if self.state == 2:
                self.state = 2
                return self.state, r
            if self.state == 3:
                self.state = 3
                return self.state, r
            if self.state == 4:
                self.state = 6
                return self.state, r
            if self.state == 5:
                self.state = 7
                return self.state, r
            if self.state == 6:
                self.state = 6
                return self.state, r
            if self.state == 7:
                self.state = 7
                return self.state, r
            if self.state == 8:
                self.state = 11
                return self.state, r
            if self.state == 9:
                self.state = 9
                return self.state, r
            if self.state == 10:
                self.state = 9
                return self.state, r
            if self.state == 11:
                self.state = 11
                return self.state, r

        if action == 1:
            if self.state == 0:
                self.state = 0
                return self.state, r
            if self.state == 1:
                self.state = 1
                return self.state, r
            if self.state == 2:
                self.state = 0
                return self.state, r
            if self.state == 3:
                self.state = 1
                return self.state, r
            if self.state == 4:
                self.state = 4
                return self.state, r
            if self.state == 5:
                self.state = 5
                return self.state, r
            if self.state == 6:
                self.state = 4
                return self.state, r
            if self.state == 7:
                self.state = 5
                return self.state, r
            if self.state == 8:
                self.state = 8
                return self.state, r
            if self.state == 9:
                self.state = 10
                return self.state, r
            if self.state == 10:
                self.state = 10
                return self.state, r
            if self.state == 11:
                self.state = 8
                return self.state, r

        if action == 2:
            if self.state == 0:
                self.state = 1
                return self.state, r
            if self.state == 1:
                self.state = 1
                return self.state, r
            if self.state == 2:
                self.state = 3
                return self.state, r
            if self.state == 3:
                self.state = 3
                return self.state, r
            if self.state == 4:
                self.state = 5
                return self.state, r
            if self.state == 5:
                self.state = 5
                return self.state, r
            if self.state == 6:
                self.state = 7
                return self.state, r
            if self.state == 7:
                self.state = 7
                return self.state, r
            if self.state == 8:
                self.state = 10
                return self.state, r
            if self.state == 9:
                self.state = 9
                return self.state, r
            if self.state == 10:
                self.state = 10
                return self.state, r
            if self.state == 11:
                self.state = 9
                return self.state, r

        if action == 3:
            if self.state == 0:
                self.state = 0
                return self.state, r
            if self.state == 1:
                self.state = 0
                return self.state, r
            if self.state == 2:
                self.state = 2
                return self.state, r
            if self.state == 3:
                self.state = 2
                return self.state, r
            if self.state == 4:
                self.state = 4
                return self.state, r
            if self.state == 5:
                self.state = 4
                return self.state, r
            if self.state == 6:
                self.state = 6
                return self.state, r
            if self.state == 7:
                self.state = 6
                return self.state, r
            if self.state == 8:
                self.state = 8
                return self.state, r
            if self.state == 9:
                self.state = 11
                return self.state, r
            if self.state == 10:
                self.state = 8
                return self.state, r
            if self.state == 11:
                self.state = 11
                return self.state, r

        if action == 4:
            if self.state == 0:  # executing "pickup" actions illegally.
                self.state = 8
                r = self.rewards[1]
                return self.state, r
            if self.state == 1:  # executing "pickup" actions illegally.
                self.state = 1
                r = self.rewards[1]
                return self.state, r
            if self.state == 2:  # executing "pickup" actions illegally.
                self.state = 2
                r = self.rewards[1]
                return self.state, r
            if self.state == 3:  # executing "pickup" actions illegally.
                self.state = 3
                r = self.rewards[1]
                return self.state, r
            if self.state == 4:  # executing "pickup" actions illegally.
                self.state = 4
                r = self.rewards[1]
                return self.state, r
            if self.state == 5:  # executing "pickup" actions illegally.
                self.state = 5
                r = self.rewards[1]
                return self.state, r
            if self.state == 6:  # executing "pickup" actions illegally.
                self.state = 6
                r = self.rewards[1]
                return self.state, r
            if self.state == 7:  # executing "pickup" actions legally.
                self.state = 9
                r = self.rewards[3] 
                return self.state, r
            if self.state == 8:  # executing "pickup" actions illegally.
                self.state = 8
                r = self.rewards[1]
                return self.state, r
            if self.state == 9:  # executing "pickup" actions illegally.
                self.state = 9
                r = self.rewards[1]
                return self.state, r
            if self.state == 10:  # executing "pickup" actions illegally.
                self.state = 10
                r = self.rewards[1]
                return self.state, r
            if self.state == 11:  # executing "pickup" actions illegally.
                self.state = 11
                r = self.rewards[1]
                return self.state, r

        if action == 5:
            if self.state == 0:  # executing "drop-off" actions illegally.
                self.state = 0
                r = self.rewards[1]
                return self.state, r
            if self.state == 1:  # executing "drop-off" actions illegally.
                self.state = 1
                r = self.rewards[1]
                return self.state, r
            if self.state == 2:  # executing "drop-off" actions illegally.
                self.state = 2
                r = self.rewards[1]
                return self.state, r
            if self.state == 3:  # executing "drop-off" actions illegally.
                self.state = 3
                r = self.rewards[1]
                return self.state, r
            if self.state == 4:
                self.state = 4  # executing "drop-off" actions illegally.
                r = self.rewards[1]
                return self.state, r
            if self.state == 5:  # executing "drop-off" actions illegally.
                self.state = 5
                r = self.rewards[1]
                return self.state, r
            if self.state == 6:  # executing "drop-off" actions illegally.
                self.state = 6
                r = self.rewards[1]
                return self.state, r
            if self.state == 7:  # executing "drop-off" actions illegally.
                self.state = 7
                r = self.rewards[1]
                return self.state, r
            if self.state == 8:  # passenger is delivered
                self.state = 0
                r = self.rewards[2]
                self.ternimate = True
                return self.state, r
            if self.state == 9:  # executing "drop-off" actions illegally.
                self.state = 7
                r = self.rewards[1]
                return self.state, r
            if self.state == 10:  # executing "drop-off" actions illegally.
                self.state = 10
                r = self.rewards[1]
                return self.state, r
            if self.state == 11:  # executing "drop-off" actions illegally.
                self.state = 11
                r = self.rewards[1]
                return self.state, r

    def render(self):
        if self.state == None:
            raise ValueError(
                "State of the environment is not determined! you Should call reset first")

        if self.state == 0:
            # 0
            print('State 0 (taxi is in 0, passenger is in 0, passenger is out of taxi)')
            print('+-----------+')
            print('|#>$D :     |')
            print('|     :    R|')
            print('+-----------+')

        if self.state == 1:
            # 1
            print('State 1 (taxi is in 1, passenger is in 0)')
            print('+-----------+')
            print('|$D   :    #|')
            print('|     :    R|')
            print('+-----------+')

        if self.state == 2:
            # 2
            print('State 2 (taxi is in 2, passenger is in 0)')
            print('+-----------+')
            print('|$D   :     |')
            print('|#    :    R|')
            print('+-----------+')

        if self.state == 3:
            # 3
            print('State 3 (taxi is in 3, passenger is in 0)')
            print('+-----------+')
            print('|$D   :     |')
            print('|     :   #R|')
            print('+-----------+')

        if self.state == 4:
            # 4
            print('State 4 (taxi is in 0, passenger is in 3)')
            print('+-----------+')
            print('|#D   :     |')
            print('|     :   $R|')
            print('+-----------+')

        if self.state == 5:
            # 5
            print('State 5 (taxi is in 1, passenger is in 3)')
            print('+-----------+')
            print('|D    :    #|')
            print('|     :   $R|')
            print('+-----------+')

        if self.state == 6:
            # 6
            print('State 6 (taxi is in 2, passenger is in 3)')
            print('+-----------+')
            print('|D    :     |')
            print('|#    :   $R|')
            print('+-----------+')

        if self.state == 7:
            # 7
            print('State 7 (taxi is in 3, passenger is in 3, passenger is out of taxi)')
            print('+-----------+')
            print('|D    :     |')
            print('|     : #>$R|')
            print('+-----------+')

        if self.state == 8:
            # 8
            print('State 8 (taxi is in 0, passenger is in 0, passenger is in taxi)')
            print('+-----------+')
            print('|#<$D :     |')
            print('|     :    R|')
            print('+-----------+')

        if self.state == 9:
            # 9
            print('State 9 (taxi is in 3, passenger is in 3, passenger is in taxi)')
            print('+-----------+')
            print('|D    :     |')
            print('|     : #<$R|')
            print('+-----------+')

        if self.state == 10:
            # 10
            print('State 10 (taxi is in 1, passenger is in 1, passenger is in taxi)')
            print('+-----------+')
            print('|D    :  #<$|')
            print('|     :    R|')
            print('+-----------+')

        if self.state == 11:
            # 11
            print('State 11 (taxi is in 2, passenger is in 2, passenger is in taxi)')
            print('+-----------+')
            print('|D    :     |')
            print('|#<$  :    R|')
            print('+-----------+')


In [2]:
from time import sleep
from IPython.display import clear_output

import random
import numpy as np


tx = SimpleTaxiEnv()

q_table = np.zeros([len(tx.get_states()), len(tx.get_actions())])
exploration = 0.0001
learning_rate = 0.1
discount_factor = 0.6


epoches = 10
states = tx.get_states()

for s in range(12):
    for _ in range(epoches):
        state = tx.reset(s)
        # print(tx.state)
        # tx.render()
        # print(q_table)
        # sleep(25)
        # clear_output(wait=True)
        while not tx.is_terminate():
            if random.uniform(0, 1) < exploration:
                action = tx.get_action_sample()
            else:
                action = int(np.argmax(q_table[state]))
                # action = int(random.choice(np.where(q_table[state] == np.amax(q_table[state]))[0]))

            next_state, reward = tx.step(action)
            
            prev_q = q_table[state, action]
            next_max_q = q_table[next_state, np.argmax(q_table[next_state])] 
            # next_max_q = q_table[next_state, int(random.choice(np.where(q_table[next_state] == np.amax(q_table[next_state]))[0]))] 
            new_q = prev_q + learning_rate * \
                (reward + discount_factor * next_max_q - prev_q)

            q_table[state, action] = new_q
            state = next_state
            
            # print(q_table)
            # tx.render()
            # sleep(0.2)
            # clear_output(wait=True)
        # sleep(5)
# print(q_table)

In [3]:
from time import sleep
from IPython.display import clear_output

for tripnum in range(12):
    num = 0
    print(f'trip number {tripnum}')
    taxi = SimpleTaxiEnv()
    state = taxi.reset(tripnum)
    taxi.render()
    sleep(2)
    print('Solving ...')
    while not taxi.is_terminate():
        # action = int(np.argmax(q_table[state]))
        action = int(random.choice(np.where(q_table[state] == np.amax(q_table[state]))[0]))
        print(action)
        sleep(1)
        next_state, reward = taxi.step(action)
        num += 1
        clear_output(wait=True)
        print(f'trip number {tripnum}')
        taxi.render()
        state = next_state
    print('Problem solved')
    sleep(2)
    clear_output(wait=True)
    print('Num=',num)

Num= 2
