Skip to content

Commit

Permalink
feat: implement adaptive multi-vector network (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Jan 6, 2021
1 parent 6ec60cd commit e3b46d9
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 15 deletions.
2 changes: 1 addition & 1 deletion rls/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '2'
_MINOR_VERSION = '0'
_PATCH_VERSION = '0'
_PATCH_VERSION = '1'

# Example: '0.4.2'
__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
10 changes: 7 additions & 3 deletions rls/algos/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ general:
normalize_vector_obs: false
logger2file: false
# ----- could be overrided in specific algorithms, i.e. dqn, so as to using different type of visual net, memory net.
vector_net_kwargs: {}
vector_net_kwargs:
network_type: adaptive # rls.utils.specs.VectorNetworkType
visual_net_kwargs:
visual_feature: 128
network_type: simple
network_type: simple # rls.utils.specs.VisualNetworkType
encoder_net_kwargs:
use_encoder: false
output_dim: 16
Expand Down Expand Up @@ -124,6 +125,9 @@ dqn: &dqn
use_priority: false
n_step: true
network_settings: [64, 64]
visual_net_kwargs:
visual_feature: 128
network_type: nature

ddqn: *dqn

Expand Down Expand Up @@ -474,7 +478,7 @@ dpg:
ddpg:
gamma: 0.99
ployak: 0.995
noise_type: 'ou' # ou or gaussian
noise_type: "ou" # ou or gaussian
gaussian_noise_sigma: 0.2 # specify the variance of gaussian distribution
gaussian_noise_bound: 0.5 # specify the clipping bound of sampled noise, noise must in range of [-bound, bound]
actor_lr: 5.0e-4
Expand Down
56 changes: 45 additions & 11 deletions rls/nn/networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# encoding: utf-8

import math
import tensorflow as tf

from typing import Tuple
Expand All @@ -17,10 +18,19 @@
from rls.nn.layers import ConvLayer
from rls.nn.activations import default_activation
from rls.nn.initializers import initKernelAndBias
from rls.utils.specs import (VisualNetworkType,
from rls.utils.specs import (VectorNetworkType,
VisualNetworkType,
MemoryNetworkType)


def get_vector_network_from_type(network_type: VectorNetworkType):
VECTOR_NETWORKS = {
VectorNetworkType.CONCAT: VectorConcatNetwork,
VectorNetworkType.ADAPTIVE: VectorAdaptiveNetwork
}
return VECTOR_NETWORKS.get(network_type, VECTOR_NETWORKS[VectorNetworkType.CONCAT])


def get_visual_network_from_type(network_type: VisualNetworkType):
VISUAL_NETWORKS = {
VisualNetworkType.SIMPLE: lambda: ConvLayer(Conv2D, [16, 32], [[8, 8], [4, 4]], [[4, 4], [2, 2]], padding='valid', activation='elu'),
Expand All @@ -32,6 +42,34 @@ def get_visual_network_from_type(network_type: VisualNetworkType):
return VISUAL_NETWORKS.get(network_type, VISUAL_NETWORKS[VisualNetworkType.SIMPLE])


class VectorConcatNetwork:

def __init__(self, *args, **kwargs):
assert 'in_dim' in kwargs.keys(), "assert dim in kwargs.keys()"
self.h_dim = self.in_dim = int(kwargs['in_dim'])
pass

def __call__(self, x):
return x


class VectorAdaptiveNetwork(Sequential):

def __init__(self, **kwargs):
super().__init__()
assert 'in_dim' in kwargs.keys(), "assert dim in kwargs.keys()"
self.in_dim = int(kwargs['in_dim'])
self.h_dim = self.out_dim = int(kwargs.get('out_dim', 16))
x = math.log2(self.out_dim)
y = math.log2(self.in_dim)
l = math.ceil(x) + 1 if math.ceil(x) == math.floor(x) else math.ceil(x)
r = math.floor(y) if math.ceil(y) == math.floor(y) else math.ceil(y)

for dim in range(l, r)[::-1]:
self.add(Dense(2**dim, default_activation, **initKernelAndBias))
self.add(Dense(self.out_dim, default_activation, **initKernelAndBias))


class DeepConvNetwork(Sequential):

def __init__(self,
Expand Down Expand Up @@ -102,14 +140,12 @@ def call(self, x):


class MultiVectorNetwork(M):
def __init__(self, vector_dim=[]):
# TODO
def __init__(self, vector_dim=[], network_type=VectorNetworkType.CONCAT):
super().__init__()
self.nets = []
for _ in vector_dim:
def net(x): return x
self.nets.append(net)
self.h_dim = sum(vector_dim)
for in_dim in vector_dim:
self.nets.append(get_vector_network_from_type(network_type)(in_dim=in_dim))
self.h_dim = sum([net.h_dim for net in self.nets])
if vector_dim:
self(*(I(shape=dim) for dim in vector_dim))

Expand All @@ -118,8 +154,7 @@ def call(self, *vector_inputs):
output = []
for net, s in zip(self.nets, vector_inputs):
output.append(net(s))
if output:
output = tf.concat(output, axis=-1)
output = tf.concat(output, axis=-1)
return output


Expand All @@ -146,8 +181,7 @@ def call(self, *visual_inputs):
net(visual_s)
)
)
if output:
output = tf.concat(output, axis=-1)
output = tf.concat(output, axis=-1)
return output


Expand Down
5 changes: 5 additions & 0 deletions rls/utils/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ class GymVectorizedType(Enum):
MULTIPROCESSING = 'multiprocessing'


class VectorNetworkType(Enum):
CONCAT = 'concat'
ADAPTIVE = 'adaptive'


class VisualNetworkType(Enum):
MATCH3 = 'match3'
SIMPLE = 'simple'
Expand Down

0 comments on commit e3b46d9

Please sign in to comment.