In [1]:
from geoh5py.workspace import Workspace
from geoh5py.groups import ContainerGroup
import numpy as np
from scipy.spatial import cKDTree
from SimPEG import dask
import dask
from SimPEG import (
    maps,
    utils,
    data_misfit,
    regularization,
    optimization,
    inverse_problem,
    directives,
    inversion,
    objective_function,
    data
)
from SimPEG.utils.drivers import create_nested_mesh
from SimPEG.electromagnetics.static import resistivity as dc, utils as DCutils
from SimPEG.electromagnetics.static import induced_polarization as ip
from dask.distributed import Client, LocalCluster, get_client
from geoapps.io.DirectCurrent import DirectCurrentParams

In [2]:
from geoapps.utils import octree_2_treemesh, treemesh_2_octree
from discretize import utils as d_utils
from discretize.utils import mesh_builder_xyz, refine_tree_xyz
from discretize import TensorMesh
from pymatsolver.direct import Pardiso as Solver

@dask.delayed
def create_tile_dc(source, obs, uncert, global_mesh, global_active, tile_id, buffer=200.):
    print(f"Processing tile {tile_id}")
    local_survey = dc.Survey(source)
    electrodes = np.vstack((local_survey.locations_a,
                            local_survey.locations_b,
                            local_survey.locations_m,
                            local_survey.locations_n))
    local_survey.dobs = obs
    local_survey.std = uncert
    local_mesh = create_nested_mesh(
        electrodes, global_mesh, method="radial", max_distance=buffer
    )
    local_map = maps.TileMap(global_mesh, global_active, local_mesh)
    actmap = maps.InjectActiveCells(
        local_mesh, indActive=local_map.local_active, valInactive=np.log(1e-8)
    )

    expmap = maps.ExpMap(local_mesh)
    mapping = expmap * actmap
    # Create the local misfit
    max_chunk_size = 256
    simulation = dc.Simulation3DNodal(
        local_mesh, survey=local_survey, sigmaMap=mapping, storeJ=True,
        Solver=Solver, max_ram=1
    )

    simulation.mesh = TensorMesh([1])  # Light dummy
    del local_mesh,
    local_map.local_mesh = None
    actmap.mesh = None
    expmap.mesh = None

    simulation.sensitivity_path = './sensitivity/Tile' + str(tile_id) + '/'
    data_object = data.Data(
        local_survey,
        dobs=obs,
        standard_deviation=uncert,
    )
    data_object.dobs = obs
    data_object.standard_deviation = uncert
    local_misfit = data_misfit.L2DataMisfit(
        data=data_object, simulation=simulation, model_map=local_map
    )
    local_misfit.W = 1 / uncert

    return local_misfit

In [3]:
# params = DirectCurrentParams()
# params.write_input_file()

In [4]:
cluster = LocalCluster(processes=False)
client = Client(cluster)

In [5]:
ws = Workspace("FlinFlon.geoh5")
rx_obj = ws.get_entity("DC_survey")[0]
tx_obj = ws.get_entity("DC_survey (currents)")[0]
topo = ws.get_entity("Topography")[0].vertices
# Generate data
# octree = ws.get_entity("DC_mesh")[0]
# model = octree.get_data("Forward_con")[0]
# mesh = octree_2_treemesh(octree)

dobs = rx_obj.get_data("dc")[0]

In [6]:
ab_id = np.unique(rx_obj.ab_cell_id.values).astype(int).tolist()
value_map = {value: key for key, value in rx_obj.ab_cell_id.value_map.map.items()}
src_lists = []
data_id = []
lines = {ii: {"sources": [], "data_id": []} for ii in np.unique(tx_obj.parts)}
for ab, cell in enumerate(tx_obj.cells.tolist()):
    
    
    rx_id = np.where(rx_obj.ab_cell_id.values.astype(int)==value_map[str(ab+1)])[0]
    
    if len(rx_id) == 0:
        continue
        
    rx_M = rx_obj.vertices[rx_obj.cells[rx_id, 0]]
    rx_N = rx_obj.vertices[rx_obj.cells[rx_id, 1]]
    receivers = dc.receivers.Dipole(
            rx_M, 
            rx_N
    )
    src_lists.append(
        dc.sources.Dipole(
            [receivers], 
            tx_obj.vertices[cell[0]],
            tx_obj.vertices[cell[1]]
        )
    )
    line_id = tx_obj.parts[cell[0]]
    lines[line_id]["sources"].append(src_lists[-1])
    lines[line_id]["data_id"].append(rx_id)
    data_id.append(rx_id)
    
survey_dc = dc.Survey(src_lists)

data_id = np.hstack(data_id)

## Assign uncertainties based on background resistivity

In [7]:
floor_res = 0.2
geofact = DCutils.geometric_factor(survey_dc)
floor = np.abs(geofact*floor_res)

survey_dc.dobs = dobs
survey_dc.std = floor

## Create a mesh

In [8]:
# Create a cheaper mesh
h = [15., 15., 15.]
mesh = mesh_builder_xyz(
    np.vstack([rx_obj.vertices, tx_obj.vertices]), 
    h, 
    padding_distance=[[1000., 1000.]]*3, 
    depth_core=100, mesh_type="tree"
)
mesh = refine_tree_xyz(
    mesh, 
    np.vstack([rx_obj.vertices, tx_obj.vertices]), 
    method="surface", 
    octree_levels=[6, 6, 6], 
    octree_levels_padding=[6, 6, 6],
    finalize=False
)
mesh = refine_tree_xyz(
    mesh, topo, method="surface", octree_levels=[0, 0, 0, 2], finalize=True
)

