## Inverse Problem

This notebook is a simple example of how to solve an inverse problem using pyshdom. It assumes that we have run the 'SimulateRadiances' notebook and saved the result.

In [1]:
#imports

import pyshdom
import numpy as np
import xarray as xr
from collections import OrderedDict
import pylab as py
np.random.seed(1)

### Load Measurements

Here we load the synthetic measurements and also all of the inputs to the solver. We will use several of these inputs to perform 'inverse crimes' whereby we fix some aspects of the problem perfectly as in the forward simulation. When using real world measurements no such short cuts are necessary.

* First, we have to look at the measurements and select a region of interest and define our `rte_grid`. 
* Then we have to model the sensor sub-pixel geometry.
* We need to analyze whether the grid and sensor geometry are consistent. The SpaceCarver is useful for this. If not then we may need to change resolution of `rte_grid` to match the resolution of the measurements etc.
* We need to decide how to represent the surface, which is currently fixed.
* We need to decide how to represent the atmosphere.
    * What scattering species are we modelling? What are their optical models? 
    * Which quantities will be unknowns and which are fixed?
    * What will be the abstract state that we will reconstruct? For this we need to set the mapping between the abstract state and the RTE solver.
* Now that we are organized we need to initialize our state vector of unknowns. The method for this initialization may itself be quite involved as starting nearer the answer is better. This typically goes hand in hand with the selection of any fixed variables.
* Lastly we perform the optimization.

In this tutorial we will reconstruct only the `extinction` and we will use forward quantities from the ground truth synthetic measurements for simplicity.

In [2]:
sensors, solvers, rte_grid = pyshdom.util.load_forward_model('./SimulateRadiances.nc')

0.1 0.015601813796583135 False False


In [3]:
sensor_list = []
for sensor in sensors['MSPI']['sensor_list']:
    copied = sensor.copy(deep=True)
    weights = np.zeros(sensor.sizes['nrays'])
    ray_mask =np.zeros(sensor.sizes['nrays'], dtype=np.int)
    
    ray_mask_pixel = np.zeros(sensor.npixels.size, dtype=np.int)
    ray_mask_pixel[np.where(sensor.I.data > 1e-4)] = 1

    copied['weights'] = ('nrays',np.ones(sensor.nrays.size))#[sensor.pixel_index.data])
    copied['cloud_mask'] = ('nrays', ray_mask_pixel[sensor.pixel_index.data])
    sensor_list.append(copied)
    
def mean_ext_estimate(rte_grid, sensors, solar_mu, solar_azimuth,
                     chi=2/3, g=0.86, sun_distance_reflect=0.1,
                     sun_distance_transmit=0.3,
                     length_scale_method='max_height'):
    """
    Estimate the extinction of a cloud using diffusion theory.

    Given a masked volume `space_carved_volume`, the geometric distance
    between each point and the sun is calculated. The value of the geometric distance
    from the sun through the cloud at the first intersection of a sensor ray
    with the cloud volume is used to classify whether sensor pixels are
    observing shadowed or directly illuminated portions of the cloud.

    The mean of all 'shadowed' and 'illuminated' pixels is used to derive an
    optical diameter using diffusion theory and the extrapolation length `chi`
    and an asymmetry factor. This optical diameter is converted to an extinction
    using the length scale of the maximum chord length through the cloud in the solar
    direction. This length scale is chosen because it collapses to the relevant case
    for several geometries.
    """
    space_carver = pyshdom.space_carve.SpaceCarver(rte_grid)
    if isinstance(sensors, xr.Dataset):
        sensor_list = [sensors]
    elif isinstance(sensors, type([])):
        sensor_list = sensors
    elif isinstance(sensors, pyshdom.containers.SensorsDict):
        sensor_list = []
        for instrument in sensors:
            sensor_list.extend(sensors[instrument]['sensor_list'])

    volume = space_carver.carve(sensor_list, agreement=(0.0, 1.0), linear_mode=False)
    sundistance = space_carver.shadow_mask(volume.mask, sensor_list, solar_mu, solar_azimuth)

    reflected = []
    transmitted = []
    for sensor in sensor_list:
        reflected.extend(sensor.I.data[np.where((sensor.sun_distance.data < 0.1)&(sensor.I > 1e-2))])
        transmitted.extend(sensor.I.data[np.where((sensor.sun_distance.data >= 0.3)&(sensor.I > 1e-2))])
    
    if length_scale_method == 'sun_distance':
        radius = sundistance.sun_distance.data[np.where(sundistance.sun_distance > 0.0 )].max()
    elif length_scale_method == 'max_height':
        heights = rte_grid.z[np.where(volume.mask > 0.0)[-1]]
        radius = heights.max() - heights.min()

    tau_estimate = 2*chi*np.mean(reflected)/np.mean(transmitted)/(1.0-g)
    ext_estimate = tau_estimate/radius

    extinction = np.zeros(volume.mask.shape)
    extinction[np.where(volume.mask == 1.0)] = ext_estimate
    extinction = xr.Dataset(
        data_vars={
            'extinction': (['x', 'y', 'z'], extinction)
        },
        coords={
            'x': rte_grid.x,
            'y': rte_grid.y,
            'z': rte_grid.z,
        }
    )
    return extinction
    
