In [1]:
import numpy as np
import matplotlib.pyplot as plt 
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch
import jax
import jax.numpy as jnp
import time

## True solution 

We now simulate the solution for a simple mountain glacier model, taken from *Fundamentals of Glacier Dynamics* by CJ van der Veen. The system is modeled by a non-linear, highly diffusive PDE.

$$
 \frac{\partial H}{\partial t } = -\frac{\partial}{\partial x}\left(-D(x)\frac{\partial h}{\partial x}\right) + M\\
  D(x) = CH^{n+2}\left|\frac{\partial h}{\partial x}\right|^{n-1}\\
  C = \frac{2A}{n+2}(\rho g)^n\\
    H(x,t) = h(x,t) - b(x) \\ 
    H_l = 0, H_r > 0
$$

$$\frac{\partial{b}}{\partial{x}} = -0.1\\
M(x) = M_0 - x M_1 \:\text{(accumulation rate, essentially a source term)}\\
M_0 = 4.0 \:\text{m/yr}, \:M_1 = 0.0002 \:\text{yr}^{-1}\\
\rho = 920 \:\text{kg/m}^3\\
g = 9.8 \:\text{m/s}^2\\
A = 10^{-16} \: \text{Pa}^{-3} \text{a}^{-1}\\
n = 3\\
dx = 1.0 \:\text{km}, \:L = 30 \:\text{km}\\
dt = 1 \:\text{month}, \:T = 2000 \:\text{yr}$$

The true solution will be generated using a staggered grid finite volume method on a fine, uniform grid.



In [2]:
def true_solution(x, t, M):

    nx = x.shape[0] - 1
    nt = t.shape[0] - 1
    b = 1.0 - 0.0001*x
    
    A = 1e-16
    rho = 920.0
    g = 9.2
    n = 3

    C = 2*A/(n+2) * (rho*g)**n * (1e3)**n

    h = np.zeros((nx+1,nt+1))
    H = np.zeros((nx+1,nt+1))
    h[:,0] = b
    h[0,:] = b[0]
    h[-1,:] = b[-1]

    H[:,0] = h[:,0] - b
    H[0,:] = h[0,:] - b[0]
    H[-1,:] = h[-1,:] - b[-1]

    for i in range(1,len(t)):

        D = C *((H[1:,i-1]+H[:nx,i-1])/2.0)**(n+2) * ((h[1:,i-1] - h[:nx,i-1])/dx)**(n-1)
        phi = -D*(h[1:,i-1]-h[:nx,i-1])/dx

        h[1:nx,i] = h[1:nx,i-1] + M[1:nx]*dt - dt/dx * (phi[1:]-phi[:nx-1])
        h[1:nx,i] = (h[1:nx,i] < b[1:nx]) * b[1:nx] + (h[1:nx,i] >= b[1:nx]) * h[1:nx,i]
        
        H[:,i] = h[:,i] - b

    V = np.sum(H[:,-1])*dx
    return V

start_time = time.time()

L = 30.0
T = 5000.0
dx = 1.0
dt = 1./12.
nx = int(L/dx)
nt = int(T/dt)

x = np.linspace(0,L,nx+1)
t = np.linspace(0,T,nt+1)
M = 0.004 - 0.0002*x

V = true_solution(x, t, M)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

V

Elapsed time: 1.5260663032531738 seconds


np.float64(12.364982412145057)

In [3]:
def true_solution_torch(x, t, M, dx, dt, nx, nt):
    b = 1.0 - 0.0001 * x
    
    A = 1e-16
    rho = 920.0
    g = 9.2 
    n = 3

    C = 2 * A / (n + 2) * (rho * g)**n * (1e3)**n

    h = torch.zeros((nx + 1, nt + 1), dtype=torch.float64)
    H = torch.zeros((nx + 1, nt + 1), dtype=torch.float64)
    h[:, 0] = b
    h[0, :] = b[0]
    h[-1, :] = b[-1]

    H[:, 0] = h[:, 0] - b
    H[0, :] = h[0, :] - b[0]
    H[-1, :] = h[-1, :] - b[-1]

    for i in range(1, len(t)):

        D = C * ((H[1:, i - 1] + H[:nx, i - 1]) / 2.0)**(n + 2) * ((h[1:, i - 1] - h[:nx, i - 1]) / dx)**(n - 1)
        phi = -D * (h[1:, i - 1] - h[:nx, i - 1]) / dx

        h[1:nx, i] = h[1:nx, i - 1] + M[1:nx] * dt - dt / dx * (phi[1:] - phi[:nx - 1])
        h[1:nx, i] = torch.where(h[1:nx, i] < b[1:nx], b[1:nx], h[1:nx, i])
        
        H[:, i] = h[:, i] - b

    V = torch.sum(H[:, -1]) * dx
    return V

