In [1]:
import torch
from helpers.ann_tools import *
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 3

# ANNs on staggered grids with stencil 3x3
In this notebook, we enlarge spatial stencil to 3x3, but also simplify features: all input features are on the same grid. Also, we train two distinct networks for corner and center points because it allows to directly account for rotation symmetry.

![figure](Staggered_grid_3x3.png)

# ZB20 as ANN
Note the following important changes:

In [2]:
def ZB20_Txy(x):
    '''
    x is the vector containing three input features: sh_xy(corner), sh_xx (interpolated to corner), vort_xy(corner)
    They are stacked together to form 3*9 input features
    9 consecutive points describe the stencil 3x3 where 
    the fast index corresponds to zonal ("x") direction (as in MOM6) 
    
    This function is NN, being a part of mapping T = delta^2 * ||X||^2 * NN(X / ||X||)
    with norm being second norm, and delta is the grid step.
    
    The output vector y contains values Txy, where
    T = [0 Txy; Txy, 0] and
    du/dt = -div(T)
    
    Note that the leftmost dimension is batch, i.e.
    x(:,12), y(:,3)
    '''
    if x.shape[1] != 27:
        print('Error')
        return
    
    # Vorticity in corner point, corresponding to prediction point
    # 9*2 means that vort_xy is the last fature in input vector
    vort_xy = x[:,4+9*2].reshape(-1,1)
    # Again in the same corner point, but now it is the second feature, so 1*9
    sh_xx = x[:,4+9*1].reshape(-1,1)
    
    Txy = vort_xy * sh_xx
    
    y = torch.zeros(x.shape[0],1)
    y[:,0:1] = Txy
    return y

In [3]:
def ZB20_Txx_Tyy(x):
    '''
    x is the vector containing three input features: 
    sh_xy(interpolated to center), sh_xx (center), vort_xy(interpolated_to center)
    They are stacked together to form 3*9 input features
    9 consecutive points describe the stencil 3x3 where 
    the fast index corresponds to zonal ("x") direction (as in MOM6) 
    
    This function is NN, being a part of mapping T = delta^2 * ||X||^2 * NN(X / ||X||)
    with norm being second norm, and delta is the grid step.
    
    The output vector y contains values Tdd, Ttr, where
    T = [Tdd+Ttr; 0; 0, -Tdd+Ttr]
    du/dt = -div(T)
    
    Note that the leftmost dimension is batch, i.e.
    x(:,12), y(:,3)
    '''
    if x.shape[1] != 27:
        print('Error')
        return
    
    vort_xy = x[:,4+9*2].reshape(-1,1)
    sh_xx = x[:,4+9*1].reshape(-1,1)
    sh_xy = x[:,4].reshape(-1,1)
    
    y = torch.zeros(x.shape[0],2)
    y[:,0:1] = 0.5 * (sh_xx**2 + (vort_xy-sh_xy)**2)
    y[:,1:2] = 0.5 * (sh_xx**2 + (vort_xy+sh_xy)**2)
    return y

# Testing

In [4]:
D = torch.randn(1)
D_hat = torch.randn(1)
zeta = torch.randn(1)
x = torch.tensor([[D]*9 + [D_hat]*9 + [zeta]*9])

In [5]:
Tdd = -zeta*D
Ttr = 0.5*(D**2+D_hat**2+zeta**2)
Txy = zeta * D_hat
Txx = Ttr + Tdd
Tyy = Ttr - Tdd
T = torch.tensor([[Txx, Tyy, Txy]])

In [6]:
print(ZB20_Txx_Tyy(x), ZB20_Txy(x))

tensor([[0.0402, 0.0592]]) tensor([[0.0108]])


In [7]:
print(T)

tensor([[0.0402, 0.0592, 0.0108]])


# Training ANN on ZB20 data
Here we generate input features from Gaussian distribution and normalize by their L2 norm.
We will do the same on inferent. The output normalization constant is needed only for normalization of MSE loss. We will pass it to the loss function, but for convenience will not use as part of the model

In [8]:
def noise_on_a_unit_sphere(Nsamples, Nfeatures):
    x = torch.randn(Nsamples, Nfeatures)
    norm = torch.sqrt((x**2).sum(dim=1, keepdims=True))
    return x / norm

# Txy part of the model

In [9]:
Nsamples = 10000000
x_train = noise_on_a_unit_sphere(Nsamples, 27)
y_train = ZB20_Txy(x_train)

x_test = noise_on_a_unit_sphere(Nsamples, 27)
y_test = ZB20_Txy(x_test)

output_norm = float(torch.sqrt((y_train**2).mean()))

In [165]:
ann = ANN(layer_sizes=[27,20,1], output_norm=output_norm)

In [166]:
train(ann, x_train, y_train, x_test, y_test, 20, 1000, 1e-3, print_frequency=1)

