Skip to content

Commit

Permalink
Merge pull request #105 from arbennett/develop
Browse files Browse the repository at this point in the history
Update ensemble & distributed to allow existing clients
  • Loading branch information
arbennett committed Mar 26, 2020
2 parents c7df1cd + 7251f4a commit 7f781c3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
21 changes: 14 additions & 7 deletions pysumma/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Distributed(object):

def __init__(self, executable: str, filemanager: str,
num_workers: int=1, threads_per_worker: int=OMP_NUM_THREADS,
chunk_size: int=None, num_chunks: int=None, scheduler: str=None):
chunk_size: int=None, num_chunks: int=None, scheduler: str=None,
client: Client=None):
"""
Initialize a new distributed object
Expand Down Expand Up @@ -70,15 +71,21 @@ def __init__(self, executable: str, filemanager: str,
self.submissions: List = []
self.num_workers: int = num_workers
# Try to get a client, and if none exists then start a new one
try:
self._client = get_client()
# Start more workers if necessary:
if client:
self._client = client
workers = len(self._client.get_worker_logs())
if workers <= self.num_workers:
self._client.cluster.scale(workers)
except ValueError:
self._client = Client(n_workers=self.num_workers,
threads_per_worker=threads_per_worker)
else:
try:
self._client = get_client()
# Start more workers if necessary:
workers = len(self._client.get_worker_logs())
if workers <= self.num_workers:
self._client.cluster.scale(workers)
except ValueError:
self._client = Client(n_workers=self.num_workers,
threads_per_worker=threads_per_worker)
self.chunk_args = self._generate_args(chunk_size, num_chunks)
self._generate_simulation_objects()

Expand Down
20 changes: 13 additions & 7 deletions pysumma/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Ensemble(object):
def __init__(self, executable: str,configuration: dict,
filemanager: str=None, num_workers: int=1,
threads_per_worker: int=OMP_NUM_THREADS,
scheduler: str=None):
scheduler: str=None, client: Client=None):
"""
Create a new Ensemble object. The API mirrors that of the
Simulation object.
Expand All @@ -46,15 +46,21 @@ def __init__(self, executable: str,configuration: dict,
self.simulations: dict = {}
self.submissions: list = []
# Try to get a client, and if none exists then start a new one
try:
self._client = get_client()
# Start more workers if necessary:
if client:
self._client = client
workers = len(self._client.get_worker_logs())
if workers <= self.num_workers:
self._client.cluster.scale(workers)
except ValueError:
self._client = Client(n_workers=self.num_workers,
threads_per_worker=threads_per_worker)
else:
try:
self._client = get_client()
# Start more workers if necessary:
workers = len(self._client.get_worker_logs())
if workers <= self.num_workers:
self._client.cluster.scale(workers)
except ValueError:
self._client = Client(n_workers=self.num_workers,
threads_per_worker=threads_per_worker)
self._generate_simulation_objects()

def _generate_simulation_objects(self):
Expand Down
2 changes: 1 addition & 1 deletion pysumma/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def monitor(self):
self.status = 'Success'

try:
self._output = [xr.open_dataset(f) for f in self.get_output()]
self._output = [xr.open_dataset(f) for f in self.get_output_files()]
if len(self._output) == 1:
self._output = self._output[0]
except Exception:
Expand Down

0 comments on commit 7f781c3

Please sign in to comment.