# Imports and Gloabls

In [None]:
import tkinter as tk
import time
import random
import numpy as np # for argmax
cells = {}

# global variables for MDP
states = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3)]

# global variables for GUI elements
current_grid_mode = "v"
cells = {}
root = None
grid_frame = None
control_panel_frame = None
value_iteration_button = None
q_learning_button = None
policy_iteration_button = None
epsilon_greedy_q_button = None
reset_button = None
x_value_entry = None
r_value_entry = None
epsilon_entry = None
discount_entry = None
speed_slider = None


# MDP Functions

In [91]:
def get_mdp_model(): #this could be decomposed further since I have a section for MDP funtions, for now I will keep it to make the value iteration function easier to follow in code
    states = [
        (0, 0), (0, 1), (0, 2), (0, 3),  # row 0, top row
        (1, 0), (1, 1), (1, 2), (1, 3),  # Row 1, including wall at (1,1) which needs to be accounted for in generating tuples for printout
        (2, 0), (2, 1), (2, 2), (2, 3)   # Row 2
    ]
    terminal_states = [(0, 3), (1, 3)]


    actions = ["up", "down", "left", "right"]
    
    # rewards, (following example here: https://www.youtube.com/watch?v=UuTkioxL9bQ), this is a dictionary of {state: {action: reward}}
    reward_step_cost = float(r_value_entry.get()) # get living reward cost
    rewards = {}
    for state in states: 
        if state in terminal_states:
            rewards[state] = {action: 0 for action in actions} # no reward for actions in terminal states, to simplify
        elif state == (1, 1): # wall state
            rewards[state] = {action: 0 for action in actions} # no reward for actions in wall state
        else:
            rewards[state] = {action: reward_step_cost for action in actions}
    
    # handling the terminal state rewards directly with overwrite
    rewards[(0, 3)] = {action: 1.0 for action in actions} # +1 
    rewards[(1, 3)] = {action: -1.0 for action in actions} # -1 
    
    
    # transition model generation setup, again as a dictionary of {state: {action: {next_state: probability}}}
    x_prob_value = float(x_value_entry.get())
    prob_intended = x_prob_value / 100.0
    prob_side = (100.0 - x_prob_value) / 200.0

    # rows and column differences will be use to check logic on moves. boundaries are cheecked to see if it exceeding grid rows or column or the wall at (1,1)
    action_index_difference = {
        "up": (-1, 0),
        "down": (1, 0),
        "left": (0, -1),
        "right": (0, 1),
    }
    side_actions = {
        "up": ["left", "right"],
        "down": ["left", "right"],
        "left": ["up", "down"],
        "right": ["up", "down"],
    }
    
    transition = {}
    grid_rows = 3 
    grid_cols = 4
    
    for state in states: 
        transition[state] = {}
        if state in terminal_states or state == (1, 1): # no transitions from terminal or wall states
            for action in actions:
                transition[state][action] = {state: 1.0} # stay in the same state with prob 1.0
            continue 
    
        for action in actions:
            transition[state][action] = {} # initialize next state probabilities for this action
    
            # intended next state
            row, col = state
            row_diff, col_diff = action_index_difference[action]
            intended_next_state_candidate = (row + row_diff, col + col_diff)
    
            # handle boundaries and wall for intended state
            if not (0 <= intended_next_state_candidate[0] < grid_rows and 0 <= intended_next_state_candidate[1] < grid_cols) or intended_next_state_candidate == (1, 1):
                intended_next_state = state # stay in current state if intended move is invalid
            else:
                intended_next_state = intended_next_state_candidate
    
    
            transition[state][action][intended_next_state] = prob_intended # add intended next state and probability
    
            # side move next states and probabilities
            for side_action in side_actions[action]:
                side_row_diff, side_col_diff = action_index_difference[side_action]
                side_next_state_candidate = (row + side_row_diff, col + side_col_diff)
    
                # boundaries and wall for side states
                if not (0 <= side_next_state_candidate[0] < grid_rows and 0 <= side_next_state_candidate[1] < grid_cols) or side_next_state_candidate == (1, 1):
                    side_next_state = state # stay in place if move is invalid
                else:
                    side_next_state = side_next_state_candidate
                transition[state][action][side_next_state] = transition[state][action].get(side_next_state, 0.0) + prob_side # add side move and probability
    discount = float(discount_entry.get())
    return states, rewards, transition, discount, actions