Training starts on device cpu, number of samples 10000000
[1/20] [14.12/268.23] Loss: [0.027752, 0.002781]
[2/20] [13.74/250.73] Loss: [0.001920, 0.001005]
[3/20] [13.94/236.86] Loss: [0.000887, 0.000921]
[4/20] [13.96/223.03] Loss: [0.000834, 0.000857]
[5/20] [14.05/209.41] Loss: [0.000829, 0.000852]
[6/20] [13.86/195.22] Loss: [0.000825, 0.000820]
[7/20] [13.91/181.21] Loss: [0.000824, 0.000842]
[8/20] [14.11/167.53] Loss: [0.000822, 0.000811]
[9/20] [22.58/164.10] Loss: [0.000821, 0.000791]
[10/20] [15.33/149.60] Loss: [0.000820, 0.001142]
[11/20] [14.31/134.11] Loss: [0.000717, 0.000720]
[12/20] [14.12/118.68] Loss: [0.000717, 0.000712]
[13/20] [14.04/103.42] Loss: [0.000717, 0.000711]
[14/20] [13.82/88.24] Loss: [0.000716, 0.000718]
[15/20] [14.04/73.31] Loss: [0.000715, 0.000714]
[16/20] [13.95/58.47] Loss: [0.000702, 0.000703]
[17/20] [13.99/43.74] Loss: [0.000702, 0.000702]
[18/20] [13.96/29.09] Loss: [0.000700, 0.000701]
[19/20] [13.99/14.52] Loss: [0.000700, 0.000701]
[20/20]

In [14]:
def R2(target, pred):
    return float(1 - ((target-pred)**2).sum() / (target**2).sum())

In [168]:
R2(y_test, ann(x_test))

0.999299943447113

In [169]:
print(ann(x_test[0:5]).data)
print(ZB20_Txy(x_test[0:5]).data)

tensor([[-0.0616],
        [ 0.0092],
        [-0.0153],
        [-0.0035],
        [ 0.0701]])
tensor([[-0.0615],
        [ 0.0103],
        [-0.0144],
        [-0.0035],
        [ 0.0692]])


In [170]:
export_ANN(ann, input_norms=torch.ones(27), output_norms=torch.ones(1), 
           filename='trained_models/ANN_Txy_ZB-small.nc')

x_test =  [ 1.125071   -0.66410357  0.03942366 -1.1223322   1.3186445   1.4001092
  0.9710138   1.4577577  -0.58660185  1.0608376   0.5968267  -0.35395458
 -0.7077333   0.8298185   0.6564303  -0.7864476   0.2546004  -1.2482073
  2.5276103  -1.097128    1.9325942  -0.06448472 -0.21903196 -2.2244322
  0.532548   -1.0340395  -2.4342656 ]
y_test =  [-0.14551666]


# Txx, Tyy part of the model

In [21]:
Nsamples = 10000000
x_train = noise_on_a_unit_sphere(Nsamples, 27)
y_train = ZB20_Txx_Tyy(x_train)

x_test = noise_on_a_unit_sphere(Nsamples, 27)
y_test = ZB20_Txx_Tyy(x_test)

output_norm = float(torch.sqrt((y_train**2).mean()))

In [22]:
ann = ANN(layer_sizes=[27,20,2], output_norm=output_norm)

In [23]:
train(ann, x_train, y_train, x_test, y_test, 20, 1000, 1e-3, print_frequency=1)

Training starts on device cpu, number of samples 10000000
[1/20] [15.65/297.43] Loss: [0.040287, 0.006444]
[2/20] [15.32/278.80] Loss: [0.006143, 0.006009]
[3/20] [14.72/258.97] Loss: [0.005661, 0.004934]
[4/20] [15.10/243.22] Loss: [0.004255, 0.003412]
[5/20] [14.74/226.62] Loss: [0.003443, 0.003418]
[6/20] [14.97/211.19] Loss: [0.003442, 0.003495]
[7/20] [16.52/198.76] Loss: [0.003439, 0.003432]
[8/20] [14.97/182.99] Loss: [0.003435, 0.003428]
[9/20] [14.57/166.91] Loss: [0.002855, 0.002320]
[10/20] [14.64/151.21] Loss: [0.002287, 0.002411]
[11/20] [15.69/136.55] Loss: [0.002183, 0.002181]
[12/20] [15.21/121.40] Loss: [0.002183, 0.002187]
[13/20] [17.03/107.22] Loss: [0.002183, 0.002188]
[14/20] [18.74/93.37] Loss: [0.002182, 0.002179]
[15/20] [15.91/77.93] Loss: [0.002182, 0.002185]
[16/20] [14.70/62.12] Loss: [0.002168, 0.002172]
[17/20] [15.56/46.60] Loss: [0.002168, 0.002171]
[18/20] [16.99/31.23] Loss: [0.002166, 0.002170]
[19/20] [15.84/15.62] Loss: [0.002166, 0.002170]
[20/20]

In [24]:
R2(y_test, ann(x_test))

0.9978306293487549

In [25]:
export_ANN(ann, input_norms=torch.ones(27), output_norms=torch.ones(2), 
           filename='trained_models/ANN_Txx_Tyy_ZB-small-retrain.nc')

x_test =  [ 0.5276723  -0.56469125 -0.02669321 -1.3318404  -0.2089864   0.4190437
  1.7333885  -0.95261437  0.36575687  0.5111949  -1.069809   -0.25853646
 -0.61322135  1.0860585   0.56493235  0.3940468  -0.28736952  1.063206
  0.56260747 -1.4229612   1.4992539   0.8398847   2.1048677  -0.0810979
 -1.1078862  -0.46300948 -1.2692279 ]
y_test =  [1.5305986 1.3409165]
