# Reinforcement Learning (DQN) tutorial

- http://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

- OpenAI GymのCatPole task
- 環境の状態は (position, velocity, ...) など4つの数値が与えられるが
- DQNではカートを中心とした画像を入力とする
- 厳密に言うと状態＝現在の画像と1つ前の画像の差分

> Strictly speaking, we will present the state as the difference between the current screen patch and the previous one. This will allow the agent to take the velocity of the pole into account from one image.

TODO: DQNではなく、4つの数値を状態としたQ-Learningで学習

- OpenAI Gymを使うので `pip install gym` でインストール

In [18]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T

%matplotlib inline

In [22]:
# setup matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
plt.ion()

In [23]:
# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

In [15]:
env = gym.make('CartPole-v0').unwrapped
env

[2018-02-12 17:04:39,640] Making new env: CartPole-v0


<gym.envs.classic_control.cartpole.CartPoleEnv at 0x1128c7240>

## Experience Replay

- DQNは観測を蓄積しておいてあとでシャッフルしてサンプリングして使う

> Transition - a named tuple representing a single transition in our environment
ReplayMemory - a cyclic buffer of bounded size that holds the transitions observed recently. It also implements a .sample() method for selecting a random batch of transitions for training.

In [24]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

In [40]:
# namedtupleの使い方
t = Transition(1, 2, 3, 4)
print(t)
print(t.state, t.action, t.next_state, t.reward)

Transition(state=1, action=2, next_state=3, reward=4)
1 2 3 4


In [42]:
class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
    
    def push(self, *args):
        """Save a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        # memoryを使い切ったら古いのから上書きしていく
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [54]:
# ReplayMemoryの動作確認
rm = ReplayMemory(3)
rm.push(1, 1, 1, 1)
rm.push(2, 2, 2, 2)
rm.push(3, 3, 3, 3)
print(len(rm))
print(rm.memory)
rm.push(4, 4, 4, 4)
print(len(rm))
print(rm.memory)

3
[Transition(state=1, action=1, next_state=1, reward=1), Transition(state=2, action=2, next_state=2, reward=2), Transition(state=3, action=3, next_state=3, reward=3)]
3
[Transition(state=4, action=4, next_state=4, reward=4), Transition(state=2, action=2, next_state=2, reward=2), Transition(state=3, action=3, next_state=3, reward=3)]
