Skip to content

Commit

Permalink
feat: implement tensor to device. (#43, #34)
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Jul 12, 2021
1 parent 0478118 commit 597632f
Show file tree
Hide file tree
Showing 51 changed files with 165 additions and 170 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ For now, these algorithms are available:
| Algorithms(29) | Discrete | Continuous | Image | RNN | Command parameter |
| :-----------------------------: | :------: | :--------: | :---: | :--: | :---------------: |
| Q-Learning/Sarsa/Expected Sarsa || | | | qs |
| ~~CEM~~ ||| | | cem |
| PG |||| | pg |
| AC ||||| ac |
| A2C |||| | a2c |
Expand Down
16 changes: 2 additions & 14 deletions rls/algos/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self, *args, **kwargs):
super().__init__()
self.no_save = bool(kwargs.get('no_save', False))
self.base_dir = base_dir = kwargs.get('base_dir')
self.device = kwargs.get('device', None) or ("cuda" if t.cuda.is_available() else "cpu")
logger.info(colorize(f"PyTorch Tensor Device: {self.device}"))

self.cp_dir, self.log_dir, self.excel_dir = [os.path.join(base_dir, i) for i in ['model', 'log', 'excel']]

Expand All @@ -52,20 +54,6 @@ def __init__(self, *args, **kwargs):
if bool(kwargs.get('logger2file', False)):
set_log_file(log_file=os.path.join(self.log_dir, 'log.txt'))

def data_convert(self, data: Union[np.ndarray, List]) -> t.Tensor:
'''
TODO: Annotation
'''
if isinstance(data, tuple):
return tuple(
t.as_tensor(d)
if d is not None
else d
for d in data
)
else:
return t.as_tensor(data)

def _create_writer(self, log_dir: str) -> SummaryWriter:
if not self.no_save:
check_or_create(log_dir, 'logs(summaries)')
Expand Down
7 changes: 3 additions & 4 deletions rls/algos/base/off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ def _learn(self, function_dict: Dict = {}) -> NoReturn:
# --------------------------------------优先经验回放部分,获取重要性比例
if self.use_priority and self.use_isw:
_isw = self.data.get_IS_w().reshape(-1, 1) # [B, ] => [B, 1]
_isw = self.data_convert(_isw)
else:
_isw = t.tensor(1.)
_isw = 1.
# --------------------------------------

# --------------------------------------训练主程序,返回可能用于PER权重更新的TD error,和需要输出tensorboard的信息
Expand Down Expand Up @@ -191,7 +190,7 @@ def _apex_learn(self, function_dict: Dict, data: BatchExperiences, priorities) -
self.intermediate_variable_reset()
data = self._data_process2dict(data=data)

cell_state = (None,)
cell_state = None

