In [1]:
from typing import *
import os
from glob import glob
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from nptyping import NDArray
from IPython.display import display

# 型推定
from matplotlib.figure import Figure
from matplotlib.axes._subplots import Subplot

sns.set_style('whitegrid')
colors = ['#de3838', '#007bc3', '#ffd12a']
markers = ['o', 'x', ',']
%config InlineBackend.figure_formats = ['svg']

pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 100)
pd.set_option('display.width', 100)

cmap = sns.diverging_palette(255, 0, as_cmap=True)  # カラーパレットの定義

In [2]:
class Gridworld:
    def __init__(self, size: int = 8) -> None:
        self.size : int = size
        self.states : List[int] = list(range(size))
        self.actions : List[int] = [-1, 1]

        self.policy : Dict[Tuple[int, int], float] = {}
        for state in self.states:
            self.policy[(state, 1)] = 1/2
            self.policy[(state, -1)] = 1/2

        self.value : Dict[int, float] = {state: 0.0 for state in self.states}

    def move(self, state: int, action: int) -> Tuple[int, int]:
        if state in (0, self.size-1):  # Terminal state
            return 0, state       # Reward, Next state

        new_state : int = state + action

        if new_state == 0:
            return 1, new_state   # Reward, Next state

        if new_state == self.size-1:
            return -1, new_state  # Reward, Next state

        return 0, new_state     # Reward, Next state


def show_values(world: Gridworld, subplot: Optional[Subplot] = None, title : str ='values') -> None:
    if subplot is None:
        fig : Figure = plt.figure(figsize=(world.size*0.8, 1.7))
        subplot : Subplot = fig.add_subplot(1, 1, 1)

    result : NDArray[(1, world.size), float] = np.zeros([1, world.size])
    for state in world.states:
        result[0][state] = world.value[state]
    sns.heatmap(result, cmap=cmap, square=True, cbar=False, yticklabels=[], annot=True, fmt='3.1f', ax=subplot).set_title(title)



def policy_eval(world, gamma: float = 1, delta: float = 0.01):
    while True:
        delta_max : float = 0
        for state in world.states:
            v_new = 0
            for action in world.actions:
                r, s_new = world.move(state, action)
                v_new += world.policy[(state, action)] * (r + gamma * world.value[s_new])
        delta_max = max(delta_max, abs(world.value[state] - v_new))
        world.value[state] = v_new

        if delta_max < delta:
            break