# A Note on this Notebook
This loads a firedrake `DumbCheckpoint` file called `true-fields.h5` containing the values of $u_{true}$ and $q_{true}$ in `Function`s named `u_true` and `q_true` respectively which should have already been generated.
When run as a python/ipython script, this expects argument specifiers as such:
```
$ python estimate-q.py num_points method
```
where `num_points` is the number of points to sample from `u_true` and `method` is one of:
 - `point-cloud` to minimise $J$
 - one of the following to minimise $J'''$ by calculating $u_{interpolated}$ via one of
   - `nearest` to use `u_interpolated = scipy.interpolate.NearestNDInterpolator(xs, u_obs_vals)`
   - `linear` to use `u_interpolated = scipy.interpolate.LinearNDInterpolator(xs, u_obs_vals, fill_value=0.0)`
   - `clough-tocher` to use `u_interpolated = scipy.interpolate.CloughTocher2DInterpolator(xs, u_obs_vals, fill_value=0.0)`
   - `gaussian` to use `u_interpolated = scipy.interpolate.Rbf(xs[:, 0], xs[:, 1], u_obs_vals, function='gaussian')`

where `xs` are the point cloud coordinates and `u_obs_vals` the simulated measurementes of `u_true` at those coordinates with normally distributed random error added (variance $\sigma^2$).

Point cloud coordinates `xs`, values `u_obs_vals`, and standard deviation $\sigma$ are taken from an HDF5 file `observed-data.h5` if available (stored as `xs_{num_points}`, `u_obs_vals_{num_points}`, and `sigma_{num_points}` respectively) or are generated and saved as such.

Results are saved in a firedrake `DumbCheckpoint` file called `q-estimations.h5` containing $u_{interpolated}$ (if calculated), $q_{min}$ which minimises the functional $J$ or $J'''$ and $q_{err} = q_{true} - q_{min}$. 
These are named `u_interpolated_{method}_{num_points}`, `q_min_{method}_{num_points}` and `q_err_{method}_{num_points}` respectively.

# Problem Description

We try to enforce posterior consistency in the non-point-cloud case by redefining our objective functional

$$J'''[u, q] = 
\underbrace{ \frac{N}{\sigma^2}\int_{\Omega}\left(u_{interpolated} - u\right)^2dx}_{\text{model-data misfit}} + 
\underbrace{\frac{\alpha^2}{2}\int_\Omega|\nabla q|^2dx}_{\text{regularization}}$$

which is the same as $J'$ but where $J'''_{\text{misfit}} = N \times J'_{\text{misfit}}$ to try to allow the misfit term to grow with number of measurements.

Note that we do not use $\hat{\sigma}$ as we did in $J''$.

# Setup

In [None]:
from scipy.interpolate import (
    LinearNDInterpolator,
    NearestNDInterpolator,
    CloughTocher2DInterpolator,
    Rbf,
)

import matplotlib.pyplot as plt
import firedrake
import firedrake_adjoint

from firedrake import Constant, cos, sin

import numpy as np
from numpy import pi as π
from numpy import random

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import os, sys

currentdir = os.path.dirname(os.path.realpath('__file__'))

import argparse

parser = argparse.ArgumentParser(description='Estimate q using pyadjoint with a given number of point samples of u_true and chosen method. Expects to find a Firedrake checkpoint file \'true-fields.h5\' in the import directory.')
parser.add_argument('num_points', type=int, help='The number of points to sample from u_true. Points and measurements be identified from \'observed-data.h5\' or created and saved to it.')
parser.add_argument('method', help="The method to use: one of point-cloud, nearest, linear, clough-tocher, or gaussian")
try:
    args = parser.parse_args()
    num_points = args.num_points
    method = args.method
except:
    import warnings
    warnings.warn(f'Failed to parse arguments. Defaulting to num_points = 4 and method = point-cloud')
    num_points = 4
    method = 'point-cloud'

