In [1]:
import random

In [2]:
class QLearning:
    def __init__(self, rows, cols):
        self.rows = rows
        self.cols = cols
        self.start_state = (0, 0)
        self.goal_state = (2, 2)
        self.epsilon = 0.1
        self.alpha = 0.1
        self.gamma = 0.9
        self.q_table = {}
        self.actions = ['UP', 'DOWN', 'LEFT', 'RIGHT']
        for row in range(rows):
            for col in range(cols):
                self.q_table[(row, col)] = {action: 0.0 for action in self.actions}

    def get_next_state(self, state, action):
        row, col = state
        if action == 'UP':
            return max(row - 1, 0), col
        elif action == 'DOWN':
            return min(row + 1, self.rows - 1), col
        elif action == 'LEFT':
            return row, max(col - 1, 0)
        elif action == 'RIGHT':
            return row, min(col + 1, self.cols - 1)
        else:
            return state

    def is_terminal(self, state):
        return state == self.goal_state

    def is_penalty(self, state):
        return state == (1, 1)

    def getReward(self, state):
        if self.is_terminal(state):
            return 10
        if self.is_penalty(state):
            return -1
        else:
            return 0

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.choice(self.actions)
        else:
            return max(self.q_table[state], key=self.q_table[state].get)

    def update_q_value(self, state, action, reward, next_state):
        oldQ = self.q_table[state][action]
        maxNextQ = max(self.q_table[next_state].values())
        newQ = oldQ + self.alpha * (reward + self.gamma * maxNextQ - oldQ)
        self.q_table[state][action] = newQ

    def get_best_action(self, state):
        return max(self.q_table[state], key=self.q_table[state].get)

    def train(self, episodes):
        for episode in range(episodes):
            state = self.start_state
            while not self.is_terminal(state):
                action = self.choose_action(state)
                next_state = self.get_next_state(state, action)
                reward = self.getReward(next_state)

                # Prevent learning from invalid actions
                if state != next_state or self.is_terminal(next_state):
                    self.update_q_value(state, action, reward, next_state)

                state = next_state

    def print_optimal_policy(self):
        for row in range(self.rows):
            for col in range(self.cols):
                state = (row, col)
                if self.is_terminal(state):
                    print(f"State {state}: Goal")
                else:
                    best_action = self.get_best_action(state)
                    print(f"State {state}: Best action = {best_action}")

    def print_q_table(self):
        for row in range(self.rows):
            for col in range(self.cols):
                state = (row, col)
                print(f"State {state}: {self.q_table[state]}")

In [3]:
q = QLearning(1000, 5)
q.train(100)
q.print_q_table()
q.print_optimal_policy()

State (0, 0): {'UP': 0.0, 'DOWN': 0.22162868962586793, 'LEFT': 0.0, 'RIGHT': 7.249531006152246}
State (0, 1): {'UP': 0.0, 'DOWN': -0.9936373145588641, 'LEFT': 0.6284048185713833, 'RIGHT': 8.09072734626781}
State (0, 2): {'UP': 0.0, 'DOWN': 8.998135371126468, 'LEFT': 2.3390911714473033, 'RIGHT': 0.18979529897733277}
State (0, 3): {'UP': 0.0, 'DOWN': 0.0, 'LEFT': 3.213750443578949, 'RIGHT': 0.0}
State (0, 4): {'UP': 0.0, 'DOWN': 0.0, 'LEFT': 0.0, 'RIGHT': 0.0}
State (1, 0): {'UP': 1.7635080830047232, 'DOWN': 0.0, 'LEFT': 0.0, 'RIGHT': -0.19}
State (1, 1): {'UP': 0.6782031019321165, 'DOWN': 0.0, 'LEFT': 0.0, 'RIGHT': 0.0}
State (1, 2): {'UP': 3.7047263626581572, 'DOWN': 9.999734386011124, 'LEFT': -0.09993439, 'RIGHT': 0.018230899265317516}
State (1, 3): {'UP': 0.41642064749940677, 'DOWN': 0.0, 'LEFT': 0.0, 'RIGHT': 0.0}
State (1, 4): {'UP': 0.0, 'DOWN': 0.0, 'LEFT': 0.0, 'RIGHT': 0.0}
State (2, 0): {'UP': 0.0, 'DOWN': 0.0, 'LEFT': 0.0, 'RIGHT': 0.0}
State (2, 1): {'UP': -0.1, 'DOWN': 0.0,