Skip to content

Commit

Permalink
more cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
kalifou committed May 6, 2019
1 parent 8b268e8 commit 83587e6
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 22 deletions.
13 changes: 7 additions & 6 deletions environments/dataset_fusioner.py
Expand Up @@ -10,22 +10,23 @@


def main():
parser = argparse.ArgumentParser(description='Dataset Manipulator: useful to fusion two datasets by concatenating '
'episodes. PS: Deleting sources after fusion into destination folder.')
parser = argparse.ArgumentParser(description='Dataset Manipulator: useful to merge two datasets by concatenating '
+ 'episodes. PS: Deleting sources after merging into the destination '
+ 'folder.')
group = parser.add_mutually_exclusive_group()
group.add_argument('--merge', type=str, nargs=3, metavar=('source_1', 'source_2', 'destination'),
default=argparse.SUPPRESS,
help='Fusion two datasets by appending the episodes, deleting sources right after.')
help='Merge two datasets by appending the episodes, deleting sources right after.')

args = parser.parse_args()

if 'merge' in args:
# let make sure everything is in order
assert os.path.exists(args.merge[0]), "Error: dataset '{}' could not be found".format(args.merge[0])
assert (not os.path.exists(args.merge[2])), "Error: dataset '{}' already exists, cannot rename '{}' to '{}'"\
.format(args.merge[2], args.merge[0], args.merge[2])
assert (not os.path.exists(args.merge[2])), \
"Error: dataset '{}' already exists, cannot rename '{}' to '{}'".format(args.merge[2], args.merge[0],
args.merge[2])
# create the output
print(args)
os.mkdir(args.merge[2])

# copy files from first source
Expand Down
8 changes: 6 additions & 2 deletions environments/dataset_generator.py
Expand Up @@ -14,6 +14,7 @@

from environments import ThreadingType
from environments.registry import registered_env
from real_robots.constants import USING_OMNIROBOT
from srl_zoo.utils import printRed, printYellow

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # used to remove debug info of tensorflow
Expand Down Expand Up @@ -61,6 +62,7 @@ def env_thread(args, thread_num, partition=True, use_ppo2=False):

env_class = registered_env[args.env][0]
env = env_class(**env_kwargs)
using_real_omnibot = args.env == "OmnirobotEnv-v0" and USING_OMNIROBOT

model = None
if use_ppo2:
Expand Down Expand Up @@ -92,7 +94,9 @@ def env_thread(args, thread_num, partition=True, use_ppo2=False):
if use_ppo2:
action, _ = model.predict([obs])
else:
if episode_toward_target_on and np.random.rand() < args.toward_target_timesteps_proportion:
# Using a target reaching policy (untrained, from camera) when collecting data from real OmniRobot
if episode_toward_target_on and np.random.rand() < args.toward_target_timesteps_proportion and \
using_real_omnibot:
action = [env.actionPolicyTowardTarget()]
else:
action = [env.action_space.sample()]
Expand All @@ -103,7 +107,7 @@ def env_thread(args, thread_num, partition=True, use_ppo2=False):
frames += 1
t += 1
if done:
if np.random.rand() < args.toward_target_timesteps_proportion:
if np.random.rand() < args.toward_target_timesteps_proportion and using_real_omnibot:
episode_toward_target_on = True
else:
episode_toward_target_on = False
Expand Down
12 changes: 6 additions & 6 deletions real_robots/omnirobot_utils/marker_finder.py
Expand Up @@ -16,15 +16,15 @@ def rotateMatrix90(matrix):
return new_matrix


def hammingDistance(s1, s2):
def hammingDistance(string_1, string_2):
"""
:param s1: (str)
:param s2: (str)
:return: (int) Hamming distance between s1 & s2
:param string_1: (str)
:param string_2: (str)
:return: (int) Hamming distance between string_1 & string_2
"""
assert len(s1) == len(s2)
return sum(ch1 != ch2 for ch1, ch2 in zip(s1, s2))
assert len(string_1) == len(string_2)
return sum(ch1 != ch2 for ch1, ch2 in zip(string_1, string_2))


class MakerFinder():
Expand Down
2 changes: 1 addition & 1 deletion real_robots/omnirobot_utils/omnirobot_manager_base.py
Expand Up @@ -4,7 +4,7 @@


class OmnirobotManagerBase(object):
def __init__(self, second_cam_topic=None):
def __init__(self):
"""
This class is the basic class for omnirobot server, and omnirobot simulator's server.
This class takes omnirobot position at instant t, and takes the action at instant t,
Expand Down
15 changes: 8 additions & 7 deletions rl_baselines/base_classes.py
Expand Up @@ -137,9 +137,10 @@ def save(self, save_path, _locals=None):
}
with open(save_path, "wb") as f:
pickle.dump(save_param, f)

def setLoadPath(self, load_path):
"""
Load the only the parameters of the neuro-network model from a path
Set the path to later load the parameters of a trained rl model
:param load_path: (str)
:return: None
"""
Expand Down Expand Up @@ -237,12 +238,12 @@ def train(self, args, callback, env_kwargs=None, train_kwargs=None):
self.ob_space = envs.observation_space
self.ac_space = envs.action_space

policy_fn = {'cnn': CnnPolicy,
'cnn-lstm': CnnLstmPolicy,
'cnn-lnlstm': CnnLnLstmPolicy,
'mlp': MlpPolicy,
'lstm': MlpLstmPolicy,
'lnlstm': MlpLnLstmPolicy}[args.policy]
policy_fn = {'cnn': "CnnPolicy",
'cnn-lstm': "CnnLstmPolicy",
'cnn-lnlstm': "CnnLnLstmPolicy",
'mlp': "MlpPolicy",
'lstm': "MlpLstmPolicy",
'lnlstm': "MlpLnLstmPolicy"}[args.policy]
if self.load_rl_model_path is not None:
print("Load trained model from the path: ", self.load_rl_model_path)
self.model = self.model_class.load(self.load_rl_model_path, envs, **train_kwargs)
Expand Down

0 comments on commit 83587e6

Please sign in to comment.