In [1]:
import tkinter as tk
import numpy as np
import random
import time

# -----------------------------
# Environment setup (Gridworld)
# -----------------------------

grid_size = 5   # 5x5 grid
n_states = grid_size * grid_size
actions = [0, 1, 2, 3]   # 0 = up, 1 = down, 2 = left, 3 = right
terminal_state = n_states - 1  # bottom-right cell is goal

# Rewards
rewards = np.full(n_states, -1)
rewards[terminal_state] = 10

# -----------------------------
# SARSA setup
# -----------------------------

Q = np.zeros((n_states, len(actions)))   # Q-table

# Hyperparameters
alpha = 0.1      # learning rate
gamma = 0.8      # discount factor
epsilon = .6   # exploration rate
episodes = 50   # training episodes

# -----------------------------
# Helper functions
# -----------------------------

def state_to_coords(state):
    """Convert state index into grid coordinates (row, col)."""
    return divmod(state, grid_size)

def coords_to_state(row, col):
    """Convert grid coordinates into state index."""
    return row * grid_size + col

def step(state, action):
    """Take an action in the gridworld and return (next_state, reward)."""
    row, col = state_to_coords(state)

    if action == 0 and row > 0:              # up
        row -= 1
    elif action == 1 and row < grid_size-1:  # down
        row += 1
    elif action == 2 and col > 0:            # left
        col -= 1
    elif action == 3 and col < grid_size-1:  # right
        col += 1

    next_state = coords_to_state(row, col)
    return next_state, rewards[next_state]

def choose_action(state):
    """Epsilon-greedy policy."""
    if random.uniform(0, 1) < epsilon:
        return random.choice(actions)
    return np.argmax(Q[state])

# -----------------------------
# GUI setup
# -----------------------------

cell_size = 80
root = tk.Tk()
root.title("SARSA Gridworld")

canvas = tk.Canvas(root, width=grid_size*cell_size, height=grid_size*cell_size)
canvas.pack()

# Draw grid
for i in range(grid_size):
    for j in range(grid_size):
        x1, y1 = j*cell_size, i*cell_size
        x2, y2 = x1+cell_size, y1+cell_size
        color = "white"
        if coords_to_state(i, j) == terminal_state:
            color = "lightgreen"
        canvas.create_rectangle(x1, y1, x2, y2, fill=color, outline="black")

# Agent
agent = canvas.create_oval(5, 5, cell_size-5, cell_size-5, fill="red")

def update_agent_position(state):
    """Move the red circle to the new state position on the grid."""
    row, col = state_to_coords(state)
    x1, y1 = col*cell_size+5, row*cell_size+5
    x2, y2 = x1+cell_size-10, y1+cell_size-10
    canvas.coords(agent, x1, y1, x2, y2)
    root.update()
    time.sleep(0.05)

# -----------------------------
# Training phase (SARSA)
# -----------------------------

for ep in range(episodes):
    state = 0
    action = choose_action(state)
    update_agent_position(state)
    print("Episode:", ep)

    while state != terminal_state:
        next_state, reward = step(state, action)
        next_action = choose_action(next_state)

        # SARSA update rule
        Q[state, action] += alpha * (
            reward + gamma * Q[next_state, next_action] - Q[state, action]
        )

        # Move to next state-action pair
        state, action = next_state, next_action
        update_agent_position(state)

    time.sleep(0.2)

print("Training complete! Learned Q-table:")
print(Q)

# -----------------------------
# Testing phase
# -----------------------------
test_state = 0
test_path = []
max_test_steps = 50

def step_policy():
    """Execute one step using greedy policy."""
    global test_state, test_path
    if test_state == terminal_state:
        print(" Optimal path found:", test_path)
        return

    if len(test_path) > max_test_steps:
        print(" Max steps exceeded. Path so far:", test_path)
        return

    qvals = Q[test_state]
    max_q = np.max(qvals)
    best_actions = np.where(qvals == max_q)[0]
    action = np.random.choice(best_actions)

    test_state, _ = step(test_state, action)
    test_path.append(test_state)
    update_agent_position(test_state)

    root.after(300, step_policy)

def run_optimal_policy():
    """Initialize test run."""
    global test_state, test_path
    test_state = 0
    test_path = [test_state]
    update_agent_position(test_state)
    step_policy()

# Button
button = tk.Button(root, text="Run Optimal Policy", command=run_optimal_policy)
button.pack()

root.mainloop()


TclError: no display name and no $DISPLAY environment variable