 ### Estime the inital conditions using PINNs (inverse problem). 
The goal of the inverse problem is to estimate the inital pressure field $p_\text{ini}(x,y) = p(x,y,0)$ that satisfies the wave equation,
$$\frac{\partial^2 p}{\partial x^2} + \frac{\partial^2 p}{\partial y^2} - \frac{1}{c^2} \frac{\partial^2 p}{\partial t^2} = 0.$$

The inference of $p_\text{ini}(x,y)$ will be based on:
- Observed data 
- Physics via the undelying PDE

We will use physics-informed neural networks (PINNs) for the task of estimating $p_\text{ini}(x,y)$. 
We start by importing the needed packages and defining a class for the neural network as well as a few auxiliary functions.

In [None]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from PINNs_util.PINNs_fdiff import solver
from PINNs_util.PINNs_aux import FCN,\
                        xyt_tensor,\
                        pde_residual,\
                        update_lambda,\
                        absorbing_boundary,\
                        rand_colloc,\
                        rand_bound
from PINNs_util.PINNs_plots import plot_train_log_bound,\
                        plot_train_log,\
                        plot_field,\
                        plot_data,\
                        plot_inital_estimation,\
                        plot_estimation

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

### Problem setup
We define a 2+1D domain. The spatial domain is a square of side length $L=5$ km, and the time domain has a duration $T=1.5$ s. The speed of sound is fixed to a constant value, $c=3$ km/s. The source is modelled as a Gaussian pulse of width `gpulse_std`. The source is placed inside the domain, centered at `r_source`.

It is important to scale the dimensions of the domain so that none of the terms in the PDE is very large or very small. For details on this see *section 3. Non-dimensionalization* in [Wang 2023] https://doi.org/10.48550/arXiv.2308.08468 and the book *Scaling of Differential Equations* by Langtangen and Pedersen. 

To simulate the wave propagation, we use a python implementation of the finite difference method based on Hans Petter Langtangen book "Finite Difference Computing with PDEs" https://hplgit.github.io/fdm-book/doc/pub/book/sphinx/index.html. We compute and plot the reference solution.

In [None]:
# define domain
L = 5
T = 1.5
c = 3

# scaling
T = T*c/L
L = L/L
c = 1

# simulation
Lx, Ly = L, L
Nx, Ny = 30, 30
dt = -1

# initial condition
def I(x, y):
    gpulse_std_x = 5e-2
    gpulse_std_y = 10e-2
    r_source = np.array([0.3, 0.3])
    I = np.exp(-0.5*( ((x-r_source[0])/gpulse_std_x)**2 +\
                      ((y-r_source[1])/gpulse_std_y)**2 ))
    r_source = np.array([0.5, 0.6])
    I = I + np.exp(-0.5*( ((x-r_source[0])/gpulse_std_y)**2 +\
                      ((y-r_source[1])/gpulse_std_x)**2 ))
    return I

# solve
p_ref, x, y, t, dt = solver(I, 0, 0, c, Lx, Ly, Nx, Ny, dt, T)
p_ref, x, y, t = p_ref.astype(np.float32), x.astype(np.float32), y.astype(np.float32), t.astype(np.float32)

# we need tensors
xx, yy = np.meshgrid(x, y)
xy = np.column_stack((np.reshape(xx,(-1,1)), np.reshape(yy,(-1,1))))
r_ref = xyt_tensor(xy, t, device)
n_T = t.shape[0]
n_L = x.shape[0]

# plot reference field (it might take a few seconds)
ani = plot_field(p_ref, L, 'reference pressure')
ani

### Generate data
We now define the observed data used for estimating the field. Our data will the pressure field on a number of points alinged with the x-axis (black triangles in the left figure). The observed data, called `p_data`, is shown in the rigth figure.
In addition, we assume that we know the position, propagation direction, and shape of the source. This is enconded in two early-time snapshots (see [Rasht-Behesht 2022] https://doi.org/10.1029/2021JB023120). We will call those two spnashots `p_ini`. With this information we are ready to train our PINN. 

In [None]:
i_data = 20
x_data = xx[i_data,:]
y_data = yy[i_data,:]
xy_data = np.column_stack((x_data, y_data))
r_data = xyt_tensor(xy_data, t, device)
p_data = torch.tensor(p_ref[i_data,:,:])

# Plot data
plot_data(p_ref, p_data, x_data, y_data, L, T)

p_data = p_data.to(device)

# Initial condition
t_ini = t[0:2]
r_ini = xyt_tensor(xy, t_ini, device)
p_ini = p_ref[:,:,0:2].reshape(-1,2)
p_ini = torch.tensor(p_ini, device=device)