In [68]:
import numpy as np

class Environment:
    def __init__(self):
        self.num_states = 6
        self.num_actions = 4
        self.reward_table = np.array([
            [1, 1, 100],
            [1, 1, 1,],
        ])
        self.location=np.array([
            [1,2,6],
            [3,4,5]
        ])
        self.actions=np.array(["up","right","down","left"])

    def reset(self):
        self.state = np.random.randint(1, self.num_states+1)
        return self.state

    def step(self, action):
        r,c=np.argwhere(self.location==self.state)[0]
        next_state = self.get_next_location(r,c,action)
        reward = self.reward_table[r, c]
        done = (next_state == 6)
        self.state = next_state
        return next_state, reward, done
    
    def _step(self,state,action):
        r,c=np.argwhere(self.location==state)[0]
        next_state = self.get_next_location(r,c,action)
        return next_state

    def get_next_location(self,current_row_index, current_column_index, action_index):
        new_row_index = current_row_index
        new_column_index = current_column_index
        if self.actions[action_index] == 'up' and current_row_index > 0:
          new_row_index -= 1
        elif self.actions[action_index] == 'right' and current_column_index < 2:
          new_column_index += 1
        elif self.actions[action_index] == 'down' and current_row_index < 1:
          new_row_index += 1
        elif self.actions[action_index] == 'left' and current_column_index > 0:
          new_column_index -= 1
        return self.location[new_row_index, new_column_index]

class QLearning:
    def __init__(self, env, num_episodes, gamma=0.8, alpha=0.5, epsilon=0.1):
        self.env = env
        self.num_episodes = num_episodes
        self.gamma = gamma
        self.prev = None
        self.epsilon = epsilon
        self.q_table = np.zeros((env.num_states, env.num_actions))

    def choose_action(self, state,epsilon=None):
        if epsilon==None:
          epsilon=self.epsilon
        if np.random.random() < epsilon:
          # print("SP")
          return np.argmax(self.q_table[state-1, :])
        else:
          return np.random.randint(0, self.env.num_actions)

    def train(self):
        for episode in range(self.num_episodes):
            state = self.env.reset()
            done = False
            
            while not done:
                action = self.choose_action(state)
                # print(state," : ",action)
                next_state, reward, done = self.env.step(action)
                # r,c=np.argwhere(self.env.location==state)[0]
                self.q_table[state-1, action] = reward + self.gamma * np.max(self.q_table[next_state-1, :])
                state = next_state
        return self.q_table

    def get_shortest_path(self,state):
        #return immediately if this is an invalid starting location
        if state==6:
          return []
        else: #if this is a 'legal' starting location
          shortest_path = []
          shortest_path.append([state])
          cnt=0
          #continue moving along the path until we reach the goal (i.e., the item packaging location)
          while not state==6:
            cnt+=1
            #get the best action to take
            action = self.choose_action(state, 1.)
            print(state," : ",action)
            #move to the next location on the path, and add the new location to the list
            state= self.env._step(state,action)
            # print("NS",next_state)
            shortest_path.append([state])
          return shortest_path


In [72]:
env = Environment()
q_learning = QLearning(env, num_episodes=100)
q_table = q_learning.train()
print(q_table)
q_learning.get_shortest_path(5)

[[220.32216525 274.15270656 160.10470656 220.32216525]
 [228.73888    353.35270656 198.8808832  220.32216525]
 [198.8808832  198.8808832  160.10470656 160.10470656]
 [247.351104   274.15270656 198.8808832  160.10470656]
 [353.35270656 247.351104   247.351104   198.8808832 ]
 [440.4408832  425.551104   327.73888    346.351104  ]]
5  :  0


[[5], [6]]