In [None]:
import datetime
import git
import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch
import warnings

import neuralpde


SAVING = True


def load_data(filename):
    with open(filename, 'rb') as fp:
        globals().update(pickle.load(fp))


In [None]:
subset_xy = np.s_[174:235, 64:110]
"""
Subset of (x, y) grid to use by index, ordered as (+/- longitude, +/- latitude)

# WARNING
Ordering here is not what you expect!  Be careful to check!
"""

q = 10
""" Number of intermediate RK stages to use. """

maps = 7
"""
Number of solution maps (forward and backward) to pass to PINN.

Note that the total number is 2 * maps + 1.
"""

date = datetime.date(1979, 8, 15)
# date = datetime.date(1980, 8, 15)
# date = datetime.date(2023, 8, 15)
""" Date to study. """

files = [
    f'data/V4/seaice_conc_daily_nh_{date.year-1}_v04r00.nc',
    f'data/V4/seaice_conc_daily_nh_{date.year}_v04r00.nc',
    f'data/V4/seaice_conc_daily_nh_{date.year+1}_v04r00.nc'
]


Alright, let's load some data and do some preliminary checks:

In [None]:
d = neuralpde.nc.SeaIceV4(files)

day = np.searchsorted(d.date, date)
indices = np.arange(day - maps, day + maps + 1)

neuralpde.nc.check_boundaries(indices, d)


Do some preliminary prep on the data:

In [None]:
u = d.seaice_conc[indices, *subset_xy]
u[np.isnan(u)] = 0.  # mask out NaN

(scalex, scaley), (x, y) = neuralpde.network.normalize_xy(d.meters_x, d.meters_y)
x, y = np.meshgrid(x, y)  # this ordering is annoying but necessary
x, y = x[subset_xy], y[subset_xy]

mask_coast = (d.flag_coast)[day, *subset_xy]
mask_other = (d.flag_land | d.flag_hole | d.flag_lake | d.flag_missing)[day, *subset_xy]
mask_any = mask_coast | mask_other


Now we define the loss hyperparameters with respect to which we will train:

In [None]:
# loss hyperparameters (i.e., relative weighting of terms in loss)
weights = np.array(
    [
        5.,  # differential loss (t_{n})
        5.,  # differential loss (t_{n+1})
        1.,  # boundary loss, no slip + no pen
        2.,  # kappa regularization
        2.,  # v regularization
        4.,  # f minimization
    ]
)
weights = weights / np.sqrt(np.sum(weights**2))


Now do the training.  :)

In [None]:
net = neuralpde.network.Network(q = q, shape = u.shape, kernel = 5).to(neuralpde.network.DEVICE, neuralpde.network.DTYPE)
losses = net.fit(x, y, u, weights, mask_coast, mask_other, epochs=1000, lr=1e-3)

# compute a prediction
uhat_i, uhat_f, kappa, kappa_x, kappa_y, v1, v1_x, v1_y, v2, v2_x, v2_y, f = net.predict(x, y, u)


In [None]:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha[:7]
name_slug = f'neuralpde-{date.strftime(r"%Y%m%d")}-{sha}.{datetime.datetime.now().strftime(r"%Y%m%d%H%M")}.'

if SAVING:
    torch.save(net.state_dict(), name_slug + 'weights.pth')

    with open(name_slug + 'results.pkl', 'wb') as fp:
        pickle.dump(
            {
                'date': date,
                'weights': weights,
                'losses': losses,
                'x': x,
                'scalex': scalex,
                'y': y,
                'scaley': scaley,
                'u': u,
                'uhat_i': uhat_i,
                'uhat_f': uhat_f,
                'kappa': kappa,
                'kappa_x': kappa_x,
                'kappa_y': kappa_y,
                'v1': v1,
                'v1_x': v1_x,
                'v1_y': v1_y,
                'v2': v2,
                'v2_x': v2_x,
                'v2_y': v2_y,
                'f': f
            },
            fp
        )


In [None]:
mask_any_plotting = np.ones_like(mask_any, dtype=float)
mask_any_plotting[mask_any] = np.nan

# ice plot
cmap = plt.get_cmap('Blues_r')
cmap.set_bad(color='gray')
plt.figure(figsize=(0.7 * 1.625 * 8, 1 * 8))
plt.pcolormesh(x * scalex / 1e3, y * scaley / 1e3, u[maps] * mask_any_plotting, cmap=cmap)
plt.colorbar().set_label('Sea Ice Concentration (fractional)')
plt.xlabel('km from grid center (x)')
plt.ylabel('km from grid center (y)')
plt.title(f'{date.strftime(r"%m/%d/%Y")}')
plt.gca().set_aspect('equal')
plt.tight_layout()
if SAVING: plt.savefig(name_slug + 'seaice-conc.png', dpi=300)
plt.show()

