Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 72 additions & 96 deletions simpeg/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,40 +200,41 @@ def compute_J(self, m, f=None):
if len(block) == 0:
continue

if client:
field_derivatives = client.scatter(
ATinv_df_duT_v[ind], workers=self.worker
)
j_row_updates.append(
client.submit(
compute_rows,
sim,
tInd,
block,
field_derivatives,
fields_array,
time_mask,
workers=self.worker,
)
)
else:
j_row_updates.append(
array.from_delayed(
delayed_compute_rows(
for row, field_derivatives in zip(block, ATinv_df_duT_v[ind]):
if client:
# field_derivatives = client.scatter(
# ATinv_df_duT_v[ind], workers=self.worker
# )
j_row_updates.append(
client.submit(
compute_rows,
sim,
tInd,
block,
ATinv_df_duT_v[ind],
row,
field_derivatives,
fields_array,
time_mask,
),
dtype=np.float32,
shape=(
np.sum([len(chunk[1][0]) for chunk in block]),
m.size,
),
workers=self.worker,
)
)
else:
j_row_updates.append(
array.from_delayed(
delayed_compute_rows(
sim,
tInd,
row,
field_derivatives,
fields_array,
time_mask,
),
dtype=np.float32,
shape=(
np.sum([len(chunk[1][0]) for chunk in block]),
m.size,
),
)
)
)

if client:
j_row_updates = np.vstack(client.gather(j_row_updates))
Expand Down Expand Up @@ -390,59 +391,39 @@ def get_field_deriv_block(
"""
Stack the blocks of field derivatives for a given timestep and call the direct solver.
"""
stacked_blocks = []
if len(ATinv_df_duT_v) == 0:
ATinv_df_duT_v = [[] for _ in block]
indices = []
count = 0

Asubdiag = None
if tInd < self.nT - 1:
Asubdiag = self.getAsubdiag(tInd + 1)

updated_ATinv_df_duT_v = []

for (_, (rx_ind, _, shape)), field_deriv, ATinv_chunk in zip(
block, field_derivs, ATinv_df_duT_v
):

# Cut out early data
time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind]
local_ind = np.arange(rx_ind.shape[0])[time_check]
indices.append(
(np.arange(count, count + len(local_ind)), local_ind),
)
count += len(local_ind)

if len(ATinv_chunk) == 0:
# last timestep (first to be solved)
stacked_block = field_deriv.toarray()[:, local_ind]

else:
stacked_block = np.asarray(
field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind]
)

stacked_blocks.append(stacked_block)

blocks = np.hstack(stacked_blocks)
if blocks.ndim == 2 and blocks.shape[1] > 0:
solve = (AdiagTinv * blocks).reshape(blocks.shape)
else:
solve = None

updated_ATinv_df_duT_v = []

for (_, arrays), field_deriv, ATinv_chunk, (columns, local_ind) in zip(
block, field_derivs, ATinv_df_duT_v, indices, strict=True
):

if len(ATinv_chunk) == 0:
time_block = field_deriv.toarray()[:, local_ind]
shape = (
field_deriv.shape[0],
len(arrays[0]),
len(rx_ind),
)
ATinv_chunk = np.zeros(shape, dtype=np.float32)
else:
time_block = np.asarray(
field_deriv[:, local_ind] - Asubdiag.T * ATinv_chunk[:, local_ind]
)

if solve is not None:
ATinv_chunk[:, local_ind] = solve[:, columns]
if time_block.ndim == 2 and time_block.shape[1] > 0:
solve = (AdiagTinv * time_block).reshape(time_block.shape)
ATinv_chunk[:, local_ind] = solve

updated_ATinv_df_duT_v.append(ATinv_chunk)

Expand Down Expand Up @@ -513,52 +494,47 @@ def compute_rows(
simulation,
tInd,
chunks,
ATinv_df_duT_v,
field_derivs,
fields,
time_mask,
):
"""
Compute the rows of the sensitivity matrix for a given source and receiver.
"""
rows = []
(address, ind_array) = chunks
# for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
src = simulation.survey.source_list[address[0]]
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]
local_ind = np.arange(len(ind_array[0]))[time_check]

for (address, ind_array), field_derivs in zip(chunks, ATinv_df_duT_v):
src = simulation.survey.source_list[address[0]]
time_check = np.kron(time_mask, np.ones(ind_array[2], dtype=bool))[ind_array[0]]
local_ind = np.arange(len(ind_array[0]))[time_check]

if len(local_ind) < 1:
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
)
rows.append(row_block)
continue

dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
tInd,
fields[:, address[0], tInd],
field_derivs[:, local_ind],
adjoint=True,
)

dRHST_dm_v = simulation.getRHSDeriv(
tInd + 1, src, field_derivs[:, local_ind], adjoint=True
) # on nodes of time mesh

un_src = fields[:, address[0], tInd + 1]
# cell centered on time mesh
dAT_dm_v = simulation.getAdiagDeriv(
tInd, un_src, field_derivs[:, local_ind], adjoint=True
)
if len(local_ind) < 1:
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
)
row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(
np.float32
)
rows.append(row_block)
return row_block

dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
tInd,
fields[:, address[0], tInd],
field_derivs[:, local_ind],
adjoint=True,
)

dRHST_dm_v = simulation.getRHSDeriv(
tInd + 1, src, field_derivs[:, local_ind], adjoint=True
) # on nodes of time mesh

un_src = fields[:, address[0], tInd + 1]
# cell centered on time mesh
dAT_dm_v = simulation.getAdiagDeriv(
tInd, un_src, field_derivs[:, local_ind], adjoint=True
)
row_block = np.zeros((len(ind_array[1]), simulation.model.size), dtype=np.float32)
row_block[time_check, :] = (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T.astype(
np.float32
)

return np.vstack(rows)
return row_block


def evaluate_dpred_block(indices, sources, mesh, time_mesh, fields):
Expand Down