if self.use_curiosity:
crsty_r, crsty_summaries = self.curiosity_model(data, cell_state)
Expand All @@ -213,6 +212,6 @@ def _apex_cal_td(self, data: BatchExperiences, function_dict: Dict = {}) -> np.n
'''
data = self._data_process2dict(data=data)

cell_state = (None,)
cell_state = None
td_error = self._cal_td(data, cell_state)
return np.squeeze(td_error.numpy())
1 change: 0 additions & 1 deletion rls/algos/base/on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def _learn(self, function_dict: Dict) -> NoReturn:
all_data = self.data.sample_generater()

for data, cell_state in all_data:
cell_state = self.data_convert(cell_state)
cell_state = {'obs': cell_state, 'obs_': cell_state}
summaries = _train(data, cell_state)

Expand Down
6 changes: 3 additions & 3 deletions rls/algos/base/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, envspec: EnvGroupArgs, **kwargs):
self.representation_net_params = dict(kwargs.get('representation_net_params', defaultdict(dict)))
self.use_rnn = bool(self.representation_net_params.get('use_rnn', False))
self.rep_net = DefaultRepresentationNetwork(obs_spec=self.obs_spec,
representation_net_params=self.representation_net_params)
representation_net_params=self.representation_net_params).to(self.device)

self.use_curiosity = bool(kwargs.get('use_curiosity', False))
if self.use_curiosity:
Expand Down Expand Up @@ -80,7 +80,7 @@ def rnn_cell_nums(self):
def initial_cell_state(self, batch: int) -> Tuple[t.Tensor]:
if self.use_rnn:
return self.rep_net.memory_net.initial_cell_state(batch=batch)
return (None,)
return None

def get_cell_state(self) -> Tuple[Optional[t.Tensor]]:
return self.cell_state
Expand All @@ -96,7 +96,7 @@ def _partial_reset_cell_state(self, index: Union[List, np.ndarray]) -> NoReturn:
根据环境的done的index,局部初始化RNN的隐藏状态
'''
assert isinstance(index, (list, np.ndarray)), 'assert isinstance(index, (list, np.ndarray))'
if self.cell_state[0] is not None and len(index) > 0:
if self.cell_state is not None and len(index) > 0:
_arr = np.ones(shape=self.cell_state[0].shape, dtype=np.float32) # h, c
_arr[index] = 0.
self.cell_state = [c * _arr for c in self.cell_state] # [A, B] * [A, B] => [A, B] 将某行全部替换为0.
Expand Down
6 changes: 3 additions & 3 deletions rls/algos/hierarchical/aoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self,
action_dim=self.a_dim,
options_num=self.options_num,
network_settings=network_settings,
is_continuous=self.is_continuous)
is_continuous=self.is_continuous).to(self.device)
if self.is_continuous:
self.log_std = -0.5 * t.ones((self.options_num, self.a_dim), requires_grad=True) # [P, A]
self.oplr = OPLR([self.net, self.rep_net, self.log_std], lr)
Expand Down Expand Up @@ -221,10 +221,10 @@ def _train(data, cell_state):
})

@iTensor_oNumpy
def train(self, BATCH, cell_state, kl_coef):
def train(self, BATCH, cell_states, kl_coef):
last_options = BATCH.last_options # [B,]
options = BATCH.options
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_state['obs']) # [B, P], [B, P, A], [B, P], [B, P]
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_states['obs']) # [B, P], [B, P, A], [B, P], [B, P]
(q, pi, beta) = self.net(feat)
options_onehot = t.nn.functional.one_hot(options, self.options_num).float() # [B, P]
options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1]
Expand Down
14 changes: 7 additions & 7 deletions rls/algos/hierarchical/hiro.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def __init__(self,

self.high_actor = ActorDPG(vector_dim=self.concat_vector_dim,
output_shape=self.sub_goal_dim,
network_settings=network_settings['high_actor'])
network_settings=network_settings['high_actor']).to(self.device)
self.high_critic = CriticQvalueOne(vector_dim=self.concat_vector_dim,
action_dim=self.sub_goal_dim,
network_settings=network_settings['high_critic'])
network_settings=network_settings['high_critic']).to(self.device)
self.high_critic2 = CriticQvalueOne(vector_dim=self.concat_vector_dim,
action_dim=self.sub_goal_dim,
network_settings=network_settings['high_critic'])
network_settings=network_settings['high_critic']).to(self.device)
self.high_actor_target = deepcopy(self.high_actor)
self.high_actor_target.eval()
self.high_critic_target = deepcopy(self.high_critic)
Expand All @@ -112,18 +112,18 @@ def __init__(self,
if self.is_continuous:
self.low_actor = ActorDPG(vector_dim=self.concat_vector_dim + self.sub_goal_dim,
output_shape=self.a_dim,
network_settings=network_settings['low_actor'])
network_settings=network_settings['low_actor']).to(self.device)
else:
self.low_actor = ActorDct(vector_dim=self.concat_vector_dim + self.sub_goal_dim,
output_shape=self.a_dim,
network_settings=network_settings['low_actor'])
network_settings=network_settings['low_actor']).to(self.device)
self.gumbel_dist = td.gumbel.Gumbel(0, 1)
self.low_critic = CriticQvalueOne(vector_dim=self.concat_vector_dim + self.sub_goal_dim,
action_dim=self.a_dim,
network_settings=network_settings['low_critic'])
network_settings=network_settings['low_critic']).to(self.device)
self.low_critic2 = CriticQvalueOne(vector_dim=self.concat_vector_dim + self.sub_goal_dim,
action_dim=self.a_dim,
network_settings=network_settings['low_critic'])
network_settings=network_settings['low_critic']).to(self.device)
self.low_actor_target = deepcopy(self.low_actor)
self.low_actor_target.eval()
self.low_critic_target = deepcopy(self.low_critic)
Expand Down
8 changes: 4 additions & 4 deletions rls/algos/hierarchical/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,22 @@ def __init__(self,

self.q_net = CriticQvalueAll(self.rep_net.h_dim,
output_shape=self.options_num,
network_settings=network_settings['q'])
network_settings=network_settings['q']).to(self.device)
self.q_target_net = deepcopy(self.q_net)
self.q_target_net.eval()