# predicted ice plot
cmap = plt.get_cmap('Blues_r')
cmap.set_bad(color='gray')
plt.figure(figsize=(0.7 * 1.625 * 8, 1 * 8))
plt.pcolormesh(x * scalex / 1e3, y * scaley / 1e3, uhat_i[-1] * mask_any_plotting, cmap=cmap)
plt.colorbar().set_label('Sea Ice Concentration (fractional)')
plt.xlabel('km from grid center (x)')
plt.ylabel('km from grid center (y)')
plt.title(f'{date.strftime(r"%m/%d/%Y")}')
plt.gca().set_aspect('equal')
plt.tight_layout()
if SAVING: plt.savefig(name_slug + 'seaice-conc-predicted.png', dpi=300)
plt.show()

# error plot
cmap = plt.get_cmap('jet')
cmap.set_bad(color='gray')
plt.figure(figsize=(0.7 * 1.625 * 8, 1 * 8))
plt.pcolormesh(x * scalex / 1e3, y * scaley / 1e3, np.abs(u[maps] - uhat_f[-1]) * mask_any_plotting, cmap=cmap)
plt.colorbar().set_label('Absolute Error')
plt.xlabel('km from grid center (x)')
plt.ylabel('km from grid center (y)')
plt.title(f'{date.strftime(r"%m/%d/%Y")}')
plt.gca().set_aspect('equal')
plt.tight_layout()
if SAVING: plt.savefig(name_slug + 'abs-err.png', dpi=300)
plt.show()

# parameter plots
cmap = plt.get_cmap('jet')
cmap.set_bad(color='gray')

plt.figure(figsize=(0.7 * 1.625 * 8, 1 * 8))
plt.pcolormesh(x * scalex / 1e3, y * scaley / 1e3, kappa * mask_any_plotting, cmap=cmap)
plt.colorbar().set_label('Diffusivity (unitless)')
plt.xlabel('km from grid center (x)')
plt.ylabel('km from grid center (y)')
plt.title(f'{date.strftime(r"%m/%d/%Y")}')
plt.gca().set_aspect('equal')
plt.tight_layout()
if SAVING: plt.savefig(name_slug + 'kappa.png', dpi=300)
plt.show()

plt.figure(figsize=(0.7 * 1.625 * 8, 1 * 8))
plt.pcolormesh(x * scalex / 1e3, y * scaley / 1e3, v1 * mask_any_plotting, cmap=cmap)
plt.colorbar().set_label('Lateral Velocity (unitless)\n(negative is left, positive is right)')
plt.xlabel('km from grid center (x)')
plt.ylabel('km from grid center (y)')
plt.title(f'{date.strftime(r"%m/%d/%Y")}')
plt.gca().set_aspect('equal')
plt.tight_layout()
if SAVING: plt.savefig(name_slug + 'v1.png', dpi=300)
plt.show()

plt.figure(figsize=(0.7 * 1.625 * 8, 1 * 8))
plt.pcolormesh(x * scalex / 1e3, y * scaley / 1e3, v2 * mask_any_plotting, cmap=cmap)
plt.colorbar().set_label('Vertical Velocity (unitless)\n(negative is down, positive is up)')
plt.xlabel('km from grid center (x)')
plt.ylabel('km from grid center (y)')
plt.title(f'{date.strftime(r"%m/%d/%Y")}')
plt.gca().set_aspect('equal')
plt.tight_layout()
if SAVING: plt.savefig(name_slug + 'v2.png', dpi=300)
plt.show()

plt.figure(figsize=(0.7 * 1.625 * 8, 1 * 8))
plt.pcolormesh(x * scalex / 1e3, y * scaley / 1e3, f * mask_any_plotting, cmap=cmap)
plt.colorbar().set_label('Forcing (unitless)')
plt.xlabel('km from grid center (x)')
plt.ylabel('km from grid center (y)')
plt.title(f'{date.strftime(r"%m/%d/%Y")}')
plt.gca().set_aspect('equal')
plt.tight_layout()
if SAVING: plt.savefig(name_slug + 'f.png', dpi=300)
plt.show()