extinction = mean_ext_estimate(rte_grid,sensor_list, -0.5, 0.0)

noise_std = np.std(solvers[0.86].medium['cloud'].extinction.data[np.where(solvers[0.86].medium['cloud'].extinction.data > 0.0)])
extinct_perturb = np.random.normal(loc=extinction.extinction.max(),
                                   scale=noise_std,
                                   size=extinction.extinction.shape)
ext_ref = extinction.extinction.data
ext_ref[0] = ext_ref[-1] = ext_ref[:,0] = ext_ref[:,-1] = 0.0
#ext_ref[np.where(ext_ref >0.0)] += extinct_perturb[np.where(ext_ref > 0.0)]
#ext_ref[np.where(ext_ref < 0.0)] = 0.0
extinction.extinction[:] = ext_ref

In [4]:
extinction.extinction.max()

In [5]:

# only take MSPI observations.
for name in sensors:
    if name != 'MSPI':
        sensors.popitem(name)
        
# make forward_sensors which will hold synthetic measurements from the evaluation of the forward model.
forward_sensors = sensors.make_forward_sensors()

# add an uncertainty model to the observations.
uncertainty = pyshdom.uncertainties.Uncertainty(np.diag(np.array([1e5,0.0,0.0,0.0])),'L2')
sensors.add_uncertainty_model('MSPI', uncertainty)

# prepare all of the static inputs to the solver just copy pasted from forward model
surfaces = OrderedDict()
numerical_parameters = OrderedDict()
sources = OrderedDict()
num_stokes = OrderedDict()
background_optical_scatterers = OrderedDict()
for key in forward_sensors.get_unique_solvers():
    surfaces[key] = solvers[key].surface
    numerical_parameters[key] = solvers[key].numerical_params
    sources[key] = solvers[key].source
    num_stokes[key] = solvers[key]._nstokes
    background_optical_scatterers[key] = {}




In [6]:
# set the generator for the unknown scatterer using ground truth optical properties
# and unknown extinction.
# OpticalGenerator holds the fixed optical properties and forms a full set of optical properties
# when it is called with extinction as the argument.

# MicrophysicalGenerator does the same but taking microphysics as the input. However that
# needs an OpticalPropertyGenerator to map between microphysics and optical properties as well.
optprop = solvers[0.86].medium['cloud']

deriv_gen = pyshdom.medium.OpticalGenerator(rte_grid,'cloud', 0.86, 
                                            optprop.legcoef, optprop.ssalb,
                                           optprop.phase_weights, optprop.table_index,
                                           )

# UnknownScatterers is a container for all of the unknown variables.
unknown_scatterers = pyshdom.containers.UnknownScatterers()
unknown_scatterers.add_unknowns(['extinction'], deriv_gen)