self.intra_option_net = OcIntraOption(vector_dim=self.rep_net.h_dim,
output_shape=self.a_dim,
options_num=self.options_num,
network_settings=network_settings['intra_option'])
network_settings=network_settings['intra_option']).to(self.device)
self.termination_net = CriticQvalueAll(vector_dim=self.rep_net.h_dim,
output_shape=self.options_num,
network_settings=network_settings['termination'],
out_act='sigmoid')
out_act='sigmoid').to(self.device)
self.interest_net = CriticQvalueAll(vector_dim=self.rep_net.h_dim,
output_shape=self.options_num,
network_settings=network_settings['interest'],
out_act='sigmoid')
out_act='sigmoid').to(self.device)

if self.is_continuous:
self.log_std = -0.5 * t.ones((self.options_num, self.a_dim), requires_grad=True) # [P, A]
Expand Down
6 changes: 3 additions & 3 deletions rls/algos/hierarchical/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self,

self.q_net = CriticQvalueAll(self.rep_net.h_dim,
output_shape=self.options_num,
network_settings=network_settings['q'])
network_settings=network_settings['q']).to(self.device)
self.q_target_net = deepcopy(self.q_net)
self.q_target_net.eval()

Expand All @@ -87,11 +87,11 @@ def __init__(self,
self.intra_option_net = OcIntraOption(vector_dim=self.rep_net.h_dim,
output_shape=self.a_dim,
options_num=self.options_num,
network_settings=network_settings['intra_option'])
network_settings=network_settings['intra_option']).to(self.device)
self.termination_net = CriticQvalueAll(vector_dim=self.rep_net.h_dim,
output_shape=self.options_num,
network_settings=network_settings['termination'],
out_act='sigmoid')
out_act='sigmoid').to(self.device)

if self.is_continuous:
self.log_std = -0.5 * t.ones((self.options_num, self.a_dim), requires_grad=True) # [P, A]
Expand Down
6 changes: 3 additions & 3 deletions rls/algos/hierarchical/ppoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self,
action_dim=self.a_dim,
options_num=self.options_num,
network_settings=network_settings,
is_continuous=self.is_continuous)
is_continuous=self.is_continuous).to(self.device)

if self.is_continuous:
self.log_std = -0.5 * t.ones((self.options_num, self.a_dim), requires_grad=True) # [P, A]
Expand Down Expand Up @@ -231,10 +231,10 @@ def _train(data, cell_state):
})

