In [1]:
import numpy as np
import random

In [2]:
# grid dimensions
grid_size = 4

# actions: =up, 1=right, 2=down, 3=left
actions = [0, 1, 2, 3]

# q-table: rows are states (16) columns are actions (4)
Q = np.zeros((grid_size * grid_size, len(actions)))

In [4]:
# define rewards 
def get_reward(state):
    if state == 15:         #goal
        return 10
    elif state == 6:        #trap
        return -10
    else:
        return -1

# Convert 20 (row, col) to 1D state index
def state_index(row, col):
    return row * grid_size + col


# move function
def take_action(row, col, action):
    if action == 0 and row > 0: row -= 1                  #up
    elif action == 1 and col < grid_size - 1: col += 1    #right
    elif action == 2 and row < grid_size - 1: row += 1    #down
    elif action == 3 and col > 0: col -=1                 #left
    return row, col    

In [5]:
# Q-learning parameters 
alpha = 0.1
gamma = 0.9
epsilon = 0.2

# training loop
for episode in range(1000):
    row, col = 0, 0  # start position

    while True:
        state = state_index(row, col)

        #choose action: explore or exploit
        if random.uniform(0, 1) < epsilon:
            action = random.choice(actions)
        else:
            action = np.argmax(Q[state])

        new_row, new_col = take_action(row, col, action)
        new_state = state_index(new_row, new_col)
        reward = get_reward(new_state)

        # Q-learning update
        Q[state, action] += alpha * (reward + gamma * np.max(Q[new_state]) - Q[state, action])

        row, col = new_row, new_col

        if new_state == 15 or new_state == 6:
            break    # episode ends at goal or trap

In [6]:
# display final Q-table 
print("Trained Q-table:")
print(Q)

Trained Q-table:
[[ 0.6162463   1.67443561  1.8098      0.60798219]
 [-0.64199286 -1.51206329  3.1091851  -1.30509134]
 [-1.12801184 -1.12668414 -1.9        -1.12120766]
 [-0.58519851 -0.58519851 -0.43306924 -0.76839311]
 [ 0.62170547  3.10120849  3.122       1.80257894]
 [ 0.77683267 -6.5132156   4.57999861  0.77853303]
 [ 0.          0.          0.          0.        ]
 [-0.199      -0.2881      1.49133578 -1.9       ]
 [ 1.7930089   4.58        4.43413251  3.10015385]
 [ 3.10776326  6.14156481  6.2         3.0901488 ]
 [-4.0951      1.66884367  7.99973533  1.42855487]
 [-0.1667782   0.3961894   7.45813417 -0.1       ]
 [ 1.37020654  6.19973101  1.49081615  2.07678879]
 [ 4.57229752  8.          6.17483859  4.56067598]
 [ 6.15962098 10.          7.98175458  6.17192202]
 [ 0.          0.          0.          0.        ]]
