<a href="https://colab.research.google.com/github/YaoGroup/DIFFICE_jax/blob/main/examples/colab/train_xpinns_iso.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Isotropic viscosity inversion of large ice shelves via extended-PINNs (XPINNS)

Fast running the code requires **GPU hardware** to accelerate.

To choose GPU hardware, click the "Edit" in the toolbar above and select "Notebook settings".

In the pop-up window, under the "Hardware accelerator", select "A100 GPU" (Pro account required or compute unit needed)

The inversion of large ice shelves via XPINNs overall required more computational resurce. **Highly recommend** users to choose **A100** GPU for the training. Otherwise, the users might encouter out-of-memory (OOM) issues.

To ensure the code to run correctly in the colab, install the specific version of JAX below.

(Note that only using the GPU version of JAX, combined with the selection of GPU hardware as mentioned above, can eventually accelerate the training.)

Copying required python functions from the [DIFFICE_jax](https://github.com/YaoGroup/DIFFICE_jax) GitHub repository

In [None]:
# install the DIFFICE_jax
!pip install DIFFICE_jax

# download the data from GitHub respository
!wget https://github.com/YaoGroup/DIFFICE_jax/raw/main/examples/real_data/data_xpinns_RnFlch.mat


Importing the required JAX library and function file from GitHub repository

In [3]:
import jax.numpy as jnp
import numpy as np
from jax import random, lax
from jax.tree_util import tree_map
from scipy.io import loadmat
import time

from diffice_jax import normdata_xpinn, dsample_xpinn
from diffice_jax import vectgrad, ssa_iso, dbc_iso
from diffice_jax import init_xpinn, solu_xpinn
from diffice_jax import loss_iso_xpinn
from diffice_jax import predict_xpinn
from diffice_jax import adam_opt, lbfgs_opt


# Setting hyperparameters

hyper-parameters used for the training. Users are free to modify their value and check their influence on the training results


In [4]:
# select a random seed
seed = 2134
key = random.PRNGKey(seed)
np.random.seed(seed)

# create the subkeys
keys = random.split(key, 4)

# select the size of neural network
n_hl = 5
n_unit = 30
# set the weights for 1. equation loss, 2. boundary condition loss
# 3. matching condition loss and 4. regularization loss
lw = [0.05, 0.1, 1, 0.25]

# number of sampling points
n_smp = 4000    # for velocity data
nh_smp = 3500   # for thickness data
n_col = 4000    # for collocation points
n_cbd = 400     # for boundary condition (calving front)
# group all the number of points
n_pt = jnp.array([n_smp, nh_smp, n_col, n_cbd], dtype='int32')
# double the points for L-BFGS training
n_pt2 = n_pt * 2


# Data Loading
load and normalize tne observed data before the PINN training

In [5]:
# load the synthetic data
rawdata = loadmat('data_xpinns_RnFlch.mat')

# normalize the remote-sensing data for the XPINNs training
data_all, idxgall, posi_all, idxcrop_all = normdata_xpinn(rawdata)
# extract the scale information for each variable
scale = tree_map(lambda x: data_all[x][4][0:2], idxgall)


# Initialization

initialize the neural network and loss function

In [6]:
# initialize the weights and biases of the network
trained_params = init_xpinn(keys[0], n_hl, n_unit,
                             n_sub=len(idxgall))

# create the solution function [tuple(callable, callable)]
solNN = solu_xpinn(scale)

# create the data function for Adam
dataf = dsample_xpinn(data_all, idxgall, n_pt)
keys_adam = random.split(keys[1], 5)
# generate the data
data = dataf(keys_adam[0])

# create the data function for L-BFGS
dataf_l = dsample_xpinn(data_all, idxgall, n_pt2)
key_lbfgs = keys[2]

# group the gov. eqn and bd cond.
eqn_all = (ssa_iso, dbc_iso)
# calculate the loss function
NN_loss = loss_iso_xpinn(solNN, eqn_all, scale, idxgall, lw)
# calculate the initial loss and set it as the loss reference value
NN_loss.lref = NN_loss(trained_params, data)[0]


# Network training

Since the real ice shelf data has more complicated profile than the synthetic data. 10000 iterations of Adam, followed by another 10000 iterations of L-BFGS can only infer a **very rough** profile of the ice viscosity.

To train a high-accurate model, the number of iterations required for both Adam and L-BFGS optimization is more than 100k.


In [None]:
# set the learning rate for Adam
lr = 1e-3
# set the training iteration
epoch1 = 10000

# training with Adam with reducing w
trained_params, loss1 = adam_opt(
    keys_adam[1], NN_loss, trained_params, dataf, epoch1, lr=lr)


Step: 100 | Loss: 8.3111e-02 | Loss_d: 1.4055e-01 | Loss_e: 2.4170e-02 | Loss_b: 2.1272e-03


Extra training using L-BFGS to reach higher accuracy

Recommended number of iterations: 10000

In [None]:
# set the training iteration
epoch2 = 10000
# re-sample the data and collocation points
data_l = dataf_l(key_lbfgs[1])

# training with L-bfgs
trained_params2, loss2 = lbfgs_opt(NN_loss, trained_params, data_l, epoch2)


# Prediction

Compute the solution variables and equation residue at high-resolution grids

In [7]:
# create the function for trained solution and equation residues
f_u = lambda x, idx: solNN[0](trained_params2, x, idx)

# group all the function
func_all = (f_u, ssa_iso)

# calculate the solution and equation residue at given grids for visualization
results = predict_xpinn(func_all, data_all, posi_all, idxcrop_all, idxgall)


# Plotting the results:

Compare the synthetic data for either velocity or thickness with the corresponding network approximation

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

u_g = results['u_g']
u = results['u']

fig = plt.figure(figsize = [10, 10], dpi = 70)

ax = plt.subplot(2,1,1)
h = ax.imshow(u_g, interpolation='nearest', cmap='rainbow',
              extent=[0., 50000., 0,  80000.],
              origin='lower', aspect='auto')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4%", pad=0.05)
fig.colorbar(h, cax=cax)

ax.set_xlabel('$x$', fontsize = 15)
ax.set_ylabel('$y\ $', fontsize = 15, rotation = 0)
ax.set_title('Syn. Data $u_g(x,y)$ (m/s)', fontsize = 15)


ax2 = plt.subplot(2,1,2)
h2 = ax2.imshow(u, interpolation='nearest', cmap='rainbow',
              extent=[0., 50000., 0,  80000.],
              origin='lower', aspect='auto')
divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="4%", pad=0.05)
fig.colorbar(h2, cax=cax)

ax2.set_xlabel('$x$', fontsize = 15)
ax2.set_ylabel('$y\ $', fontsize = 15, rotation = 0)
ax2.set_title('Network approx. $u(x,y)$ (m/s)', fontsize = 15)


Showing the inferred viscosity via PINNs for the ice shelf

In [None]:
# load the PINN inference of viscosity
mu = results['mu']

fig = plt.figure(figsize = [10, 5], dpi = 70)

ax = plt.subplot(1,1,1)
h = ax.imshow(mu, interpolation='nearest', cmap='rainbow',
              extent=[0., 50000., 0,  80000.],
              origin='lower', aspect='auto')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4%", pad=0.05)
fig.colorbar(h, cax=cax)

ax.set_xlabel('$x$', fontsize = 15)
ax.set_ylabel('$y\ $', fontsize = 15, rotation = 0)
ax.set_title('Inferred viscosity $\mu(x,y)$', fontsize = 15)