@iTensor_oNumpy
def share(self, BATCH, cell_state, kl_coef):
def share(self, BATCH, cell_states, kl_coef):
last_options = BATCH.last_options # [B,]
options = BATCH.options
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_state['obs']) # [B, P], [B, P, A], [B, P], [B, P]
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_states['obs']) # [B, P], [B, P, A], [B, P], [B, P]
(q, pi, beta, o) = self.net(feat)
options_onehot = t.nn.functional.one_hot(options, self.options_num).float() # [B, P]
options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1]
Expand Down
4 changes: 2 additions & 2 deletions rls/algos/multi/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(self,
self.oplrs = []
for i in range(self.n_models_percopy):
rep_net = DefaultRepresentationNetwork(obs_spec=self.envspecs[i].obs_spec,
representation_net_params=self.representation_net_params)
representation_net_params=self.representation_net_params).to(self.device)
q_net = CriticQvalueAll(rep_net.h_dim,
output_shape=self.envspecs[i].a_dim,
network_settings=network_settings)
network_settings=network_settings).to(self.device)
target_rep_net = deepcopy(rep_net)
target_rep_net.eval()
q_target_net = deepcopy(q_net)
Expand Down
8 changes: 4 additions & 4 deletions rls/algos/multi/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,18 @@ def __init__(self,

for i in range(self.n_models_percopy):
rep_net = DefaultRepresentationNetwork(obs_spec=self.envspecs[i].obs_spec,
representation_net_params=self.representation_net_params)
representation_net_params=self.representation_net_params).to(self.device)
if self.envspecs[i].is_continuous:
actor = ActorDPG(rep_net.h_dim,
output_shape=self.envspecs[i].a_dim,
network_settings=network_settings['actor_continuous'])
network_settings=network_settings['actor_continuous']).to(self.device)
else:
actor = ActorDct(rep_net.h_dim,
output_shape=self.envspecs[i].a_dim,
network_settings=network_settings['actor_discrete'])
network_settings=network_settings['actor_discrete']).to(self.device)
critic = CriticQvalueOne(rep_net.h_dim*self.n_models_percopy,
action_dim=sum([envspec.a_dim for envspec in self.envspecs]),
network_settings=network_settings['q'])
network_settings=network_settings['q']).to(self.device)
target_rep_net = deepcopy(rep_net)
target_rep_net.eval()
critic_target = deepcopy(critic)
Expand Down
4 changes: 2 additions & 2 deletions rls/algos/multi/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def __init__(self,
self.q_target_nets = []
for i in range(self.n_models_percopy):
rep_net = DefaultRepresentationNetwork(obs_spec=self.envspecs[i].obs_spec,
representation_net_params=self.representation_net_params)
representation_net_params=self.representation_net_params).to(self.device)
q_net = CriticDueling(rep_net.h_dim,
output_shape=self.envspecs[i].a_dim,
network_settings=network_settings)
network_settings=network_settings).to(self.device)
target_rep_net = deepcopy(rep_net)
target_rep_net.eval()
q_target_net = deepcopy(q_net)
Expand Down
10 changes: 5 additions & 5 deletions rls/algos/single/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def __init__(self,
if self.is_continuous:
self.actor = ActorMuLogstd(self.rep_net.h_dim,
output_shape=self.a_dim,
network_settings=network_settings['actor_continuous'])
network_settings=network_settings['actor_continuous']).to(self.device)
else:
self.actor = ActorDct(self.rep_net.h_dim,
output_shape=self.a_dim,
network_settings=network_settings['actor_discrete'])
network_settings=network_settings['actor_discrete']).to(self.device)
self.critic = CriticValue(self.rep_net.h_dim,
network_settings=network_settings['critic'])
network_settings=network_settings['critic']).to(self.device)

self.actor_op = OPLR(self.actor, actor_lr)
self.critic_op = OPLR([self.critic, self.rep_net], critic_lr)
Expand Down Expand Up @@ -116,8 +116,8 @@ def _train(data, cell_state):
})

@iTensor_oNumpy
def train(self, BATCH, cell_state):
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_state['obs'])
def train(self, BATCH, cell_states):
feat, _ = self.rep_net(BATCH.obs, cell_state=cell_states['obs'])
output = self.actor(feat)
v = self.critic(feat)
if self.is_continuous:
Expand Down
Loading

0 comments on commit 597632f

Please sign in to comment.