In [1]:
import tkinter as tk
import time
import numpy as np
import random
from PIL import ImageTk, Image

UNIT = 100  # pixels
HEIGHT = 5  # grid height
WIDTH = 5   # grid width
TRANSITION_PROB = 1
POSSIBLE_ACTIONS = [0, 1, 2, 3]  # up, down, left, right
ACTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)]


In [2]:
class Env:
    def __init__(self):
        self.transition_probability = TRANSITION_PROB
        self.width = WIDTH
        self.height = HEIGHT
        self.reward = [[0] * WIDTH for _ in range(HEIGHT)]
        self.possible_actions = POSSIBLE_ACTIONS
        self.reward[2][2] = 1
        self.reward[1][2] = -1
        self.reward[2][1] = -1
        self.all_state = [[x, y] for x in range(WIDTH) for y in range(HEIGHT)]

    def get_reward(self, state, action):
        next_state = self.state_after_action(state, action)
        return self.reward[next_state[0]][next_state[1]]

    def state_after_action(self, state, action_index):
        action = ACTIONS[action_index]
        return self.check_boundary([state[0] + action[0], state[1] + action[1]])

    @staticmethod
    def check_boundary(state):
        state[0] = max(0, min(WIDTH - 1, state[0]))
        state[1] = max(0, min(HEIGHT - 1, state[1]))
        return state

    def get_transition_prob(self, state, action):
        return self.transition_probability

    def get_all_states(self):
        return self.all_state


In [3]:
class ValueIteration:
    def __init__(self, env):
        self.env = env
        self.value_table = [[0.0] * env.width for _ in range(env.height)]
        self.discount_factor = 0.9

    def value_iteration(self):
        next_value_table = [[0.0] * self.env.width for _ in range(self.env.height)]
        for state in self.env.get_all_states():
            if state == [2, 2]:
                next_value_table[state[0]][state[1]] = 0.0
                continue
            value_list = []
            for action in self.env.possible_actions:
                next_state = self.env.state_after_action(state, action)
                reward = self.env.get_reward(state, action)
                next_value = self.get_value(next_state)
                value_list.append(reward + self.discount_factor * next_value)
            next_value_table[state[0]][state[1]] = round(max(value_list), 2)
        self.value_table = next_value_table

    def get_action(self, state):
        if state == [2, 2]:
            return []
        action_list = []
        max_value = float('-inf')
        for action in self.env.possible_actions:
            next_state = self.env.state_after_action(state, action)
            reward = self.env.get_reward(state, action)
            next_value = self.get_value(next_state)
            value = reward + self.discount_factor * next_value
            if value > max_value:
                action_list = [action]
                max_value = value
            elif value == max_value:
                action_list.append(action)
        return action_list

    def get_value(self, state):
        return round(self.value_table[state[0]][state[1]], 2)