In [7]:
# now we form state_gen which updates the solvers with an input_state.
# Note that `pyshdom.medium.StateGenerator` also takes as inputs a state_transform, a state_representation
# and a state_to_grid objects. These objects handle multi-variable
solvers_reconstruct = pyshdom.containers.SolversDict()
mask = np.zeros((rte_grid.x.size,rte_grid.y.size,rte_grid.z.size)).astype(np.bool)
mask[np.where(extinction.extinction.data > 0.0)] = 1.0
state_gen = pyshdom.medium.StateGenerator(solvers_reconstruct,
                                         unknown_scatterers, rte_grid,surfaces,
                                         numerical_parameters, sources, background_optical_scatterers,
                                         num_stokes, state_transform=None, state_to_grid=mask)

In [83]:
self = solvers[0.86]
#reshaped_ext = self._total_ext[:self._nbpts].reshape(self._nx1, self._ny1, self._nz)
reshaped_ext = self.medium['cloud'].extinction
cell_averaged_extinct = (reshaped_ext[1:, 1:, 1:] + reshaped_ext[1:, 1:, :-1] +   \
                         reshaped_ext[1:, :-1, 1:] + reshaped_ext[1:, :-1, :-1] + \
                         reshaped_ext[:-1, 1:, 1:] + reshaped_ext[:-1, 1:, :-1] + \
                         reshaped_ext[:-1, :-1, 1:] + reshaped_ext[:-1, :-1, :-1])/8.0

In [84]:
py.figure()
for i in range(0,cell_averaged_extinct.shape[-1],3):
    py.hist(cell_averaged_extinct[...,i].data[np.where(cell_averaged_extinct[...,i] > 0.0)],
           histtype='step',bins=np.linspace(0.0,130.0,20),label='{}'.format(i))
py.legend()
    

<matplotlib.legend.Legend at 0x184738df0>

In [47]:
py.figure()
py.imshow(cell_averaged_extinct[12])

<matplotlib.image.AxesImage at 0x144606130>

In [8]:
# get bounds automatically.
min_bounds, max_bounds = state_gen.transform_bounds()

In [9]:
# transform initial physical state to abstract state. 
# This is
x0 = extinction.extinction.data#np.zeros((rte_grid.x.size,rte_grid.y.size,rte_grid.z.size)) + 1e-1
a = state_gen._state_to_grid.inverse(x0, 'cloud', 'extinction')
x0 = state_gen.state_transform.inverse(a)

In [121]:
x0

array([28.55629108, 28.55629108, 28.55629108, ..., 28.55629108,
       28.55629108, 28.55629108])

In [139]:
state_gen.state_transform(x0)

array([29.74613654, 29.74613654, 29.74613654, ..., 29.74613654,
       29.74613654, 29.74613654])

In [140]:
state_gen(x0)

0.1 0.04692769985720488 False False




In [10]:

objective_function = pyshdom.optimize.ObjectiveFunction.LevisApproxUncorrelatedL2(
    sensors, solvers_reconstruct, forward_sensors, unknown_scatterers, state_gen,
  state_gen.project_gradient_to_state,
    parallel_solve_kwargs={'n_jobs': 4, 'verbose': True},
  gradient_kwargs={'cost_function': 'L2', 'exact_single_scatter':True},
  uncertainty_kwargs={'add_noise': False},
  min_bounds=min_bounds, max_bounds=max_bounds)

In [20]:
rte_grid

In [141]:
state_gen(x0)
forward_sensors.get_measurements(solvers_reconstruct)



0.1 0.04692769985720488 False False


In [11]:
_, gradient = objective_function(x0)



0.1 0.04692769985720488 False False
[1622446675.320618, 1622446676.154758, 1622446676.386637, 1622446676.462264, 1622446676.539957, 1622446687.1492522]


In [39]:
np.where(np.isnan(grad_shaped))

(array([5, 5, 5, 5]), array([36, 36, 36, 36]), array([ 6,  7,  9, 10]))

In [37]:
np.unravel_index(920,(extinction.extinction.shape))

(0, 30, 20)

In [40]:
solver = solvers_reconstruct[0.86]
grad_shaped=state_gen._state_to_grid(gradient,'cloud','extinction')
#grad_shaped = gradient.reshape(solver._grid.x.size, solver._grid.y.size, solver._grid.z.size)