methods = ['point-cloud', 'nearest', 'linear', 'clough-tocher', 'gaussian']

# If running as notebook use default of 4 points and method 'point-cloud'
if method not in methods:
    import warnings
    warnings.warn(f'Got unexpected method argument {method} defaulting to point-cloud')
    method = 'point-cloud'
    
print(f"Running with {num_points} points and method {method}")

seed = 1729

In [None]:
mesh = firedrake.UnitSquareMesh(32, 32)

# Solution Space
V = firedrake.FunctionSpace(mesh, family='CG', degree=2)

# q (Control) Space
Q = firedrake.FunctionSpace(mesh, family='CG', degree=2)

## Get $q_{true}$ and $u_{true}$

In [None]:
q_true = firedrake.Function(V, name='q_true')
u_true = firedrake.Function(V, name='u_true')
filename = os.path.join(currentdir, 'true-fields')
with firedrake.DumbCheckpoint(filename, mode=firedrake.FILE_READ) as chk:
    chk.load(q_true, name='q_true')
    chk.load(u_true, name='u_true')
    
print("Have fake q_true and u_true")

## Generate and Save or Load "Observed" Data

In [None]:
import h5py

filename = os.path.join(currentdir, 'observed-data.h5')

try:
    # Load if available
    with h5py.File(filename, 'r') as file:
        xs = file[f"xs_{num_points}"][:]
        u_obs_vals = file[f"u_obs_vals_{num_points}"][:]
        σ = firedrake.Constant(file[f"sigma_{num_points}"])
        print(f"Loaded xs, u_obs_vals and sigma for {num_points} points.")
except (OSError, KeyError) as e:
    # Generate
    np.random.seed(0)
    xs = np.random.random_sample((num_points, 2))
    signal_to_noise = 20
    U = u_true.dat.data_ro[:]
    u_range = U.max() - U.min()
    σ = firedrake.Constant(u_range / signal_to_noise)
    generator = random.default_rng(seed)
    ζ = generator.standard_normal(len(xs))
    u_obs_vals = np.array(u_true.at(xs)) + float(σ) * ζ
    # Save
    with h5py.File(filename, 'a') as file:
        file.create_dataset(f"xs_{num_points}", data=xs)
        file.create_dataset(f"u_obs_vals_{num_points}", data=u_obs_vals)
        file.create_dataset(f"sigma_{num_points}", data=σ.values()[0])
    print(f"Generated and saved xs, u_obs_vals and sigma for {num_points} points.")

# Solve

## Define constants

In [None]:
f = Constant(1.0)
k0 = Constant(0.5)
α = firedrake.Constant(0.5)

## Run forward model with `q = 0` as first guess.

In [None]:
from firedrake import exp, inner, grad, dx

print('Running forward model')
u = firedrake.Function(V)
v = firedrake.TestFunction(V)
q = firedrake.Function(Q)
bc = firedrake.DirichletBC(V, 0, 'on_boundary')
F = (k0 * exp(q) * inner(grad(u), grad(v)) - f * v) * dx
firedrake.solve(F == 0, u, bc)

## Formulate $J$ or $J''$

In [None]:
if method == 'point-cloud':

    # Store data on the point_cloud using a vertex only mesh
    print('Creating VertexOnlyMesh')
    point_cloud = firedrake.VertexOnlyMesh(mesh, xs)
    print('Creating P0DG(VertexOnlyMesh) space')
    P0DG = firedrake.FunctionSpace(point_cloud, 'DG', 0)
    print('Creating u_obs')
    u_obs = firedrake.Function(P0DG, name=f'u_obs_{method}_{num_points}')
    u_obs.dat.data[:] = u_obs_vals
    
    print('Assembling J')
    misfit_expr = 0.5 * ((u_obs - firedrake.interpolate(u, P0DG)) / σ)**2