In [4]:
class GraphicDisplay(tk.Tk):
    def __init__(self, value_iteration):
        super(GraphicDisplay, self).__init__()
        self.title('Value Iteration')
        self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT + 50))
        self.texts = []
        self.arrows = []
        self.env = Env()
        self.agent = value_iteration
        self.iteration_count = 0
        self.improvement_count = 0
        self.is_moving = 0

        # Load images and keep strong references
        self.load_images()
        self.canvas = self._build_canvas()
        self.canvas.image_refs = [self.rectangle_img, self.triangle_img, self.circle_img,
                                  self.up, self.down, self.left, self.right]

        self.text_reward(2, 2, "R : 1.0")
        self.text_reward(1, 2, "R : -1.0")
        self.text_reward(2, 1, "R : -1.0")

    def load_images(self):
        self.up = ImageTk.PhotoImage(Image.open("img/up.png").resize((13, 13)))
        self.down = ImageTk.PhotoImage(Image.open("img/down.png").resize((13, 13)))
        self.left = ImageTk.PhotoImage(Image.open("img/left.png").resize((13, 13)))
        self.right = ImageTk.PhotoImage(Image.open("img/right.png").resize((13, 13)))
        self.rectangle_img = ImageTk.PhotoImage(Image.open("img/rectangle.png").resize((65, 65)))
        self.triangle_img = ImageTk.PhotoImage(Image.open("img/triangle.png").resize((65, 65)))
        self.circle_img = ImageTk.PhotoImage(Image.open("img/circle.png").resize((65, 65)))

    def _build_canvas(self):
        canvas = tk.Canvas(self, bg='white', height=HEIGHT * UNIT, width=WIDTH * UNIT)

        # Buttons
        canvas.create_window(WIDTH * UNIT * 0.13, HEIGHT * UNIT + 10,
                             window=tk.Button(self, text="Calculate", command=self.calculate_value, width=10))
        canvas.create_window(WIDTH * UNIT * 0.37, HEIGHT * UNIT + 10,
                             window=tk.Button(self, text="Print Policy", command=self.print_optimal_policy, width=10))
        canvas.create_window(WIDTH * UNIT * 0.62, HEIGHT * UNIT + 10,
                             window=tk.Button(self, text="Move", command=self.move_by_policy, width=10))
        canvas.create_window(WIDTH * UNIT * 0.87, HEIGHT * UNIT + 10,
                             window=tk.Button(self, text="Clear", command=self.clear, width=10))

        # Grid lines
        for col in range(0, WIDTH * UNIT, UNIT):
            canvas.create_line(col, 0, col, HEIGHT * UNIT)
        for row in range(0, HEIGHT * UNIT, UNIT):
            canvas.create_line(0, row, HEIGHT * UNIT, row)

        # Add images
        self.rectangle = canvas.create_image(50, 50, image=self.rectangle_img)
        canvas.create_image(250, 150, image=self.triangle_img)
        canvas.create_image(150, 250, image=self.triangle_img)
        canvas.create_image(250, 250, image=self.circle_img)

        canvas.pack()
        return canvas

    def clear(self):
        if self.is_moving == 0:
            self.iteration_count = 0
            self.improvement_count = 0
            for i in self.texts: self.canvas.delete(i)
            for i in self.arrows: self.canvas.delete(i)
            self.agent.value_table = [[0.0] * WIDTH for _ in range(HEIGHT)]
            x, y = self.canvas.coords(self.rectangle)
            self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)

    def text_value(self, row, col, contents, font='Helvetica', size=12, style='normal', anchor="nw"):
        x, y = 85 + UNIT * row, 70 + UNIT * col
        text = self.canvas.create_text(y, x, fill="black", text=contents,
                                       font=(font, str(size), style), anchor=anchor)
        self.texts.append(text)

    def text_reward(self, row, col, contents, font='Helvetica', size=12, style='normal', anchor="nw"):
        x, y = 5 + UNIT * row, 5 + UNIT * col
        text = self.canvas.create_text(y, x, fill="black", text=contents,
                                       font=(font, str(size), style), anchor=anchor)
        self.texts.append(text)

    def calculate_value(self):
        self.iteration_count += 1
        for i in self.texts: self.canvas.delete(i)
        self.agent.value_iteration()
        self.print_values(self.agent.value_table)

    def move_by_policy(self):
        if self.improvement_count != 0 and self.is_moving != 1:
            self.is_moving = 1
            x, y = self.canvas.coords(self.rectangle)
            self.canvas.move(self.rectangle, UNIT / 2 - x, UNIT / 2 - y)
            x, y = self.find_rectangle()
            while len(self.agent.get_action([x, y])) != 0:
                action = random.choice(self.agent.get_action([x, y]))
                self.after(100, self.rectangle_move(action))
                x, y = self.find_rectangle()
            self.is_moving = 0

    def rectangle_move(self, action):
        base_action = np.array([0, 0])
        location = self.find_rectangle()
        self.render()
        if action == 0 and location[0] > 0: base_action[1] -= UNIT
        elif action == 1 and location[0] < HEIGHT - 1: base_action[1] += UNIT
        elif action == 2 and location[1] > 0: base_action[0] -= UNIT
        elif action == 3 and location[1] < WIDTH - 1: base_action[0] += UNIT
        self.canvas.move(self.rectangle, base_action[0], base_action[1])

    def find_rectangle(self):
        temp = self.canvas.coords(self.rectangle)
        return int((temp[1] / 100) - 0.5), int((temp[0] / 100) - 0.5)

    def draw_one_arrow(self, col, row, action):
        if [col, row] == [2, 2]: return
        pos = {
            0: (50 + UNIT * row, 10 + UNIT * col),
            1: (50 + UNIT * row, 90 + UNIT * col),
            3: (90 + UNIT * row, 50 + UNIT * col),
            2: (10 + UNIT * row, 50 + UNIT * col)
        }
        arrow_imgs = {0: self.up, 1: self.down, 2: self.left, 3: self.right}
        self.arrows.append(self.canvas.create_image(*pos[action], image=arrow_imgs[action]))

    def draw_from_values(self, state, action_list):
        for action in action_list:
            self.draw_one_arrow(state[0], state[1], action)

    def print_values(self, values):
        for i in range(WIDTH):
            for j in range(HEIGHT):
                self.text_value(i, j, values[i][j])

    def render(self):
        time.sleep(0.1)
        self.canvas.tag_raise(self.rectangle)
        self.update()

    def print_optimal_policy(self):
        self.improvement_count += 1
        for i in self.arrows: self.canvas.delete(i)
        for state in self.env.get_all_states():
            action = self.agent.get_action(state)
            self.draw_from_values(state, action)


In [None]:
# Run this in the last cell to launch the environment window
env = Env()
vi = ValueIteration(env)
#from environment import GraphicDisplay  # if kept separate
#grid_world = GraphicDisplay(vi) #if class is in notebook
grid_world = GraphicDisplay(vi)
grid_world.mainloop()