py.figure()
py.imshow(grad_shaped[5])
py.colorbar()

<matplotlib.colorbar.Colorbar at 0x15d7d5460>

In [14]:
%matplotlib qt
for im,im2 in zip(forward_sensors.get_images('MSPI'), sensors.get_images('MSPI')):
    py.figure()
    (im.I-im2.I).plot()

In [12]:
optimizer = pyshdom.optimize.Optimizer(objective_function, prior_fn=None)


In [32]:
x0

array([0., 0., 0., ..., 0., 0., 0.])

In [13]:
optimizer.minimize(x0)

0.1 0.04692769985720488 False False




[1622446687.3207302, 1622446688.123911, 1622446688.3571699, 1622446688.439382, 1622446688.5324311, 1622446698.0156631]
0.1 0.04692728242309267 False False




[1622446698.176322, 1622446698.9439669, 1622446699.169319, 1622446699.2396889, 1622446699.313718, 1622446708.743819]
0.1 0.046926864516718354 False False




[1622446708.88127, 1622446709.650889, 1622446709.8729851, 1622446709.944235, 1622446710.018441, 1622446719.590812]
0.1 0.046925193188457476 False False




[1622446719.7232509, 1622446720.482968, 1622446720.7052422, 1622446720.775743, 1622446720.847164, 1622446730.3550668]
0.1 0.04691850771281291 False False




[1622446730.47809, 1622446731.2446358, 1622446731.466182, 1622446731.53636, 1622446731.608772, 1622446740.968043]
0.1 0.04689176590726424 False False




[1622446741.0910778, 1622446741.868387, 1622446742.092563, 1622446742.16449, 1622446742.236945, 1622446751.76212]
0.1 0.04678479869159676 False False




[1622446751.885102, 1622446752.64351, 1622446752.8665369, 1622446752.9363441, 1622446753.008574, 1622446762.863082]
0.1 0.046356929904409754 False False




[1622446762.9865978, 1622446763.756649, 1622446763.984641, 1622446764.05659, 1622446764.130261, 1622446773.7379289]
0.1 0.044645454539271816 False False




[1622446773.861683, 1622446774.617426, 1622446774.8375208, 1622446774.908299, 1622446774.98113, 1622446784.390604]
0.1 0.03779955317214335 False False




[1622446784.513973, 1622446785.2506518, 1622446785.470631, 1622446785.5410938, 1622446785.613021, 1622446795.04843]
0.1 0.022790300409110807 False False




[1622446795.177598, 1622446795.78401, 1622446796.000231, 1622446796.0702589, 1622446796.143064, 1622446805.4213219]
0.1 0.010489348918752351 False False




[1622446805.548786, 1622446806.2544892, 1622446806.4729862, 1622446806.545482, 1622446806.6174989, 1622446815.8703501]
0.1 0.013684543708317514 False False




[1622446815.997483, 1622446816.696914, 1622446816.922286, 1622446816.998701, 1622446817.072715, 1622446826.226757]
0.1 0.015756053082455902 False False




[1622446826.355309, 1622446827.03844, 1622446827.2574852, 1622446827.328002, 1622446827.399149, 1622446836.689227]
0.1 0.015204041274163454 False False




[1622446836.816592, 1622446837.569655, 1622446837.784509, 1622446837.854003, 1622446837.92706, 1622446847.1616268]
0.1 0.014922324711506682 False False




[1622446847.2891278, 1622446848.1100202, 1622446848.3287208, 1622446848.399711, 1622446848.47281, 1622446857.906178]
0.1 0.01513136976035404 False False




[1622446858.036593, 1622446858.8372722, 1622446859.057327, 1622446859.128014, 1622446859.19931, 1622446868.95714]
0.1 0.015620574004603364 False False




[1622446869.086336, 1622446869.9065971, 1622446870.132704, 1622446870.2049642, 1622446870.277796, 1622446879.638149]
0.1 0.015671132062406563 False False