# GUI Functions

In [92]:
def setup_gui():
    global root, grid_frame, control_panel_frame, value_iteration_button, q_learning_button, policy_iteration_button, epsilon_greedy_q_button, reset_button
    global output_label, x_value_entry, r_value_entry, epsilon_entry, discount_entry, speed_slider
    root = tk.Tk()
    root.title("Gridworld Display")

    # frame for the grid
    grid_frame = tk.Frame(root)
    grid_frame.grid(row=0, column=0, sticky="nsew")

    # frame for the panel of controls at the bottom
    control_panel_frame = tk.Frame(root)
    control_panel_frame.grid(row=1, column=0, sticky="ew")

    #  buttons (row 0 of control_panel_frame)
    value_iteration_button = tk.Button(control_panel_frame, text="Run Value Iteration", command=value_iteration)
    value_iteration_button.grid(row=0, column=0, padx=5, pady=5)

    q_learning_button = tk.Button(control_panel_frame, text="Run Q-Learning", command=q_learning)
    q_learning_button.grid(row=0, column=1, padx=5, pady=5)

    policy_iteration_button = tk.Button(control_panel_frame, text="Run Policy Iteration", command=policy_iteration)
    policy_iteration_button.grid(row=0, column=2, padx=5, pady=5)

    epsilon_greedy_q_button = tk.Button(control_panel_frame, text="Run Epsilon Greedy", command=epsilon_greedy)
    epsilon_greedy_q_button.grid(row=0, column=3, padx=5, pady=5)

    reset_button = tk.Button(control_panel_frame, text="Reset Grid", command=reset_grid)
    reset_button.grid(row=0, column=4, padx=5, pady=5)

    # input boxes and labels (row 1 of control_panel_frame) 
    x_value_label = tk.Label(control_panel_frame, text="X Value:")
    x_value_label.grid(row=1, column=0, padx=5, pady=5, sticky="e")
    x_value_entry = tk.Entry(control_panel_frame, width=5)
    x_value_entry.grid(row=1, column=1, padx=5, pady=5, sticky="w")
    x_value_entry.insert(0, "90") # default X value

    r_value_label = tk.Label(control_panel_frame, text="R Value:")
    r_value_label.grid(row=1, column=2, padx=5, pady=5, sticky="e")
    r_value_entry = tk.Entry(control_panel_frame, width=5)
    r_value_entry.grid(row=1, column=3, padx=5, pady=5, sticky="w")
    r_value_entry.insert(0, "-0.04") # default R value

    epsilon_label = tk.Label(control_panel_frame, text="Epsilon:")
    epsilon_label.grid(row=2, column=0, padx=5, pady=5, sticky="e")
    epsilon_entry = tk.Entry(control_panel_frame, width=5)
    epsilon_entry.grid(row=2, column=1, padx=5, pady=5, sticky="w")
    epsilon_entry.insert(0, "0.0001") # default epsilon value
    
    discount_label = tk.Label(control_panel_frame, text="Discount:")
    discount_label.grid(row=2, column=2, padx=5, pady=5, sticky="e")
    discount_entry = tk.Entry(control_panel_frame, width=5)
    discount_entry.grid(row=2, column=3, padx=5, pady=5, sticky="w")
    discount_entry.insert(0, "0.99") # default discount value

    output_label = tk.Label(control_panel_frame, text="", width=50, anchor="w")
    output_label.grid(row=3, column=6, columnspan=6, padx=5, pady=5, sticky="w")

    speed_slider_label = tk.Label(control_panel_frame, text="Speed Multiplier:")
    speed_slider_label.grid(row=3, column=0, padx=5, pady=5, sticky="e")
    speed_slider = tk.Scale(control_panel_frame, from_=.5, to=1.5, orient=tk.HORIZONTAL, resolution=0.01)
    speed_slider.set(1)
    speed_slider.grid(row=3, column=1, padx=5, pady=5, sticky="w")
    
    # root window row and column weights, again for resizing
    root.grid_rowconfigure(0, weight=1)
    root.grid_columnconfigure(0, weight=1)

In [93]:
def initialize_q_grid():
    global current_grid_mode
    output_label.config(text="")
    current_grid_mode = "q"
    initial_q_quadtuples = []
    for _ in range(9):
        initial_q_quadtuples.append((0.00, 0.00, 0.00, 0.00)) # initialize to (0.00, 0.00, 0.00, 0.00)
    q_display_grid(grid_frame, initial_q_quadtuples)

