# A Note on this Notebook
This analyses data produced by running the `solve.sh` script in this directory.
Graphs are produced corresponding to data that was output.

# Problem Description

We want to find out how the solution of our inverse problem converges as we increase the number of points for both the new and traditional methods of data interpolation.

If we have what is known as **"posterior consistency"** then we expect that the error in our solution, when compared to the true solution, will always decrease as we increase the number of points we are assimilating.

## Posterior Consistency

From a Bayesian point of view, the regularisation we choose and the weighting we give it encode information about our assumed prior probability distribution of $q$ before we start assimilating data (adding observations).
Take, for example, the regularisation used in the this problem

$$
\alpha^2\int_\Omega|\nabla q|^2dx
$$

which asserts a prior that the solution $q$ which minimises $J$ should be smooth and gives a weighting $\alpha$ to the assertion.
If we have posterior consistency, the contribution of increasing numbers of measurements $u_{obs}$ should increase the weighting of our data relative to our prior and we should converge towards the true solution.

## Hypothesis

Our two methods minimise two different functionals. 
The first minimises $J$

$$
J[u, q] = \underbrace{
                        \int_{\Omega_v} ( u_{\text{obs}} - \mathcal{I}_{\text{P0DG}(\Omega_v)}(u) )^2 dx
                        }_{J_{\text{model-data misfit}}^{\text{point}}} + 
            \underbrace{
                        \alpha^2\int_\Omega|\nabla q|^2 dx
                        }_{J_{\text{regularisation}}}
$$

whilst the second minimises $J'$

$$
J'[u, q] = \underbrace{
                        \int_{\Omega} ( u_{\text{interpolated}} - u )^2 dx
                        }_{J_{\text{model-data misfit}}^{\text{field}}} + 
            \underbrace{
                        \alpha^2\int_\Omega|\nabla q|^2 dx
                        }_{J_{\text{regularisation}}}
$$

**where $\alpha$ is an appropriate value found with an l-curve analysis.**

As set up here increasing the number of points to assimilate has the effect of increasing the size of the misfit term in $J$ so we expect to converge to $q_\text{true}$ as the number of measurements increases.

As we increase the number of measurements in $J'$ we hope that our calculated $u_\text{interpolated}$ approaches $u$ (to therefore minimise the misfit). There is, however, no mechanism to cause the misfit term to increase relative to the regularization term.

We therefore predict that minimising $J$ will display posterior consistency and that minimising the various $J'$ for each $u_\text{interpolated}$ will not.

## Hypothesis Amendment! A note on finite element method error
Note that our solutions all exist in finite element spaces which are usually approximations of a true solution with some error that (hopefully) decreases as mesh density increase and solution space order increase.
Since I am comparing to a solution $u_\text{true}$ in CG2 space I expect, at best, that we will converge to $u_\text{true}$ when we have, on average, enough points per cell to fully specify the lagrange polynomials in that cell.
Were we in CG1 this would be 3 points per cell (I can't remember how many we would need for CG2!) to give convergence if those measurements had no noise.
Since our measurements are noisy I do not expect actual convergence, but I anticipate some slowing in convergence.

# Setup

In [None]:
import matplotlib.pyplot as plt
import firedrake

import numpy as np
from numpy import pi as π
from numpy import random

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import os, sys

import h5py

currentdir = os.path.dirname(os.path.realpath('__file__'))

In [None]:
mesh = firedrake.UnitSquareMesh(32, 32)

# Solution Space
V = firedrake.FunctionSpace(mesh, family='CG', degree=2)

# q (Control) Space
Q = firedrake.FunctionSpace(mesh, family='CG', degree=2)

## Load $u_\text{true}$ and $q_\text{true}$

In [None]:
u_true = firedrake.Function(V)
q_true = firedrake.Function(Q)

filename = os.path.join(currentdir, 'true-fields')
with firedrake.DumbCheckpoint(filename, mode=firedrake.FILE_READ) as chk:
    chk.load(q_true, name='q_true')
    chk.load(u_true, name='u_true')

# Plot Results

In [None]:
num_points_set = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]
methods = ['point-cloud', 'nearest', 'linear', 'clough-tocher', 'gaussian']
l2norms = {method: [] for method in methods}

## Plot Fields and Save L2 Error Norms

