Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 38 additions & 31 deletions simpeg/dask/electromagnetics/frequency_domain/simulation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import gc
import os
import shutil

from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim
from ....utils import Zero
Expand Down Expand Up @@ -50,7 +52,9 @@ def receiver_derivs(survey, mesh, fields, blocks):
return field_derivatives


def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address):
def compute_rows(
simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address, Jmatrix
):
"""
Evaluate the sensitivities for the block or data
"""
Expand Down Expand Up @@ -92,7 +96,14 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address
if not isinstance(deriv_m, Zero):
du_dmT += deriv_m

return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T
values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T

if isinstance(Jmatrix, zarr.Array):
Jmatrix.set_orthogonal_selection((address[1][1], slice(None)), values)
else:
Jmatrix[address[1][1], :] = values

return None


def getSourceTerm(self, freq, source=None):
Expand Down Expand Up @@ -195,28 +206,39 @@ def compute_J(self, m, f=None):
"Consider creating one misfit per frequency."
)

client, worker = self._get_client_worker()

A_i = list(Ainv.values())[0]
m_size = m.size
compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6))
blocks = get_parallel_blocks(
self.survey.source_list, compute_row_size, optimize=True
)

if self.store_sensitivities == "disk":

chunk_size = np.median(
[np.sum([len(chunk[1][1]) for chunk in block]) for block in blocks]
).astype(int)

if os.path.exists(self.sensitivity_path):
shutil.rmtree(self.sensitivity_path)

Jmatrix = zarr.open(
self.sensitivity_path,
mode="w",
shape=(self.survey.nD, m_size),
chunks=(self.max_chunk_size, m_size),
chunks=(chunk_size, m_size),
)
else:
Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32)

compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6))
blocks = get_parallel_blocks(
self.survey.source_list, compute_row_size, optimize=False
)
if client:
Jmatrix = client.scatter(Jmatrix, workers=worker)

fields_array = f[:, self._solutionType]
blocks_receiver_derivs = []

client, worker = self._get_client_worker()

if client:
fields_array = client.scatter(f[:, self._solutionType], workers=worker)
fields = client.scatter(f, workers=worker)
Expand Down Expand Up @@ -270,7 +292,6 @@ def compute_J(self, m, f=None):
addresses_chunks,
client,
worker,
store_sensitivities=self.store_sensitivities,
)

for A in Ainv.values():
Expand All @@ -295,7 +316,6 @@ def parallel_block_compute(
addresses,
client,
worker=None,
store_sensitivities="disk",
):
m_size = m.size
block_stack = sp.hstack(blocks_receiver_derivs).toarray()
Expand All @@ -306,29 +326,29 @@ def parallel_block_compute(
ATinvdf_duT = client.scatter(ATinvdf_duT, workers=worker)
else:
ATinvdf_duT = delayed(ATinvdf_duT)

count = 0
rows = []
block_delayed = []

for address, dfduT in zip(addresses, blocks_receiver_derivs):
n_cols = dfduT.shape[1]
n_rows = address[1][2]

if client:
block_delayed.append(
client.submit(
eval_block,
compute_rows,
simulation,
ATinvdf_duT,
np.arange(count, count + n_cols),
Zero(),
fields_array,
address,
Jmatrix,
workers=worker,
)
)
else:
delayed_eval = delayed(eval_block)
delayed_eval = delayed(compute_rows)
block_delayed.append(
array.from_delayed(
delayed_eval(
Expand All @@ -338,35 +358,22 @@ def parallel_block_compute(
Zero(),
fields_array,
address,
Jmatrix,
),
dtype=np.float32,
shape=(n_rows, m_size),
)
)
count += n_cols
rows += address[1][1].tolist()

indices = np.hstack(rows)

if client:
block_delayed = client.gather(block_delayed)
block = np.vstack(block_delayed)
else:
block = compute(array.vstack(block_delayed))[0]

if store_sensitivities == "disk":
Jmatrix.set_orthogonal_selection(
(indices, slice(None)),
block,
)
client.gather(block_delayed)
else:
# Dask process to compute row and store
Jmatrix[indices, :] = block
compute(block_delayed)

return Jmatrix


Sim.parallel_block_compute = parallel_block_compute
Sim.compute_J = compute_J
Sim.getJtJdiag = getJtJdiag
Sim.Jvec = Jvec
Expand Down
25 changes: 16 additions & 9 deletions simpeg/dask/electromagnetics/static/resistivity/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix

from .....utils import Zero

import shutil
import os
import dask.array as da
import numpy as np
from scipy import sparse as sp
Expand Down Expand Up @@ -42,23 +43,29 @@ def compute_J(self, m, f=None):

f, Ainv = self.fields(m=m, return_Ainv=True)

m_size = m.size
n_cells = m.size
row_chunks = int(
np.ceil(
float(self.survey.nD)
/ np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size)
/ np.ceil(
float(n_cells) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size
)
)
)

if self.store_sensitivities == "disk":

if os.path.exists(self.sensitivity_path):
shutil.rmtree(self.sensitivity_path)

Jmatrix = zarr.open(
self.sensitivity_path + "J.zarr",
self.sensitivity_path,
mode="w",
shape=(self.survey.nD, m_size),
chunks=(row_chunks, m_size),
shape=(self.survey.nD, n_cells),
chunks=(row_chunks, n_cells),
)
else:
Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32)
Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float32)

blocks = []
count = 0
Expand Down Expand Up @@ -92,7 +99,7 @@ def compute_J(self, m, f=None):
du_dmT += df_dmT

#
du_dmT = du_dmT.T.reshape((-1, m_size))
du_dmT = du_dmT.T.reshape((-1, n_cells))

if len(blocks) == 0:
blocks = du_dmT
Expand Down Expand Up @@ -130,7 +137,7 @@ def compute_J(self, m, f=None):

if self.store_sensitivities == "disk":
del Jmatrix
self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr")
self._Jmatrix = da.from_zarr(self.sensitivity_path)
else:
self._Jmatrix = Jmatrix

Expand Down
Loading