In [1]:
import random

In [2]:
DISCOUNT = .9
LEARNING_RATE = .1
CONVERGENCE_THRESHOLD = .01


def greedy_policy(q, player_X, player_Y):

    max_q = max(q[player_Y][player_X].values())
    max_actions = [action for action, value in q[player_Y][player_X].items() if value == max_q]
    return random.choice(max_actions)


def epsilon_greedy_policy(q, player_X, player_Y, epsilon=0.3
):

    if random.random() < epsilon:
        return random.choice(['N', 'E', 'S', 'W'])
    else:
        max_q = max(q[player_Y][player_X].values())
        max_actions = [action for action, value in q[player_Y][player_X].items() if value == max_q]
        return random.choice(max_actions)
    

# None, None: User fell off board
def take_a(windyworld, wind, a, player_X, player_Y):

    direction_map = {
        'N': (0, -1),  # Move North (decrease Y)
        'S': (0, 1),   # Move South (increase Y)
        'E': (1, 0),   # Move East (increase X)
        'W': (-1, 0),  # Move West (decrease X)
    }

    delta_X, delta_Y = direction_map[a]
    new_X = player_X + delta_X
    new_Y = player_Y + delta_Y
    
    # Check if the initial movement is within bounds
    if 0 <= new_X and new_X < len(windyworld[0]) and 0 <= new_Y and new_Y < len(windyworld):

        if windyworld[new_Y][new_X] == 'X':
            return None, None
        
        wind_map = {
            '': 0,   # No wind
            '^': -1, # Upward wind
            'v': 1   # Downward wind
        }
        new_Y_wind = new_Y + wind_map[wind[new_X]]
        # Check if movement w/ wind is within bounds
        if 0 <= new_Y_wind and new_Y_wind < len(windyworld):

            if windyworld[new_Y_wind][new_X] == 'X':
                return None, None
            else:
                return new_X, new_Y_wind
        
        else:
            return None, None
    else:
        return None, None

In [None]:
def SARSA(windyworld, wind, lr=LEARNING_RATE, d=DISCOUNT, ct=CONVERGENCE_THRESHOLD, min_episode_count=1000):

    q = [[{'N': 0, 'E': 0, 'S': 0, 'W': 0} for _ in r] for r in windyworld]

    starting_states = []
    goal_X = 0
    goal_Y = 0
    for y in range(len(windyworld)):
        for x in range(len(windyworld[y])):
            if windyworld[y][x] == '_':
                starting_states.append((x, y))
            if windyworld[y][x] == '1':
                goal_X = x
                goal_Y = y

    episode_count = 0
    while True:
        episode_count += 1
        player_X1, player_Y1 = random.choice(starting_states) # S
        greedy_eps_A1 = epsilon_greedy_policy(q, player_X, player_Y) # A
        td_sum = 0
        while True:

            # Take action A and observe R and S'
            upd_X, upd_Y = take_a(windyworld, wind, greedy_eps_A1, player_X, player_Y) # S'
            terminal = False
            r = 0 # R
            if upd_X == None and upd_Y == None:
                terminal = True
                r = -10
            elif upd_X == player_X and upd_Y == player_Y:
                r = -1
            elif windyworld[upd_Y][upd_X] == '1':
                terminal = True
                r = 10
            else: # windyworld[upd_Y][upd_X] == '_':
                current_dist = abs(player_X - goal_X) + abs(player_Y - goal_Y)
                new_dist = abs(upd_X - goal_X) + abs(upd_Y - goal_Y)
                if new_dist < current_dist:
                    r = 0
                else:
                    r = -1

            greedy_eps_A2 = epsilon_greedy_policy(q, upd_X, upd_Y) # A'
            
            if terminal: # S' is terminal, therefore, max_a Q(S', a) = 0
                # Q(S, A) <- Q(S, A) + lr[R - Q(S, A)]
                td = (lr * (r - q[player_Y][player_X][greedy_eps_A1]))
                td_sum += td
                q[player_Y][player_X][greedy_eps_A1] += td
                break
            else:
                # Q(S, A) <- Q(S, A) + lr[R + d*Q(S', A') - Q(S, A)]
                td = (lr * (r + (d * q[upd_Y][upd_X][greedy_eps_A2])  - q[player_Y][player_X][greedy_eps_A1]))
                td_sum += td
                q[player_Y][player_X][greedy_eps_A1] += td
                
                # S' -> S
                player_X = upd_X
                player_Y = upd_Y
                # A' -> A
                greedy_eps_A2 = greedy_eps_A1

        if episode_count >= min_episode_count and td_sum <= ct:
            break

    return q

