In [1]:
import numpy as np
import random

# Environment settings
grid_size = 4
goal_state = (3, 3)
start_state = (0, 0)
obstacles = [(1, 1), (2, 2)]

# Actions: up, down, left, right
actions = ['up', 'down', 'left', 'right']

# Rewards
goal_reward = 10
obstacle_penalty = -1
step_penalty = -0.1

# Q-table initialization
Q = np.zeros((grid_size, grid_size, len(actions)))

# Learning parameters
alpha = 0.1  # Learning rate
gamma = 0.9  # Discount factor
epsilon = 0.1  # Exploration rate

def is_valid(state):
    if (0 <= state[0] < grid_size) and (0 <= state[1] < grid_size) and (state not in obstacles):
        return True
    return False

def get_next_state(state, action):
    if action == 'up':
        next_state = (state[0] - 1, state[1])
    elif action == 'down':
        next_state = (state[0] + 1, state[1])
    elif action == 'left':
        next_state = (state[0], state[1] - 1)
    elif action == 'right':
        next_state = (state[0], state[1] + 1)
    
    if is_valid(next_state):
        return next_state
    return state

def get_reward(state):
    if state == goal_state:
        return goal_reward
    if state in obstacles:
        return obstacle_penalty
    return step_penalty

# Training the agent
num_episodes = 1000
for _ in range(num_episodes):
    state = start_state
    while state != goal_state:
        if random.uniform(0, 1) < epsilon:
            action = random.choice(actions)
        else:
            action = actions[np.argmax(Q[state[0], state[1]])]
        
        next_state = get_next_state(state, action)
        reward = get_reward(next_state)
        
        old_value = Q[state[0], state[1], actions.index(action)]
        next_max = np.max(Q[next_state[0], next_state[1]])
        
        Q[state[0], state[1], actions.index(action)] = old_value + alpha * (reward + gamma * next_max - old_value)
        
        state = next_state

# Display the learned Q-table
print("Learned Q-table:")
print(Q)


Learned Q-table:
[[[ 4.39722717  0.72221871  4.51504064  5.49539   ]
  [ 4.52836115  5.13046274  4.56626313  6.2171    ]
  [ 5.80423889  7.019       4.91454874  5.70756889]
  [ 0.64408185  7.83098038  0.59352912 -0.029701  ]]

 [[-0.1177221   1.8793397  -0.12167317 -0.09347526]
  [ 0.          0.          0.          0.        ]
  [ 5.76207134  6.56389338  6.56791226  7.91      ]
  [ 5.66974671  8.9         5.84076384  6.94378139]]

 [[-0.09985463 -0.06297286 -0.06793465  3.65269017]
  [ 0.16378947  5.83702771  0.16297545  0.2345111 ]
  [ 0.          0.          0.          0.        ]
  [ 7.17984664 10.          8.38242017  8.50931788]]

 [[-0.04098644 -0.0385219  -0.03940399  0.59720465]
  [-0.028891   -0.0199     -0.02152     8.05477072]
  [ 1.28932296  1.19499714  0.09631667  9.81751996]
  [ 0.          0.          0.          0.        ]]]
