In [34]:
import torch
import numpy as np
import sklearn.preprocessing
import xarray as xr

from pathlib import Path

In [35]:
datapath = Path("./layer")

ds = xr.open_dataset(datapath / "prog.nc").isel(yh=0, xq=slice(0,-1), Time=slice(-40,None))
grid = xr.open_dataset(datapath / "ocean_geometry.nc")

In [36]:
device = torch.device("cpu")
if torch.backends.mps.is_available(): # Apple Silicon
    device = torch.device("mps")
if torch.cuda.is_available(): # Nvidia GPU
    device = torch.device("cuda")


In [37]:
def preprocess_dataarray(dataarray):
    ntime, nk, nx = dataarray.shape
    array = dataarray.data.reshape(ntime*nk, -1)
    array = np.nan_to_num(array, nan=0., posinf=0., neginf=0.)
    array = array.astype(np.float32)
    return array

def create_scaler_dataarray(dataarray, Scaler):
    scaler = Scaler()
    array = scaler.fit_transform(dataarray)
    return array, scaler

def tensorize(array, device):
    return torch.from_numpy(array).unsqueeze(1).to(device)

preprocess_tensorize = lambda dataarray, device: tensorize(preprocess_dataarray(dataarray), device)
preprocess_scale = lambda dataarray, Scaler: create_scaler_dataarray(
    preprocess_dataarray(dataarray),
    Scaler
)

In [38]:
Scaler = sklearn.preprocessing.MinMaxScaler

u, u_scaler = preprocess_scale(ds.ubt_corrector, Scaler)
u = tensorize(u, device)

h_start, h_scaler = preprocess_scale(ds.h_continuity_start, Scaler)
h_start = tensorize(h_start, device)

# uh, uh_scaler = preprocess_scale(ds.uh, Scaler)
uh = tensorize(preprocess_dataarray(ds.uh), device)