In [9]:
activeCells = d_utils.active_from_xyz(mesh, topo)
nC = int(activeCells.sum())
survey_dc.drape_electrodes_on_topography(mesh, activeCells, option='top')

# expmap = maps.ExpMap(mesh)
# mapactive = maps.InjectActiveCells(mesh=mesh, indActive=activeCells, valInactive=np.log(1e-8))
# mapping = expmap * mapactive
# simulation_g = dc.Simulation3DNodal(
#     mesh, survey=survey_dc, sigmaMap=mapping, solver=Solver, model=mstart
# )
# global_data = simulation_g.make_synthetic_data(mtrue[activeCells], relative_error=0.05, noise_floor=1., add_noise=True)

In [10]:
%%time

local_misfits = []
for ab_id, part in lines.items():
    ind = np.hstack(part["data_id"])
    local_misfits.append(
            client.compute(
                create_tile_dc(
                    part["sources"],  
                    dobs.values[ind],
                    floor[ind], 
                    mesh, 
                    activeCells, 
                    ab_id,
                    buffer=200
                )
            )
    )
    
local_misfits = client.gather(local_misfits)
global_misfit = objective_function.ComboObjectiveFunction(
    local_misfits
)

Processing tile 1
Processing tile 0
Processing tile 2
Processing tile 9
Processing tile 8
Processing tile 7
Processing tile 6
Processing tile 5
Processing tile 4
Processing tile 3
Wall time: 25.9 s


In [11]:
name = "Inversion_1"
inv_group = ContainerGroup.create(ws, name=name)
octree = treemesh_2_octree(ws, mesh, parent=inv_group)

pred = rx_obj.copy(parent=inv_group, copy_children=False)
pred.name = "Predicted"

rx_obj.get_data("A-B Cell ID")[0].copy(parent=pred)
src = tx_obj.copy(parent=inv_group, copy_children=False)
src.name = "Predicted (currents)"
pred.current_electrodes = src


obs_entity = pred.add_data({
    "Observed": {"values": dobs.values, "association": "CELL"}
})
    
coolingFactor = 2
coolingRate = 1
beta0_ratio = 1e1

# Map for a regularization
regmap = maps.IdentityMap(nP=int(activeCells.sum()))
# reg = regularization.Tikhonov(mesh, indActive=global_actinds, mapping=regmap)
reg = regularization.Sparse(mesh, indActive=activeCells, mapping=regmap)
reg.norms = np.c_[0, 2, 2, 2]
print('[INFO] Getting things started on inversion...')
# set alpha length scales


opt = optimization.ProjectedGNCG(
    maxIter=15, upper=np.inf, lower=-np.inf,
    maxIterCG=20, tolCG=1e-4
)
invProb = inverse_problem.BaseInvProblem(global_misfit, reg, opt)

print("Pre-computing Jmatrix and predicted_0")
mstart = np.ones(nC) * np.log(.2)
invProb.dpred = invProb.get_dpred(mstart, compute_J=True)

actmap = maps.InjectActiveCells(
    mesh, indActive=activeCells, valInactive=np.nan
)
expmap = maps.ExpMap(mesh)

directive_list = []
directive_list.append(directives.UpdateSensitivityWeights(threshold=1e-3))
directive_list.append(directives.Update_IRLS(f_min_change=1e-4, minGNiter=1))
directive_list.append(directives.BetaEstimate_ByEig(beta0_ratio=beta0_ratio, method="ratio"))
directive_list.append(directives.UpdatePreconditioner())
directive_list.append(
    directives.SaveIterationsGeoH5(
        h5_object=octree,
        mapping=expmap*actmap,
        attribute_type="model",
        association="CELL",
        sorting=mesh._ubc_order,
    )
)
directive_list.append(
    directives.SaveIterationsGeoH5(
        h5_object=pred,
        channels=["dc"],
        attribute_type="predicted",
        association="CELL",
        data_type={"": {"dc": obs_entity.entity_type}},
        save_objective_function=True,
    )
)

# Need to have basice saving function

inv = inversion.BaseInversion(
    invProb, directiveList=directive_list)
opt.LSshorten = 0.5
opt.remember('xc')

# Run Inversion ================================================================
minv = inv.run(mstart)

[INFO] Getting things started on inversion...
Pre-computing Jmatrix and predicted_0
SimPEG.InvProblem will set Regularization.mref to m0.

        SimPEG.InvProblem is setting bfgsH0 to the inverse of the eval2Deriv.
        ***Done using same Solver and solver_opts as the Simulation3DNodal problem***
model has any nan: 0
  #     beta     phi_d     phi_m       f      |proj(x-g)-x|  LS    Comment   
-----------------------------------------------------------------------------
   0  2.59e-01  4.55e+05  0.00e+00  4.55e+05    6.11e+03      0              
   1  1.30e-01  1.46e+05  3.36e+05  1.90e+05    4.02e+03      0              
   2  6.48e-02  5.00e+04  5.22e+05  8.38e+04    1.99e+03      0   Skip BFGS  
   3  3.24e-02  1.28e+04  7.41e+05  3.68e+04    7.65e+02      0   Skip BFGS  
   4  1.62e-02  4.97e+03  8.72e+05  1.91e+04    3.64e+02      0   Skip BFGS  
   5  8.10e-03  2.45e+03  9.18e+05  9.89e+03    2.72e+02      0   Skip BFGS  
   6  4.05e-03  2.18e+03  9.46e+05  6.01e+03    3.15

KeyboardInterrupt: 