Permalink
Branch: master
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
152 lines (128 sloc) 5.69 KB
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import mxnet.ndarray as nd
import numpy
import cv2
from scipy.stats import entropy
from utils import *
class DQNOutput(mx.operator.CustomOp):
def __init__(self):
super(DQNOutput, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
self.assign(out_data[0], req[0], in_data[0])
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
# TODO Backward using NDArray will cause some troubles see `https://github.com/dmlc/mxnet/issues/1720'
x = out_data[0].asnumpy()
action = in_data[1].asnumpy().astype(numpy.int)
reward = in_data[2].asnumpy()
dx = in_grad[0]
ret = numpy.zeros(shape=dx.shape, dtype=numpy.float32)
ret[numpy.arange(action.shape[0]), action] \
= numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1)
self.assign(dx, req[0], ret)
@mx.operator.register("DQNOutput")
class DQNOutputProp(mx.operator.CustomOpProp):
def __init__(self):
super(DQNOutputProp, self).__init__(need_top_grad=False)
def list_arguments(self):
return ['data', 'action', 'reward']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
action_shape = (in_shape[0][0],)
reward_shape = (in_shape[0][0],)
output_shape = in_shape[0]
return [data_shape, action_shape, reward_shape], [output_shape], []
def create_operator(self, ctx, shapes, dtypes):
return DQNOutput()
class DQNOutputNpyOp(mx.operator.NumpyOp):
def __init__(self):
super(DQNOutputNpyOp, self).__init__(need_top_grad=False)
def list_arguments(self):
return ['data', 'action', 'reward']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
action_shape = (in_shape[0][0],)
reward_shape = (in_shape[0][0],)
output_shape = in_shape[0]
return [data_shape, action_shape, reward_shape], [output_shape]
def forward(self, in_data, out_data):
x = in_data[0]
y = out_data[0]
y[:] = x
def backward(self, out_grad, in_data, out_data, in_grad):
x = out_data[0]
action = in_data[1].astype(numpy.int)
reward = in_data[2]
dx = in_grad[0]
dx[:] = 0
dx[numpy.arange(action.shape[0]), action] \
= numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1)
def dqn_sym_nips(action_num, data=None, name='dqn'):
"""Structure of the Deep Q Network in the NIPS 2013 workshop paper:
Playing Atari with Deep Reinforcement Learning (https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)
Parameters
----------
action_num : int
data : mxnet.sym.Symbol, optional
name : str, optional
"""
if data is None:
net = mx.symbol.Variable('data')
else:
net = data
net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=16)
net = mx.symbol.Activation(data=net, name='relu1', act_type="relu")
net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=32)
net = mx.symbol.Activation(data=net, name='relu2', act_type="relu")
net = mx.symbol.Flatten(data=net)
net = mx.symbol.FullyConnected(data=net, name='fc3', num_hidden=256)
net = mx.symbol.Activation(data=net, name='relu3', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=action_num)
net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput')
return net
def dqn_sym_nature(action_num, data=None, name='dqn'):
"""Structure of the Deep Q Network in the Nature 2015 paper:
Human-level control through deep reinforcement learning
(http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
Parameters
----------
action_num : int
data : mxnet.sym.Symbol, optional
name : str, optional
"""
if data is None:
net = mx.symbol.Variable('data')
else:
net = data
net = mx.symbol.Variable('data')
net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=32)
net = mx.symbol.Activation(data=net, name='relu1', act_type="relu")
net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=64)
net = mx.symbol.Activation(data=net, name='relu2', act_type="relu")
net = mx.symbol.Convolution(data=net, name='conv3', kernel=(3, 3), stride=(1, 1), num_filter=64)
net = mx.symbol.Activation(data=net, name='relu3', act_type="relu")
net = mx.symbol.Flatten(data=net)
net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=512)
net = mx.symbol.Activation(data=net, name='relu4', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='fc5', num_hidden=action_num)
net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput')
return net