In [1]:
import numpy as np

# Grid world environment
grid_size = 5
num_actions = 4

In [2]:
# Q-learning parameters
num_episodes = 1000
alpha = 0.5
gamma = 0.9

In [3]:
# Initialize Q-table
Q = np.zeros((grid_size, grid_size, num_actions))

In [4]:
# Define reward function
def get_reward(state, action):
    if state == (grid_size - 1, grid_size - 1):
        return 10
    else:
        return -1

In [5]:
# Q-learning algorithm
for episode in range(num_episodes):
    # Initialize the starting state
    state = (0, 0)

    while state != (grid_size - 1, grid_size - 1):
        # Choose an action based on the epsilon-greedy policy
        if np.random.rand() < 0.1:
            action = np.random.randint(num_actions)
        else:
            action = np.argmax(Q[state])

        # Take the action and observe the next state and reward
        if action == 0:  # Up
            next_state = (max(state[0] - 1, 0), state[1])
        elif action == 1:  # Down
            next_state = (min(state[0] + 1, grid_size - 1), state[1])
        elif action == 2:  # Left
            next_state = (state[0], max(state[1] - 1, 0))
        elif action == 3:  # Right
            next_state = (state[0], min(state[1] + 1, grid_size - 1))

        reward = get_reward(next_state, action)

        # Update Q-table using Q-learning update rule
        Q[state][action] += alpha * (reward + gamma * np.max(Q[next_state]) - Q[state][action])

        # Move to the next state
        state = next_state

In [6]:
# Print the learned Q-values
print("Learned Q-values:")
for i in range(grid_size):
    for j in range(grid_size):
        print(f"State ({i}, {j}):")
        print("  Up:", Q[i][j][0])
        print("  Down:", Q[i][j][1])
        print("  Left:", Q[i][j][2])
        print("  Right:", Q[i][j][3])

Learned Q-values:
State (0, 0):
  Up: -1.3906559445871287
  Down: -0.4343823217835996
  Left: -1.3906561125547081
  Right: -0.4340620000000006
State (0, 1):
  Up: -0.4340620217974215
  Down: 0.6288198867810769
  Left: -1.3906558036636616
  Right: 0.6288199999999993
State (0, 2):
  Up: 0.6286252312088005
  Down: 1.809799993466446
  Left: -0.43406203496260853
  Right: 1.8097999999999992
State (0, 3):
  Up: 1.809688161086695
  Down: 3.121999999999999
  Left: 0.6288194130779767
  Right: 3.1219996528546163
State (0, 4):
  Up: 2.551164488115844
  Down: 4.579999999999998
  Left: -1.4825
  Right: -1.42625
State (1, 0):
  Up: -2.987289439453125
  Down: -1.4691596081542966
  Left: -3.2004000117187497
  Right: 0.6288120452041339
State (1, 1):
  Up: -3.0633219887695313
  Down: -1.1974857421874998
  Left: -1.003772959057164
  Right: 1.8097999999999121
State (1, 2):
  Up: 0.26195384765624935
  Down: 3.121999999999999
  Left: -0.03789314098091512
  Right: 2.963432617187499
State (1, 3):
  Up: 1.80979