start_time = time.time()

# Define parameters
L = 30.0
T = 5000.0
dx = 1.0
dt = 1. / 12.
nx = int(L / dx)
nt = int(T / dt)

x = torch.linspace(0, L, nx + 1, dtype=torch.float64)
t = torch.linspace(0, T, nt + 1, dtype=torch.float64)

M = torch.tensor(0.004 - 0.0002 * torch.linspace(0, L, nx + 1), dtype=torch.float64, requires_grad=True)

# Compute the solution
V = true_solution_torch(x, t, M, dx, dt, nx, nt)

# Compute the gradient of V with respect to M
V.backward()

# The gradient is stored in M.grad
dVdM = M.grad

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

dVdM

  M = torch.tensor(0.004 - 0.0002 * torch.linspace(0, L, nx + 1), dtype=torch.float64, requires_grad=True)


Elapsed time: 792.7237458229065 seconds


tensor([  0.0000,   7.7235,  13.8033,  19.6687,  25.4607,  31.2770,  37.2099,
         43.3705,  49.9198,  57.1388,  65.6429,  77.4524, 139.2594, 149.0766,
        153.9650, 156.6878, 158.0030, 158.2083, 157.4224, 155.6658, 152.8868,
        148.9615, 143.6755, 136.6793, 127.3935, 114.7918,  96.8271,  68.4260,
          0.0000,   0.0000,   0.0000], dtype=torch.float64)

In [4]:
def true_solution_jax(x, t, M, dx, dt, nx, nt):

    nx = x.shape[0] - 1
    nt = t.shape[0] - 1
    b = 1.0 - 0.0001*x
    
    A = 1e-16
    rho = 920.0
    g = 9.2 
    n = 3

    C = 2*A/(n+2) * (rho*g)**n * (1e3)**n

    h = jnp.zeros((nx+1,nt+1))
    H = jnp.zeros((nx+1,nt+1))
    h = h.at[:,0].set(b)
    h = h.at[0,:].set(b[0])
    h = h.at[-1,:].set(b[-1])

    H = H.at[:,0].set(h[:,0] - b)
    H = H.at[0,:].set(h[0,:] - b[0])
    H = H.at[-1,:].set(h[-1,:] - b[-1])

    for i in range(1,len(t)):

        D = C * ((H[1:,i-1] + H[:nx,i-1])/2.0)**(n+2) * ((h[1:,i-1] - h[:nx,i-1])/dx)**(n-1)
        phi = -D * (h[1:,i-1] - h[:nx,i-1]) / dx

        h = h.at[1:nx,i].set(h[1:nx,i-1] + M[1:nx] * dt - dt/dx * (phi[1:] - phi[:nx-1]))
        h = h.at[1:nx,i].set((h[1:nx,i] < b[1:nx]) * b[1:nx] + (h[1:nx,i] >= b[1:nx]) * h[1:nx,i])
        
        H = H.at[:,i].set(h[:,i] - b)

    V = jnp.sum(H[:,-1]) * dx
    return V

start_time = time.time()

# Set default precision to float64
jax.config.update("jax_enable_x64", True)

# Define parameters
L = 30.0
T = 5000.0
dx = 1.0
dt = 1./12.
nx = int(L/dx)
nt = int(T/dt)

x = jnp.linspace(0, L, nx+1)
t = jnp.linspace(0, T, nt+1)

M = 0.004 - 0.0002*jnp.linspace(0, L, nx+1)

# Use JAX's grad to compute the derivative of V with respect to M
# In 0-indexing, 2 corresponds as the third argument
dVdM = jax.grad(true_solution_jax, argnums=2)(x, t, M, dx, dt, nx, nt)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

dVdM

Elapsed time: 10283.995254516602 seconds


Array([  0.        ,   7.7234855 ,  13.8033328 ,  19.66865227,
        25.46071801,  31.27695349,  37.2099016 ,  43.37050976,
        49.91984079,  57.13880789,  65.64284624,  77.45237268,
       139.2596656 , 149.07685504, 153.96525772, 156.68808935,
       158.00321266, 158.20855233, 157.42261785, 155.66608018,
       152.88707221, 148.96174071, 143.67571334, 136.67955242,
       127.39369517, 114.79204539,  96.82724429,  68.42610895,
         0.        ,   0.        ,   0.        ], dtype=float64)