[1622446879.7694001, 1622446880.573164, 1622446880.7978282, 1622446880.868225, 1622446880.94199, 1622446890.314065]
0.1 0.015427017951745694 False False




[1622446890.445269, 1622446891.2562969, 1622446891.473911, 1622446891.5444422, 1622446891.617197, 1622446900.858407]
0.1 0.01528540133523678 False False




[1622446900.98813, 1622446901.7510269, 1622446901.9708622, 1622446902.040677, 1622446902.1121101, 1622446911.3412359]
0.1 0.015244460122636716 False False




[1622446911.486874, 1622446912.394345, 1622446912.625155, 1622446912.7001631, 1622446912.781461, 1622446922.3034618]
0.1 0.015235791653774377 False False




[1622446922.432472, 1622446923.242628, 1622446923.4645529, 1622446923.5355148, 1622446923.606849, 1622446932.978695]
0.1 0.015197241957173459 False False




[1622446933.104023, 1622446933.911199, 1622446934.1289668, 1622446934.199971, 1622446934.272137, 1622446943.7163038]
0.1 0.015163329873581823 False False




[1622446943.844213, 1622446944.6495519, 1622446944.864585, 1622446944.936982, 1622446945.009642, 1622446954.315137]
0.1 0.015193090305318976 False False




[1622446954.4437811, 1622446955.2604532, 1622446955.490106, 1622446955.564409, 1622446955.6371, 1622446964.9330292]
0.1 0.015234734109752072 False False




[1622446965.05895, 1622446965.9106581, 1622446966.1323252, 1622446966.2059588, 1622446966.277836, 1622446975.6734238]
0.1 0.015254789149153675 False False




[1622446975.7998452, 1622446976.650712, 1622446976.872237, 1622446976.94381, 1622446977.015598, 1622446986.4157891]
0.1 0.015230444536802785 False False




[1622446986.5430741, 1622446987.401129, 1622446987.6224189, 1622446987.693918, 1622446987.765022, 1622446997.6034591]
0.1 0.01522076159460306 False False




[1622446997.765766, 1622446998.664206, 1622446998.886092, 1622446998.9575622, 1622446999.028574, 1622447008.5168571]
0.1 0.015234362449362751 False False




[1622447008.6449091, 1622447009.5472052, 1622447009.7723782, 1622447009.8423848, 1622447009.9127831, 1622447019.251792]
0.1 0.015234008599980638 False False




[1622447019.381515, 1622447020.2668211, 1622447020.485664, 1622447020.5569558, 1622447020.628035, 1622447030.224844]
0.1 0.015228094930737993 False False




KeyboardInterrupt: 

In [21]:
state_gen._unknown_scatterers

UnknownScatterers([('cloud',
                    {'variable_name_list': ['extinction'],
                     'dataset_generator': <pyshdom.medium.OpticalGenerator at 0x134ac7d60>})])

In [23]:
solvers_reconstruct[0.86]._unknown_scatterer_indices

array([[2]], dtype=int32)

In [15]:
py.figure()
py.plot(solvers_reconstruct[0.86].medium['cloud'].extinction.data.ravel(), 
        solvers[0.86].medium['cloud'].extinction.data.ravel(),'x')

[<matplotlib.lines.Line2D at 0x1435fb490>]

In [17]:
cond = np.where(solvers_reconstruct[0.86].medium['cloud'].extinction > 0.0)
one = solvers_reconstruct[0.86].medium['cloud'].extinction.data[cond]
ref = solvers[0.86].medium['cloud'].extinction.data[cond]
print(np.mean(one), np.mean(ref), np.mean(np.abs(one - ref))/np.mean(ref))

9.747940899135358 9.989887744836532 0.31471784498055466


In [24]:
py.figure()
solvers_reconstruct[0.86].medium['cloud'].extinction[12,:,:-3].plot()

<matplotlib.collections.QuadMesh at 0x142e7baf0>

In [25]:
py.figure()
solvers[0.86].medium['cloud'].extinction[12,:,:-3].plot()

<matplotlib.collections.QuadMesh at 0x144680bb0>