else:

    # Interpolating the mesh coordinates field (which is a vector function space)
    # into the vector function space equivalent of our solution space gets us
    # global DOF values (stored in the dat) which are the coordinates of the global
    # DOFs of our solution space. This is the necessary coordinates field X.
    print('Getting coordinates field X')
    Vc = firedrake.VectorFunctionSpace(mesh, V.ufl_element())
    X = firedrake.interpolate(mesh.coordinates, Vc).dat.data_ro[:]

    # Pick the appropriate "interpolate" method needed to create
    # u_interpolated given the chosen method
    print(f'Creating {method} interpolator')
    if method == 'nearest':
        interpolator = NearestNDInterpolator(xs, u_obs_vals)
    elif method == 'linear':
        interpolator = LinearNDInterpolator(xs, u_obs_vals, fill_value=0.0)
    elif method == 'clough-tocher':
        interpolator = CloughTocher2DInterpolator(xs, u_obs_vals, fill_value=0.0)
    elif method == 'gaussian':
        interpolator = Rbf(xs[:, 0], xs[:, 1], u_obs_vals, function='gaussian')
    print('Interpolating to create u_interpolated')
    u_interpolated = firedrake.Function(V, name=f'u_interpolated_{method}_{num_points}')
    u_interpolated.dat.data[:] = interpolator(X[:, 0], X[:, 1])
    
    print('Assembling J\'\'\'')
    misfit_expr = num_points * 0.5 * ((u_interpolated - u) / σ)**2    
    
regularisation_expr = 0.5 * α**2 * inner(grad(q), grad(q))
# Assembled J is here either J or J_prime_prime depending on the misfit expression
J = firedrake.assemble(misfit_expr * dx) + firedrake.assemble(regularisation_expr * dx)

## Minimise to Estimate $q$

In [None]:
print('Getting q̂ (control varaible) and Ĵ (reduced functional)')
q̂ = firedrake_adjoint.Control(q)
Ĵ = firedrake_adjoint.ReducedFunctional(J, q̂)

print('Minimising Ĵ to get q_min')
q_min = firedrake_adjoint.minimize(
    Ĵ, method='Newton-CG', options={'disp': True}
)
# Rename for saving
q_min.rename(name=f'q_min_{method}_{num_points}')

## Calculate Error

In [None]:
print('Calculating q error field')
q_err = firedrake.Function(Q, name=f'q_err_{method}_{num_points}').assign(q_min-q_true)
print('Calculating L2 error norm')
l2norm = firedrake.norm(q_err, "L2")

# Save Results

In [None]:
filename = os.path.join(currentdir, 'q-estimations')

with firedrake.DumbCheckpoint(filename, mode=firedrake.FILE_UPDATE) as chk:
    if method != 'point-cloud':
        # Not necessary to save u_obs since it's already saved as u_obs_vals
        print('Saving u_interpolated')
        chk.store(u_interpolated)
    print('Saving q_min')
    chk.store(q_min)
    print('Saving q_err')
    chk.store(q_err)

Check we saved correctly

In [None]:
with firedrake.DumbCheckpoint(filename, mode=firedrake.FILE_UPDATE) as chk:
    if method != 'point-cloud':
        print('Loading u_interpolated')
        u_interpolated_2 = firedrake.Function(V)
        chk.load(u_interpolated_2, name=f'u_interpolated_{method}_{num_points}')
        assert np.allclose(u_interpolated_2.dat.data_ro[:], u_interpolated.dat.data_ro[:])
    print('Loading q_min')
    q_min_2 = firedrake.Function(Q)
    chk.load(q_min_2, name=f'q_min_{method}_{num_points}')
    assert np.allclose(q_min_2.dat.data_ro[:], q_min.dat.data_ro[:])
    print('Loading q_err')
    q_err_2 = firedrake.Function(Q)
    chk.load(q_err_2, name=f'q_err_{method}_{num_points}')
    assert np.allclose(q_err_2.dat.data_ro[:], q_err.dat.data_ro[:])
    
print('Success!')