In [94]:
def initialize_v_grid():
    global current_grid_mode
    output_label.config(text="")
    current_grid_mode = "v"
    initial_v_tuples = []
    for _ in range(9):
        initial_v_tuples.append((0.00, "up")) # initializing to "Up" as default direction, following the examples in the slides
    v_display_grid(grid_frame, initial_v_tuples)

In [5]:
def reset_grid():
    global current_grid_mode
    output_label.config(text="")
    if(current_grid_mode == "v"):
        initialize_v_grid()
    else:
        initialize_q_grid()

**Display function for V-Score board**
- This is the method to call when updating the display for the board which contains only v-scores and directions.
- Takes in a list of tuples (v_score, direction), for each cell.
- Tuples information is populated into cells starting from the top left and ending with the bottom right.

In [72]:
def v_display_grid(grid_frame, tuples_list):
    if len(tuples_list) != 9:
        raise ValueError("tuples_list must contain exactly 9 tuples.")

    global cells
    cells.clear() # clear existing cells so this can refresh the board on each call

    tuple_index = 0
    for row in range(3):
        for col in range(4):
            if row == 0 and col == 3: # top right position on the board, this cell should not be given in the list of tuples, it's always 1
                text = "1.00"
            elif row == 1 and col == 3: # cell directly under the top right position on the board, same reason as above, it's always -1
                text = "-1.00"
            elif row == 1 and col == 1: # cell in middle row, 2nd column from the right, this is a wall cell, effectively not in the state space
                text = ""
            else:
                text = f"Max Reward:\n\n{tuples_list[tuple_index][0]} if going {tuples_list[tuple_index][1]}."
                tuple_index += 1

            cell_key = (row, col)
            cell = tk.Label(grid_frame, text=text, relief=tk.SOLID, padx=10, pady=5, width=25, height=15, font=("Comic Sans MS", 10)) 
            cell.grid(row=row, column=col, sticky="nsew")
            if row == 1 and col == 1: # cell in middle row, 2nd column from the right, this is a wall cell, I want to color it grey
                cell.config(bg="grey")
            cells[cell_key] = cell

    # default row and column weights for the grid_frame resizing
    for i in range(3):
        grid_frame.grid_rowconfigure(i, weight=1)
    for i in range(4):
        grid_frame.grid_columnconfigure(i, weight=1)


**Display function for Q-Score board**
- This is the fuction that is called to display the board which contains the q-scores and of the 4 directions in each cell.
- This takes in a quadtuple of q_scores, which are used to populate each cell.
    - The quadtuples populate the up, right, down, and left directions respectively when read left-to-right
- The cells of the board are populated beginning with the top-left cell and ending with the bottom-right cell. 

In [73]:
def q_display_grid(grid_frame, quadtuple_list):
    if len(quadtuple_list) != 9:
        raise ValueError("quadtuple_list must contain exactly 9 quadtuples.")

    global cells 
    cells.clear() # clear existing cells so this can refresh the board on each call 

    # next iterate through the list of quadruples and interpret their positions to copy the value to generate each cell frame
    tuple_index = 0
    for row in range(3):
        for col in range(4):
            cell_key = (row, col)

            cell_frame = tk.Frame(grid_frame, relief=tk.SOLID, bd=1) # frame to hold labels, parent to grid_frame
            cell_frame.grid(row=row, column=col, sticky="nsew") #expand to fill cell

            cell_label = tk.Label(cell_frame, text="", padx=10, pady=5, width=25, height=15) # center label
            cell_label.grid(row=1, column=1) # center pos in frame

            top_label = tk.Label(cell_frame, text="", anchor="s") # top q-score
            top_label.grid(row=0, column=1, sticky="ew") # top pos, stretch to right and left sides of cell

            right_label = tk.Label(cell_frame, text="", anchor="w") # right q-score
            right_label.grid(row=1, column=2, sticky="ns") # right pos, stretch to top and bottom sides of cell

            bottom_label = tk.Label(cell_frame, text="", anchor="n") # bottom
            bottom_label.grid(row=2, column=1, sticky="ew") 

            left_label = tk.Label(cell_frame, text="", anchor="e") # left
            left_label.grid(row=1, column=0, sticky="ns") 

            if row == 0 and col == 3: # top right position on the board, this cell should not be given in the list of tuples, it's always 1
                cell_label.config(text="1.00") 
            elif row == 1 and col == 3: # cell directly under the top right position on the board, same reason as above, it's always -1
                cell_label.config(text="-1.00") 
            elif row == 1 and col == 1: # cell in middle row, 2nd column from the right, this is a wall cell, effectively not in the state space 
                cell_label.config(text="") 
                top_label.config(text="")
                right_label.config(text="")
                bottom_label.config(text="")
                left_label.config(text="")
            else:
                quadtuple = quadtuple_list[tuple_index]
                if not isinstance(quadtuple, tuple) or len(quadtuple) != 4:  # should not occur in scope at submittion
                    raise ValueError(f"Expected a quadtuple at index {tuple_index}, got: {quadtuple}")
                top_val, right_val, bottom_val, left_val = quadtuple
                top_label.config(text=str(top_val))
                right_label.config(text=str(right_val))
                bottom_label.config(text=str(bottom_val))
                left_label.config(text=str(left_val))
                tuple_index += 1

            cells[cell_key] = cell_frame

    # default row and column weights for the grid_frame, this is for resizing
    for i in range(3):
        grid_frame.grid_rowconfigure(i, weight=1)
    for i in range(4):
        grid_frame.grid_columnconfigure(i, weight=1)