In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(in_channels=4, out_channels=2, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
        # self.conv3 = nn.Conv1d(in_channels=4, out_channels=1, kernel_size=1, bias=True)
        self.activation = nn.ReLU
        self.fc1 = nn.Linear(in_features=4, out_features=2)
        self.fc2 = nn.Linear(in_features=2, out_features=1)
        self.scale = nn.Parameter(torch.ones(1))

    def forward(self, u, h):
        # First create linear combinations of u and h
        x = torch.cat((u, h), dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        # x = x*self.scale
        # x = x.permute(0,2,1)
        # x = self.activation()(self.fc1(x))
        # x = self.activation()(self.fc2(x))
        # x = x.permute(0,2,1)
        return x

net = Net().to(device)
net(u, h_start).shape

torch.Size([1600, 1, 80])

In [60]:
# Train network on fluxes directly
loss_fluxes = nn.MSELoss()
net_fluxes = Net().to(device)
optimizer = torch.optim.Adam(net_fluxes.parameters(), lr=1e-3)


In [61]:

net_fluxes.train()
for i in range(10000):
    optimizer.zero_grad()
    uh_pred = net_fluxes(u, h_start)
    loss = loss_fluxes(uh, uh_pred)
    loss.backward()
    optimizer.step()

    if (i-1) % 100 == 0:
        print(loss.item(), end="\r")


22814439424.0

In [75]:
d

In [312]:
class FiniteDifferenceNet(nn.Module):
    def __init__(self):
        super(FiniteDifferenceNet, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, padding=0, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        # return x[..., 1:-1]
        return x

def periodic_differences(x):
    # Forward
    forward_diff = torch.roll(x, -1, dims=1) - x
    # Backward
    backward_diff = x - torch.roll(x, 1, dims=1)
    # Central
    central_diff = 0.5*(torch.roll(x, -1, dims=1) - torch.roll(x, 1, dims=1))

    return forward_diff[..., 1:-1], central_diff[..., 1:-1], backward_diff[..., 1:-1]

In [335]:
a = torch.rand(100,10)
a[0] = a[-1]
gradx = periodic_differences(a)
fd_true = torch.stack(gradx, dim=1)

In [336]:
fd_net = FiniteDifferenceNet().to(device)
fd_loss = nn.MSELoss()
optimizer = torch.optim.SGD(fd_net.parameters(), lr=1e-3)

fd_true = fd_true.to(device)
a_channel = a.unsqueeze(dim=1).to(device)


In [339]:

fd_net.train()
for i in range(100000):
    optimizer.zero_grad()
    fd_pred = fd_net(a_channel)
    loss = fd_loss(fd_pred, fd_true)
    loss.backward()
    optimizer.step()

    if (i-1) % 1000 == 0:
        print(loss.item(), end="\n")
        print(fd_net.conv1.weight.data)
print(loss.item())
print(fd_net.conv1.weight.data)

2.6550097231847758e-08
tensor([[[-9.3126e-06, -9.9952e-01,  9.9954e-01]],

        [[-4.9976e-01,  1.7119e-05,  4.9974e-01]],

        [[-9.9952e-01,  9.9963e-01, -8.5419e-05]]], device='mps:0')
2.654774711174923e-08
tensor([[[-9.3015e-06, -9.9952e-01,  9.9954e-01]],

        [[-4.9976e-01,  1.7082e-05,  4.9974e-01]],

        [[-9.9952e-01,  9.9963e-01, -8.5343e-05]]], device='mps:0')
2.6545464493210602e-08
tensor([[[-9.2926e-06, -9.9952e-01,  9.9954e-01]],

        [[-4.9976e-01,  1.7048e-05,  4.9974e-01]],

        [[-9.9952e-01,  9.9963e-01, -8.5273e-05]]], device='mps:0')
2.6545489362206354e-08
tensor([[[-9.2856e-06, -9.9952e-01,  9.9954e-01]],

        [[-4.9976e-01,  1.7017e-05,  4.9974e-01]],

        [[-9.9952e-01,  9.9963e-01, -8.5211e-05]]], device='mps:0')
2.6541538744595528e-08
tensor([[[-9.2712e-06, -9.9952e-01,  9.9954e-01]],

        [[-4.9976e-01,  1.6989e-05,  4.9974e-01]],

        [[-9.9952e-01,  9.9963e-01, -8.5152e-05]]], device='mps:0')
2.653937158925146e-08
tens

In [316]:
fd_net.conv1.weight.data

tensor([[[ 0.0117, -0.9913,  0.9801]],

        [[-0.4875, -0.0025,  0.4902]],

        [[-0.9818,  0.9969, -0.0148]]], device='mps:0')

In [293]:
conv1d_forward_diff = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=2, padding=0, bias=False)

# Set the weights to represent the forward finite difference
with torch.no_grad():
    conv1d_forward_diff.weight = nn.Parameter(torch.tensor([[[-1.0, 1.0]]]))
conv1d_forward_diff.weight
conv1d_forward_diff = conv1d_forward_diff.to(device)

In [295]:
print(fd_net(a_channel)[0,0,:])
print(conv1d_forward_diff(a_channel)[0,0,1:])
print(fd_true[0,0,:])

tensor([ 1.7936, -1.7214,  1.3931,  0.4448, -0.2927,  0.4275, -0.8092, -0.3350],
       device='mps:0', grad_fn=<SliceBackward0>)
tensor([ 0.9173, -0.9518,  0.7136,  0.1781, -0.2131,  0.1715, -0.4759, -0.2001],
       device='mps:0', grad_fn=<SliceBackward0>)
tensor([ 0.9173, -0.9518,  0.7136,  0.1781, -0.2131,  0.1715, -0.4759, -0.2001],
       device='mps:0')


In [291]:
fd_loss(conv1d_forward_diff(a_channel)[...,1:], fd_true)

tensor(0., device='mps:0', grad_fn=<MseLossBackward0>)

In [292]:
fd_true[0,0,:]

tensor([ 0.9173, -0.9518,  0.7136,  0.1781, -0.2131,  0.1715, -0.4759, -0.2001],
       device='mps:0')