Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isolating multiprocessing change #39

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions mjrl/samplers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def sample_paths(
start_time = timer.time()
print("####### Gathering Samples #######")

results = _try_multiprocess(do_rollout, input_dict_list,
results = _try_multiprocess_cf(do_rollout, input_dict_list,
num_cpu, max_process_time, max_timeouts)
paths = []
# result is a paths type and results is list of paths
Expand Down Expand Up @@ -186,7 +186,7 @@ def sample_data_batch(
return paths


def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
def _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts):

# Base case
if max_timeouts == 0:
Expand All @@ -202,9 +202,29 @@ def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_time
pool.close()
pool.terminate()
pool.join()
return _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)
return _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1)

pool.close()
pool.terminate()
pool.join()
return results

def _try_multiprocess_cf(func, input_dict_list, num_cpu, max_process_time, max_timeouts):
import concurrent.futures
results = None
if max_timeouts != 0:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpu) as executor:
submit_futures = [executor.submit(func, **input_dict) for input_dict in input_dict_list]
try:
results = [f.result() for f in submit_futures]
except TimeoutError as e:
print(str(e))
print("Timeout Error raised...")
except concurrent.futures.CancelledError as e:
print(str(e))
print("Future Cancelled Error raised...")
except Exception as e:
print(str(e))
print("Error raised...")
raise e
return results
2 changes: 1 addition & 1 deletion mjrl/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def high_res_normalize(probs):


def stack_tensor_list(tensor_list):
return np.array(tensor_list)
return np.array(tensor_list, dtype='object')
# tensor_shape = np.array(tensor_list[0]).shape
# if tensor_shape is tuple():
# return np.array(tensor_list)
Expand Down