# Part 1 - Value Iteration

In [85]:
def value_iteration():
    initialize_v_grid()
    states, rewards, transition, discount, actions = get_mdp_model()
    epsilon = float(epsilon_entry.get())
    threshold = epsilon * (1 - discount) / discount #definine the breakout condition here to avoid performing the calculation on every loop
    
    # initialize values ignoring the wall
    V = {s: 0 for s in states if s != (1, 1)}
    # manually setting the terminal states to their rewards
    V[(0, 3)] = rewards[(0, 3)]["up"]  # should be 1.0
    V[(1, 3)] = rewards[(1, 3)]["up"]  # should be -1.0
    
    iteration = 0
    while True:
        iteration += 1
        delta = 0
        new_V = V.copy()

        for s in states:
            if s == (1, 1) or s in [(0, 3), (1, 3)]:
                continue  # skip wall and terminal states
                
            # calculate value for each action and take the max
            max_value = float('-inf')
            for a in actions:
                value = rewards[s][a]
                for next_state, prob in transition[s][a].items():
                    if next_state != (1, 1):  # skip the wall state
                        value += discount * prob * V[next_state]
                max_value = max(max_value, value)
            new_V[s] = max_value
            delta = max(delta, abs(new_V[s] - V[s]))
        
        # batch update the values
        V = new_V
        
        # policy determination to pass with the display_tuples
        policy = {}
        for s in states:
            if s == (1, 1) or s in [(0, 3), (1, 3)]:
                policy[s] = None
                continue
            action_values = {}
            best_action = None
            best_value = float('-inf')
            for a in actions:
                value = rewards[s][a]
                for next_state, prob in transition[s][a].items():
                    if next_state != (1, 1):
                        value += discount * prob * V[next_state]
                action_values[a] = value

            best_action = max(action_values, key=action_values.get) # this is the equivalent of the argmax we went over in class but I am using a dictionary for the action value pairs so this made more sense. 
            policy[s] = best_action
        # generating and passing the tuples of (state, optimal policy direction) pairs    
        display_states = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3)]
        display_tuples = []
        for state in display_states:
            v_score_str = f"{V[state]:.2f}"
            direction_str = policy.get(state, "")
            display_tuples.append((v_score_str, direction_str))
        v_display_grid(grid_frame, tuple(display_tuples))
        grid_frame.update()
        global speed_slider
        wait_time = 0.2 / float(speed_slider.get()) #0.2 is a good default delay, here we're dividing instead of multiplying since an increase in 'speed' technically should be inversely proportional to the delay we are using
        time.sleep(wait_time)
        
        # check for convergence
        if delta <= threshold:
            output_label.config(text=f"Complete. Value iteration converged after {iteration} iterations.")
            break

# Part 2: Policy Iteration

In [88]:
def policy_iteration():
    pass

# Part 3: Q-Learning

In [86]:
def q_learning():
    initialize_q_grid()

# Part 4: Greedy

In [87]:
def epsilon_greedy():
    pass

# Main Controller

In [95]:
def main():
    setup_gui() 

    initialize_v_grid()
    
    root.mainloop() 

if __name__ == "__main__":
    main()