In [None]:
for num_points in num_points_set:

    # Test Loading and adjust plots as necessary
    filename = os.path.join(currentdir, 'observed-data.h5')
    try:
        with h5py.File(filename, 'r') as file:
            xs = file[f"xs_{num_points}"][:]
            u_obs_vals = file[f"u_obs_vals_{num_points}"][:]
            σ = firedrake.Constant(file[f"sigma_{num_points}"])
    except:
        # Can't load so move on
        continue
    methods_available = []        
    for method in methods:
        try:
            filename = os.path.join(currentdir, 'q-estimations')
            with firedrake.DumbCheckpoint(filename, mode=firedrake.FILE_READ) as chk:
                if method != 'point-cloud':
                    u_interpolated = firedrake.Function(V)
                    chk.load(u_interpolated, name=f'u_interpolated_{method}_{num_points}')
                q_min = firedrake.Function(Q)
                chk.load(q_min, name=f'q_min_{method}_{num_points}')
                q_err = firedrake.Function(Q)
                chk.load(q_err, name=f'q_err_{method}_{num_points}')
                methods_available.append(method)            
        except:
            pass
    if len(methods_available) == 0:
        # Nothing to plot so move on
        continue

    # Setup Plot
    ukw = {'vmin': 0.0, 'vmax': +0.2}
    kw = {'vmin': -4, 'vmax': +4, 'shading': 'gouraud'}
    ucbarticks = [0, 0.1, 0.2]
    qcbarticks = [-4, 0, 4]
    suptitle_fontsize = 25
    title_fontsize = 22
    title_padding = 10
    text_fontsize = 25
    default_fontsize = 15
    plt.rc('font', size=default_fontsize)
    cbarlabel = 'Magnitude (arb.)'
    xlabel = 'x'
    ylabel = 'y'
    # Change sharex and sharey to true to remove axes from inner plots
    fig, axes = plt.subplots(ncols=3, nrows=1+len(methods_available), sharex=False, sharey=False, figsize=(15,22.5), dpi=300)
    plt.suptitle('Estimating Log-Conductivity $q$ \n    where $k = k_0e^q$ and $-\\nabla \\cdot k \\nabla u = f$ for known $f$\n', fontsize=suptitle_fontsize)
    for ax in axes.ravel():
        ax.set_aspect('equal')
        ax.set_xticks([0, 0.5, 1])
        ax.set_yticks([0, 0.5, 1])
        # Comment out below if labelling only outer bits of plots
        ax.set(xlabel=xlabel, ylabel=ylabel)

    # Column 0
    axes[0, 0].set_title('$u_{true}$', fontsize=title_fontsize, pad=title_padding)
    colors = firedrake.tripcolor(u_true, axes=axes[0, 0], shading='gouraud', **ukw)
    cax = make_axes_locatable(axes[0, 0]).append_axes("right", size="5%", pad=0.05)
    fig.colorbar(colors, cax=cax, label=cbarlabel, ticks=ucbarticks)

    # Column 1
    axes[0, 1].set_title('$q_{true}$', fontsize=title_fontsize, pad=title_padding)
    colors = firedrake.tripcolor(q_true, axes=axes[0, 1], **kw)
    cax = make_axes_locatable(axes[0, 1]).append_axes("right", size="5%", pad=0.05)
    fig.colorbar(colors, cax=cax, label=cbarlabel, ticks=qcbarticks)

    # Column 2
    axes[0, 2].set_title('$q_{true}-q_{true}$', fontsize=title_fontsize, pad=title_padding)
    zero_func = firedrake.Function(Q).assign(q_true-q_true)
    axes[0, 2].text(0.5, 0.35, f'$L^2$ Norm\n{firedrake.norm(zero_func, "L2"):.2f}', ha='center', fontsize=text_fontsize)
    colors = firedrake.tripcolor(zero_func, axes=axes[0, 2], **kw);
    cax = make_axes_locatable(axes[0, 2]).append_axes("right", size="5%", pad=0.05)
    fig.colorbar(colors, cax=cax, label=cbarlabel, ticks=qcbarticks)

    for method_i, method in enumerate(methods_available):

        # Load fields
        filename = os.path.join(currentdir, 'q-estimations')
        with firedrake.DumbCheckpoint(filename, mode=firedrake.FILE_READ) as chk:
            if method != 'point-cloud':
                u_interpolated = firedrake.Function(V)
                chk.load(u_interpolated, name=f'u_interpolated_{method}_{num_points}')
            q_min = firedrake.Function(Q)
            chk.load(q_min, name=f'q_min_{method}_{num_points}')
            q_err = firedrake.Function(Q)
            chk.load(q_err, name=f'q_err_{method}_{num_points}')

        # Recalculate l2 norm error and save in l2norms
        l2norm = firedrake.norm(q_err, "L2")
        l2norms[method].append((num_points, l2norm))

        row = method_i+1

        # column 0
        if method == 'point-cloud':
            axes[row, 0].set_title('Sampled Noisy $u_{obs}$', fontsize=title_fontsize, pad=title_padding)
            colors = axes[row, 0].scatter(xs[:, 0], xs[:, 1], c=u_obs_vals, vmin=0.0, vmax=0.2)
        else:
            axes[row, 0].set_title(f'$u_{{interpolated}}^{{{method}}}$', fontsize=title_fontsize, pad=title_padding)
            colors = firedrake.tripcolor(u_interpolated, axes=axes[row, 0], shading='gouraud', **ukw)
        cax = make_axes_locatable(axes[row, 0]).append_axes("right", size="5%", pad=0.05)
        fig.colorbar(colors, cax=cax, label=cbarlabel, ticks=ucbarticks)

        # column 1
        if method == 'point-cloud':
            axes[row, 1].set_title('$q_{est}$ from $u_{obs}$', fontsize=title_fontsize, pad=title_padding)
        else:
            axes[row, 1].set_title(f'$q_{{est}}^{{{method}}}$', fontsize=title_fontsize, pad=title_padding)
        colors = firedrake.tripcolor(q_min, axes=axes[row, 1], **kw)
        cax = make_axes_locatable(axes[row, 1]).append_axes("right", size="5%", pad=0.05)
        fig.colorbar(colors, cax=cax, label=cbarlabel, ticks=qcbarticks)

        # column 2
        if method == 'point-cloud':
            axes[row, 2].set_title('$q_{est}-q_{true}$', fontsize=title_fontsize, pad=title_padding)
        else:
            axes[row, 2].set_title(f'$q_{{est}}^{{{method}}}-q_{{true}}$', fontsize=title_fontsize, pad=title_padding)
        axes[row, 2].text(0.5, 0.35, f'$L^2$ Norm\n{l2norm:.2f}', ha='center', fontsize=text_fontsize)
        colors = firedrake.tripcolor(q_err, axes=axes[row, 2], **kw);
        cax = make_axes_locatable(axes[row, 2]).append_axes("right", size="5%", pad=0.05)
        fig.colorbar(colors, cax=cax, label=cbarlabel, ticks=qcbarticks)
        
