In [1]:
# TD法

In [11]:
import numpy as np
from collections import defaultdict, deque
from common import GridWorld
from utils import greedy_probs

In [2]:
### 6.1.2 TD法の実装

In [6]:
class TdAgent:
    def __init__(self):
        self.gamma = 0.9
        self.alpha = 0.01
        self.action_size = 4
        
        random_actions = {0:0.25, 1:0.25, 2:0.25, 3:0.25}
        self.pi = defaultdict(lambda: random_actions)
        self.V = defaultdict(lambda: 0)
        
    def get_action(self, state):
        action_probs = self.pi[state]
        actions = list(action_probs.keys())
        probs = list(action_probs.values())
        return np.random.choice(actions, p=probs)
    
    def eval(self, state, reward, next_state, done):
        next_V = 0 if done else self.V[next_state]
        target = reward + self.gamma*next_V
        
        self.V[state] += (target - self.V[state])*self.alpha

In [9]:
env = GridWorld()
agent = TdAgent()

episodes = 1000
for episode in range(episodes):
    state = env.reset()
    
    while True:
        action = agent.get_action(state)
        next_state, reward, done = env.step(action)
        
        agent.eval(state, reward, next_state, done)
        if done:
            break
        state = next_state
agent.V

defaultdict(<function __main__.TdAgent.__init__.<locals>.<lambda>()>,
            {(2, 1): -0.23948985631315103,
             (2, 0): -0.09859375674663563,
             (2, 2): -0.46631482821763454,
             (1, 2): -0.5450153272896253,
             (0, 2): 0.11373250550849406,
             (0, 1): 0.0761719787120954,
             (2, 3): -0.8663981775657867,
             (1, 3): -0.5071596521589833,
             (1, 0): -0.030045892344131284,
             (0, 0): 0.029374575131925366})

In [10]:
### 6.2.2 SARSAの実装

In [12]:
class SarsaAgent:
    def __init__(self):
        self.gamma = 0.9
        self.alpha = 0.8
        self.epsilon = 0.1
        self.action_size = 4
        
        random_actions = {0:0.25, 1:0.25, 2:0.25, 3:0.25}
        self.pi = defaultdict(lambda: random_actions)
        self.Q = defaultdict(lambda: 0)
        self.memory = deque(maxlen=2)
        
    def get_action(self, state):
        action_probs = self.pi[state]
        actions = list(action_probs.keys())
        probs = list(action_probs.values())
        return np.random.choice(actions, p=probs)
    
    def reset(self):
        self.memory.clear()
        
    def update(self, state, action, reward, done):
        self.memory.append((state, action ,reward, done))
        if len(self.memory)<2:
            return
        
        state, action, reward, done = self.memory[0]
        next_state, next_action, _, _ = self.memory[1]
        
        next_q = 0 if done else self.Q[next_state, next_action]
        
        target = reward + self.gamma*next_q
        self.Q[state, action] += (target - self.Q[state, action])*self.alpha
        
        self.pi[state] = greedy_probs(self.Q, state, self.epsilon)

In [14]:
env = GridWorld()
agent = SarsaAgent()

episodes = 10000
for eisode in range(episodes):
    state = env.reset()
    agent.reset()
    
    while True:
        action = agent.get_action(state)
        next_state, reward, done = env.step(action)
        
        agent.update(state, action, reward, done)
        
        if done:
            agent.update(state, None, None, None)
            break
        state = next_state
agent.Q

defaultdict(<function __main__.SarsaAgent.__init__.<locals>.<lambda>()>,
            {((2, 0), 2): 0.3706059974348025,
             ((2, 0), 1): 0.343204195001758,
             ((2, 0), 0): 0.42686789833120464,
             ((2, 0), 3): 0.43701584540100236,
             ((2, 1), 1): 0.3085119551596592,
             ((2, 1), 0): 0.2956546353901475,
             ((2, 1), 2): 0.4165342877496803,
             ((2, 1), 3): 0.30347771559866743,
             ((2, 2), 1): 0.2726559029779189,
             ((2, 2), 0): 0.08999941634508046,
             ((2, 2), 2): 0.30341040788930557,
             ((2, 2), 3): 0.3054725294207578,
             ((1, 2), 2): 0.7748900926317673,
             ((1, 2), 0): 0.8999999999997971,
             ((1, 2), 1): 0.6469431526082596,
             ((1, 2), 3): -0.10000017653760002,
             ((0, 2), 2): 0.6341651225695195,
             ((0, 1), 0): 0.5483945136068872,
             ((0, 2), 0): 0.899880560639984,
             ((0, 2), 1): 0.617563575782853,
   