In [1]:
!pip install gym



In [2]:
import gym
env = gym.make('CartPole-v0')

In [3]:
state = env.reset()
print(state) # 初期状態

action_space = env.action_space
print(action_space) # 行動の次元数

[-0.03185282  0.02733258 -0.03137989  0.01263766]
Discrete(2)


- カートの位置
- カートの速度
- 棒の角度
- 棒の角速度

In [4]:
action = 0 # 0 or 1
next_state, reward, done, info = env.step(action)
print(next_state)

[-0.03130617 -0.16732562 -0.03112714  0.29525703]


env.step(action)の結果
- 次の状態（next_state）
- 報酬（reward）
- 終了かどうかのフラグ（done）
- 追加の情報（info）

### ランダムなエージェント

In [9]:
import numpy as np
import gym


env = gym.make('CartPole-v0')
state = env.reset()
done = False

while not done:
    # env.render()
    action = np.random.choice([0, 1])
    next_state, reward, done, info = env.step(action)
env.close()

### 経験再生の実装

In [10]:
from collections import deque
import random
import numpy as np

class ReplayBuffer:
    def __init__(self, buffer_size, batch_size):
        self.buffer = deque(maxlen=buffer_size)
        self.batch_size = batch_size
    
    def add(self, state, action, reward, next_state, done):
        data = (state, action, reward, next_state, done)
        self.buffer.append(data)
    
    def __len__(self):
        return len(self.buffer)
    
    def get_batch(self):
        data = random.sample(self.buffer, self.batch_size)

        state = np.stack([x[0] for x in data])
        action = np.stack([x[1] for x in data])
        reward = np.stack([x[2] for x in data])
        next_state = np.stack([x[3] for x in data])
        done = np.array([x[4] for x in data]).astype(np.int32)
        return state, action, reward, next_state, done

In [11]:
import gym

env = gym.make('CartPole-v0')
replay_buffer = ReplayBuffer(buffer_size=10000, batch_size=32)

for episode in range(10):
    state = env.reset()
    done = False

    while not done:
        action = 0
        next_state, reward, done, info = env.step(action)
        replay_buffer.add(state, action, reward, next_state, done)
        state = next_state
    
state, action, reward, next_state, done = replay_buffer.get_batch()
print(state.shape)
print(action.shape)
print(reward.shape)
print(next_state.shape)
print(done.shape)

(32, 4)
(32,)
(32,)
(32, 4)
(32,)


### ターゲットネットワークの実装

In [12]:
!pip install dezero

Collecting dezero
  Downloading dezero-0.0.13-py3-none-any.whl (28 kB)
Installing collected packages: dezero
Successfully installed dezero-0.0.13


In [14]:
import copy
from dezero import Model
from dezero import optimizers
import dezero.functions as F
import dezero.layers as L

class QNet(Model):
    def __init__(self, action_size):
        super().__init__()
        self.l1 = L.Linear(128)
        self.l2 = L.Linear(128)
        self.l3 = L.Linear(action_size)
    
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x

class DQNAgent:
    def __init__(self):
        self.gamma = 0.98
        self.lr = 0.0005
        self.epsilon = 0.1
        self.buffer_size = 10000
        self.batch_size = 32
        self.action_size = 2

        self.replay_buffer = ReplayBuffer(self.buffer_size, self.batch_size)
        self.qnet = QNet(self.action_size)
        self.qnet_target = QNet(self.action_size)
        self.optimizer = optimizers.Adam(self.lr)
        self.optimizer.setup(self.qnet) # qnetを設定
    
    def sync_qnet(self):
        self.qnet_target = copy.deepcopy(self.qnet)
    
    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.action_size)
        else:
            state = state[np.newaxis, :] # バッチの次元を追加
            qs = self.qnet(state)
            return qs.data.argmax()
    
    def update(self, state, action, reward, next_state, done):
        self.replay_buffer.add(state, action, reward, next_state, done)
        if len(self.replay_buffer) < self.batch_size:
            return
        
        state, action, reward, next_state, done = self.replay_buffer.get_batch()
        qs = self.qnet(state) # ①
        q = qs[np.arange(self.batch_size), action] # ②

        next_qs = self.qnet_target(next_state) # ③
        next_q = next_qs.max(axis=1)
        next_q.unchain()
        target = reward + (1 - done) * self.gamma * next_q # ④

        loss = F.mean_squared_error(q, target)

        self.qnet.cleargrads()
        loss.backward()
        self.optimizer.update()

### DQNを動かす

In [15]:
episodes = 300
sync_interval = 20
env = gym.make('CartPole-v0')
agent = DQNAgent()
reward_history = []

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        action = agent.get_action(state)
        next_state, reward, done, info = env.step(action)

        agent.update(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
    
    if episode % sync_interval == 0:
        agent.sync_qnet()
    
    reward_history.append(total_reward)

In [16]:
# greedyな行動
agent.epsilon = 0 # greedy policy
state = env.reset()
done = False
total_reward = 0

while not done:
    action = agent.get_action(state)
    next_state, reward, done, info = env.step(action)
    state = next_state
    total_reward += reward
    # env.render()
print('Total Reward:', total_reward)

Total Reward: 178.0
