In [1]:
import numpy as np
import random

In [2]:
# Define grid world parameters
GRID_SIZE = 5
ACTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # Up, Down, Left, Right
GAMMA = 0.9  # Discount factor
ALPHA = 0.5  # Learning rate
EPISODES = 500  # Number of episodes
REWARD_STEP = -1
REWARD_GOAL = 10
GOAL_STATE = (4, 4)

In [3]:
def get_next_state(state, action):
    i, j = state
    ni, nj = i + action[0], j + action[1]
    if 0 <= ni < GRID_SIZE and 0 <= nj < GRID_SIZE:
        return (ni, nj)
    return state  # If hitting wall, stay in place

# Initialize Q-table
Q = np.zeros((GRID_SIZE, GRID_SIZE, len(ACTIONS)))

In [4]:
def q_learning():
    for _ in range(EPISODES):
        state = (0, 0)  # Start state
        while state != GOAL_STATE:
            action_idx = random.choice(range(len(ACTIONS)))  # Random action selection
            action = ACTIONS[action_idx]
            next_state = get_next_state(state, action)
            reward = REWARD_GOAL if next_state == GOAL_STATE else REWARD_STEP

            # Q-learning update rule
            best_next_q = np.max(Q[next_state[0], next_state[1]])
            Q[state[0], state[1], action_idx] += ALPHA * (reward + GAMMA * best_next_q - Q[state[0], state[1], action_idx])

            state = next_state  # Move to the next state

In [5]:
# Run Q-learning
q_learning()

# Print learned Q-values
for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        print(f"State ({i}, {j}): {Q[i, j]}")

State (0, 0): [-1.3906558 -0.434062  -1.3906558 -0.434062 ]
State (0, 1): [-0.434062   0.62882   -1.3906558  0.62882  ]
State (0, 2): [ 0.62882   1.8098   -0.434062  1.8098  ]
State (0, 3): [1.8098  3.122   0.62882 3.122  ]
State (0, 4): [3.122  4.58   1.8098 3.122 ]
State (1, 0): [-1.3906558  0.62882   -0.434062   0.62882  ]
State (1, 1): [-0.434062  1.8098   -0.434062  1.8098  ]
State (1, 2): [0.62882 3.122   0.62882 3.122  ]
State (1, 3): [1.8098 4.58   1.8098 4.58  ]
State (1, 4): [3.122 6.2   3.122 4.58 ]
State (2, 0): [-0.434062  1.8098    0.62882   1.8098  ]
State (2, 1): [0.62882 3.122   0.62882 3.122  ]
State (2, 2): [1.8098 4.58   1.8098 4.58  ]
State (2, 3): [3.122 6.2   3.122 6.2  ]
State (2, 4): [4.58 8.   4.58 6.2 ]
State (3, 0): [0.62882 3.122   1.8098  3.122  ]
State (3, 1): [1.8098 4.58   1.8098 4.58  ]
State (3, 2): [3.122 6.2   3.122 6.2  ]
State (3, 3): [4.58 8.   4.58 8.  ]
State (3, 4): [ 6.2 10.   6.2  8. ]
State (4, 0): [1.8098 3.122  3.122  4.58  ]
State (4, 1)