In [1]:
import numpy as np

# Grid world environment
grid_size = 5
num_actions = 4

In [2]:
# Q-learning parameters
num_episodes = 1000
alpha = 0.5
gamma = 0.9

In [3]:
# Initialize Q-table
Q = np.zeros((grid_size, grid_size, num_actions))

In [4]:
# Define reward function
def get_reward(state, action):
    if state == (grid_size - 1, grid_size - 1):
        return 10
    else:
        return -1

In [5]:
# Q-learning algorithm
for episode in range(num_episodes):
    # Initialize the starting state
    state = (0, 0)

    while state != (grid_size - 1, grid_size - 1):
        # Choose an action based on the epsilon-greedy policy
        if np.random.rand() < 0.1:
            action = np.random.randint(num_actions)
        else:
            action = np.argmax(Q[state])

        # Take the action and observe the next state and reward
        if action == 0:  # Up
            next_state = (max(state[0] - 1, 0), state[1])
        elif action == 1:  # Down
            next_state = (min(state[0] + 1, grid_size - 1), state[1])
        elif action == 2:  # Left
            next_state = (state[0], max(state[1] - 1, 0))
        elif action == 3:  # Right
            next_state = (state[0], min(state[1] + 1, grid_size - 1))

        reward = get_reward(next_state, action)

        # Update Q-table using Q-learning update rule
        Q[state][action] += alpha * (reward + gamma * np.max(Q[next_state]) - Q[state][action])

        # Move to the next state
        state = next_state

In [6]:
# Print the learned Q-values
print("Learned Q-values:")
for i in range(grid_size):
    for j in range(grid_size):
        print(f"State ({i}, {j}):")
        print("  Up:", Q[i][j][0])
        print("  Down:", Q[i][j][1])
        print("  Left:", Q[i][j][2])
        print("  Right:", Q[i][j][3])

Learned Q-values:
State (0, 0):
  Up: -1.3906558059373813
  Down: -0.43407470316366903
  Left: -1.390655801019927
  Right: -0.4340620000000006
State (0, 1):
  Up: -0.4340620004367673
  Down: 0.6288199992091221
  Left: -1.3906561860229638
  Right: 0.6288199999999993
State (0, 2):
  Up: 0.6288199752855492
  Down: 1.8097999999999992
  Left: -0.43406201158589
  Right: 1.809799999938773
State (0, 3):
  Up: 0.07159638671874946
  Down: 3.121999999999999
  Left: -1.8152187499999999
  Right: 0.8477080078124994
State (0, 4):
  Up: -0.7344687500000002
  Down: 4.579976806640625
  Left: 0.9346931772786119
  Right: 0.0029687499999997424
State (1, 0):
  Up: -1.8696729465820316
  Down: -2.7729480468750003
  Left: -3.0166270390625005
  Right: 0.6288195658344791
State (1, 1):
  Up: -1.0063280820312506
  Down: -0.7762420863106033
  Left: -0.7016675608947511
  Right: 1.8097999999999992
State (1, 2):
  Up: 0.6288199297092052
  Down: 2.498950608956008
  Left: 0.6288199762916693
  Right: 3.121999999999999
St