<a href="https://colab.research.google.com/github/YaoGroup/DIFFICE_jax/blob/main/ice_shelf/pinn_icedata.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Viscosity inversion for real ice shelves via PINNs

Fast running the code requires **GPU hardware** to accelerate.

To choose GPU hardware, click the "Edit" in the toolbar above and select "Notebook settings".

In the pop-up window, under the "Hardware accelerator", select "T4 GPU"

To ensure the code to run correctly in the colab, install the specific version of JAX below.

(Note that only using the GPU version of JAX, combined with the selection of GPU hardware as mentioned above, can eventually accelerate the training.)

In [1]:
# Install specific version (0.4.23) of JAX and Jaxlib
!pip install --upgrade jax==0.4.23 jaxlib==0.4.23+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Verify the installed version
import jax
print(jax.__version__)

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax==0.4.23
  Downloading jax-0.4.23-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxlib==0.4.23+cuda12.cudnn89
  Downloading https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23%2Bcuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl (131.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m131.8/131.8 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
    Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
      Successfully uninstalled jaxlib-0.4.26+cuda12.cudnn89
  Attempting uninstall: jax
    Found existing installation: jax 0.4.26
    Uninstalling jax-0.4.26:
      Successfully uninstalled jax-0.4.26
Successfully installed j

Copying required python functions from the [DIFFICE_jax](https://github.com/YaoGroup/DIFFICE_jax) GitHub repository

In [2]:
!git clone https://github.com/YaoGroup/DIFFICE_jax

Cloning into 'DIFFICE_jax'...
remote: Enumerating objects: 724, done.[K
remote: Counting objects: 100% (444/444), done.[K
remote: Compressing objects: 100% (265/265), done.[K
remote: Total 724 (delta 339), reused 249 (delta 178), pack-reused 280[K
Receiving objects: 100% (724/724), 25.82 MiB | 11.23 MiB/s, done.
Resolving deltas: 100% (430/430), done.


Importing the required JAX library and function file from GitHub repository

In [3]:
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.tree_util import tree_map
from scipy.io import loadmat
import time

from DIFFICE_jax.data.sampling import data_sample_create
from DIFFICE_jax.equation.ssa_eqn_iso import vectgrad, gov_eqn, front_eqn
from DIFFICE_jax.model.initialization import init_MLP
from DIFFICE_jax.model.networks import solu_create
from DIFFICE_jax.model.loss import loss_create
from DIFFICE_jax.model.prediction import predict
from DIFFICE_jax.optimizer.optimizer import adam_optimizer, lbfgs_optimizer
from DIFFICE_jax.ice_shelf.load_icedata import iceshelf_data


# Setting hyperparameters

hyper-parameters used for the training. Users are free to modify their value and check their influence on the training results


In [4]:
# select a random seed
seed = 2134
key = random.PRNGKey(seed)
np.random.seed(seed)

# create the subkeys
keys = random.split(key, 4)

# select the size of neural network
n_hl = 6  # number of hidden layers
n_unit = 30  # number of units in each hidden layer

# select the weight for the equation and boundary conditions
lw = [0.05, 0.1]

# number of points sampled for each variable from the synthetic data
n_smp = 6000
# number of collocation points to evaluate the equation residue
n_col = 6000
# number of collocation points to evalute boundary conditions
n_cbd = 600
# group all the points
n_pt = jnp.array([n_smp, n_col, n_cbd], dtype='int32')

# double all the points for L-BFGS training (fixed points over iterations)
n_pt2 = n_pt * 2

# select the ice shelf for the training
shelfname = 'Amery'

# create the dataset filename
DataFile = 'Data_' + shelfname + '.mat'
# re-organized the synthetic data for the PINN code
data_all = iceshelf_data(DataFile, step=1)
# extract the scale information for each variable
scale = data_all[4][0:2]


# Initializing the model and the loss function

In [5]:
# initialize the weights and biases of the network
trained_params = init_MLP(keys[0], n_hl, n_unit)

# create the solution function
pred_u = solu_create()

# create the data sampling function for Adam training
dataf = data_sample_create(data_all, n_pt)
keys_adam = random.split(keys[1], 5)
# generate the data
data = dataf(keys_adam[0])

# create the data sampling function for L-BFGS training
dataf_l = data_sample_create(data_all, n_pt2)
key_lbfgs = random.split(keys[2], 5)


# group the gov. eqn and bdry cond.
eqn_all = (gov_eqn, front_eqn)
# create the loss function
NN_loss = loss_create(pred_u, eqn_all, scale, lw)
# calculate the initial loss and set it as the reference value for loss
NN_loss.lref = NN_loss(trained_params, data)[0]


# Training using ADAM optimizer

Require at least 10000 iterations of Adam,
plus 10000 iterations of L-BFGS later
for a relatively accurate trained model


In [None]:
# set the learning rate for Adam
lr = 1e-3
# set the training iteration
epoch1 = 10000

# training with Adam
trained_params, loss1 = adam_optimizer(keys_adam[1], NN_loss, trained_params, dataf, epoch1, lr=lr)


Step: 20 | Loss: 1.7636e-01 | Loss_d: 2.9135e-01 | Loss_e: 2.0864e-01 | Loss_b: 1.8249e-02
Step: 40 | Loss: 1.1493e-01 | Loss_d: 1.7732e-01 | Loss_e: 3.9534e-01 | Loss_b: 7.5765e-03
Step: 60 | Loss: 9.7267e-02 | Loss_d: 1.6523e-01 | Loss_e: 3.9759e-02 | Loss_b: 2.2117e-03
Step: 80 | Loss: 8.9029e-02 | Loss_d: 1.5177e-01 | Loss_e: 2.5959e-02 | Loss_b: 1.9754e-03
Step: 100 | Loss: 7.8914e-02 | Loss_d: 1.3462e-01 | Loss_e: 2.1024e-02 | Loss_b: 1.7285e-03
Step: 120 | Loss: 7.6020e-02 | Loss_d: 1.2896e-01 | Loss_e: 3.5013e-02 | Loss_b: 1.5663e-03
Step: 140 | Loss: 7.1118e-02 | Loss_d: 1.1897e-01 | Loss_e: 6.5185e-02 | Loss_b: 1.9583e-03
Step: 160 | Loss: 6.7068e-02 | Loss_d: 1.1122e-01 | Loss_e: 7.8701e-02 | Loss_b: 3.0145e-03
Step: 180 | Loss: 6.1565e-02 | Loss_d: 1.0081e-01 | Loss_e: 9.5210e-02 | Loss_b: 4.1576e-03
Step: 200 | Loss: 5.5484e-02 | Loss_d: 8.8942e-02 | Loss_e: 1.1929e-01 | Loss_b: 6.0737e-03
Step: 220 | Loss: 4.7944e-02 | Loss_d: 7.4678e-02 | Loss_e: 1.3959e-01 | Loss_b: 8.7

# Training using L-BFGS optimizer
Extra training using L-BFGS to reach higher accuracy

Recommended number of iterations: 10000

In [None]:
# set the training iteration
epoch2 = 10000
# re-sample the data and collocation points
data_l = dataf_l(key_lbfgs[1])

# training with L-bfgs
trained_params2, loss2 = lbfgs_optimizer(NN_loss, trained_params, data_l, epoch2)

# Computing the output of trained networks

Compute the solution variables and equation residue at high-resolution grids

In [None]:
# Function of solution and equation residues based on trained networks
f_u = lambda x: pred_u(trained_params2, x)
f_gu = lambda x: vectgrad(f_u, x)[0][:, 0:6]
f_eqn = lambda x: gov_eqn(f_u, x, scale)

# group all the function
func_all = (f_u, f_gu, f_eqn)

# calculate the solution and equation residue at given grids for visualization
results = predict(func_all, data_all)


# Plotting the results:

Compare the synthetic data for either velocity or thickness with the corresponding network approximation

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

u_g = results['u_g']
u = results['u']

fig = plt.figure(figsize = [10, 10], dpi = 70)

ax = plt.subplot(2,1,1)
h = ax.imshow(u_g, interpolation='nearest', cmap='rainbow',
              extent=[0., 50000., 0,  80000.],
              origin='lower', aspect='auto')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4%", pad=0.05)
fig.colorbar(h, cax=cax)

ax.set_xlabel('$x$', fontsize = 15)
ax.set_ylabel('$y\ $', fontsize = 15, rotation = 0)
ax.set_title('Syn. Data $u_g(x,y)$ (m/s)', fontsize = 15)


ax2 = plt.subplot(2,1,2)
h2 = ax2.imshow(u, interpolation='nearest', cmap='rainbow',
              extent=[0., 50000., 0,  80000.],
              origin='lower', aspect='auto')
divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="4%", pad=0.05)
fig.colorbar(h2, cax=cax)

ax2.set_xlabel('$x$', fontsize = 15)
ax2.set_ylabel('$y\ $', fontsize = 15, rotation = 0)
ax2.set_title('Network approx. $u(x,y)$ (m/s)', fontsize = 15)


Showing the inferred viscosity via PINNs for the ice shelf

In [None]:
# load the PINN inference of viscosity
mu = results['mu']

fig = plt.figure(figsize = [10, 5], dpi = 70)

ax = plt.subplot(1,1,1)
h = ax.imshow(mu_g, interpolation='nearest', cmap='rainbow',
              extent=[0., 50000., 0,  80000.],
              origin='lower', aspect='auto')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4%", pad=0.05)
fig.colorbar(h, cax=cax)

ax.set_xlabel('$x$', fontsize = 15)
ax.set_ylabel('$y\ $', fontsize = 15, rotation = 0)
ax.set_title('Inferred viscosity $\mu(x,y)$', fontsize = 15)