#     # label all first columns
#     [ax.set_ylabel(ylabel) for ax in axes[:, 0]]
#     # label all bottom rows
#     [ax.set_xlabel(xlabel) for ax in axes[row, :]]
        
    # Set spacing
    fig.tight_layout()
        
    # Save figure
    plt.savefig(f'posterior-consistency-{num_points}-pts.png', bbox_inches='tight')
    if num_points == 256 or num_points == 32768:
        # Only save pdfs of what we plan to publish since it takes a while
        plt.savefig(f'posterior-consistency-{num_points}-pts.pdf', bbox_inches='tight')


## Plot L2 Errors

In [None]:
fig, ax = plt.subplots(dpi=300)
ax.set_xscale('log', base=2)
cmap = plt.get_cmap("tab10")
lmap = ['solid', 'dashed', 'dashdot', 'dotted', 'solid', 'dashed', 'dashdot', 'dotted']
mmap = ['o', 'v', '^', '<', '>', 's', 'P', '*']
for i, method in enumerate(methods):
    arr = np.asarray(l2norms[method])
    method_num_points = arr[:,0]
    method_l2norms = arr[:,1]
    ax.plot(method_num_points, method_l2norms, marker=mmap[i], color=cmap(i), linestyle=lmap[i])
ax.legend(methods, title='Method', fontsize=10)
ax.set_xlabel('Number of Points')
ax.set_ylabel('$||q_{est}-q_{true}||_{L^2}$')
ax.set_title('Estimating Log-Conductivity $q$ \n where $k = k_0e^q$ and $-\\nabla \\cdot k \\nabla u = f$ for known $f$')
ax.axvline(256, linestyle='--', color='black')
ax.text(1, 0.35, 'L-Curves\nProduced\nFor $N=2^8$')
ax.arrow(32, 0.75, 224, 0, head_width=0.1, head_length=55, length_includes_head=True, color='black')
plt.savefig(f'l2norms.png', bbox_inches='tight')
plt.savefig(f'l2norms.pdf', bbox_inches='tight')