'''
Q_learning is not SARSA with greedy, there are a few key differences:

SARSA evaluates S based on S' (given A) and A' selected from epsilon-greedy and goes along that path.
Updates for Q values are based on current (and possibly non optimal) state path.

Q_learning evaluates S based on S' (given A) and A' selected from greedy, but may not follow that exact path.
Updates for Q values are based on best (and optimal, but possibly not current) state path. 
'''
def Q_learning(windyworld, wind, lr=LEARNING_RATE, d=DISCOUNT, ct=CONVERGENCE_THRESHOLD, min_episode_count=1000):

    q = [[{'N': 0, 'E': 0, 'S': 0, 'W': 0} for _ in r] for r in windyworld]

    starting_states = []
    goal_X = 0
    goal_Y = 0
    for y in range(len(windyworld)):
        for x in range(len(windyworld[y])):
            if windyworld[y][x] == '_':
                starting_states.append((x, y))
            if windyworld[y][x] == '1':
                goal_X = x
                goal_Y = y

    episode_count = 0
    while True:
        episode_count += 1
        player_X, player_Y = random.choice(starting_states) # S
        td_sum = 0
        while True:
            greedy_eps_A = epsilon_greedy_policy(q, player_X, player_Y) # A
            upd_X, upd_Y = take_a(windyworld, wind, greedy_eps_A, player_X, player_Y) # S'
            
            terminal = False
            r = 0 # R
            if upd_X == None and upd_Y == None:
                terminal = True
                r = -10
            elif upd_X == player_X and upd_Y == player_Y:
                r = -1
            elif windyworld[upd_Y][upd_X] == '1':
                terminal = True
                r = 10
            else: # windyworld[upd_Y][upd_X] == '_':
                current_dist = abs(player_X - goal_X) + abs(player_Y - goal_Y)
                new_dist = abs(upd_X - goal_X) + abs(upd_Y - goal_Y)
                if new_dist < current_dist:
                    r = 0
                else:
                    r = -1

            if terminal: # S' is terminal, therefore, max_a Q(S', a) = 0
                # Q(S, A) <- Q(S, A) + lr[R - Q(S, A)]
                td = (lr * (r - q[player_Y][player_X][greedy_eps_A]))
                td_sum += td
                q[player_Y][player_X][greedy_eps_A] += td
                break
            else:
                # Q(S, A) <- Q(S, A) + lr[R + d*max_a(Q(S', a)) - Q(S, A)]
                best_a = greedy_policy(q, upd_X, upd_Y)
                td = (lr * (r + (d * q[upd_Y][upd_X][best_a])  - q[player_Y][player_X][greedy_eps_A]))
                td_sum += td
                q[player_Y][player_X][greedy_eps_A] += td
                
                # S' -> S
                player_X = upd_X
                player_Y = upd_Y

        if episode_count >= min_episode_count and td_sum <= ct:
            break

    return q

In [4]:
windyworld = [
    ['_', '_', '_', '_', '_', 'X', '_', '_'],
    ['_', '_', '_', '_', '_', 'X', '1', '_'],
    ['_', '_', '_', '_', '_', 'X', '_', '_'],
    ['_', '_', 'X', '_', '_', 'X', '_', '_'],
    ['_', '_', 'X', '_', '_', 'X', '_', '_'],
    ['_', '_', 'X', '_', '_', '_', '_', '_'],
    ['_', '_', 'X', '_', '_', '_', '_', '_'],
    ['_', '_', 'X', '_', '_', '_', '_', '_']
]
wind = ['v', '', 'v', '', '', '^', '', 'v']

q = Q_learning(windyworld, wind)

replace_dict = {
    'N': '^',
    'E': '>',
    'S': 'v',
    'W': '<'
}

for y in range(len(windyworld)):
    line = ''
    for x in range(len(windyworld[y])):
        if windyworld[y][x] != '_':
            line += windyworld[y][x]
        else:
            max_q = max(q[y][x].values())
            max_actions = [action for action, value in q[y][x].items() if value == max_q]
            line += ''.join(max_actions)
        line += '\t'
        for key in list(replace_dict):
            line = line.replace(key, replace_dict[key])
    print(line)
print()
print('\t'.join(wind))

v	v	>	v	v	X	v	<	
>	>	>	>	v	X	1	<	
>	^	>	>	v	X	^	<	
>	^	X	>	v	X	^	<	
>	^	X	>	v	X	^	<	
>	^	X	>	v	>	^	<	
>	^	X	>	>	>	^	<	
>	^	X	^	^	^	^	<	

v		v			^		v
