In [44]:
"""
@author: Zongyi Li
This file is the Fourier Neural Operator for 2D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf),
which uses a recurrent structure to propagates in time.
"""

import os
os.environ['OMP_NUM_THREADS'] = '2'
os.environ['export OPENBLAS_NUM_THREADS']='2'

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from utilities3 import *

import operator
from functools import reduce
from functools import partial

from timeit import default_timer

from Adam import Adam

torch.manual_seed(0)
np.random.seed(0)


In [45]:
################################################################
# 3d fourier layers
################################################################

class SpectralConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super(SpectralConv3d, self).__init__()

        """
        3D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul3d(self, input, weights):
        # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1])
        
        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        #Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x

class FNO3d(nn.Module):
    def __init__(self, modes1, modes2, modes3, width):
        super(FNO3d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y),  x, y, t). It's a constant function in time, except for the last index.
        input shape: (batchsize, x=64, y=64, t=40, c=13)
        output: the solution of the next 40 timesteps
        output shape: (batchsize, x=64, y=64, t=40, c=1)
        """

        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.padding = 6 # pad the domain if input is non-periodic
        #self.fc0 = nn.Linear(13, self.width)
        self.fc0 = nn.Linear(4, self.width)
        # input channel is 12: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y),  x, y, t)

        self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.w0 = nn.Conv3d(self.width, self.width, 1)
        self.w1 = nn.Conv3d(self.width, self.width, 1)
        self.w2 = nn.Conv3d(self.width, self.width, 1)
        self.w3 = nn.Conv3d(self.width, self.width, 1)
        self.bn0 = torch.nn.BatchNorm3d(self.width)
        self.bn1 = torch.nn.BatchNorm3d(self.width)
        self.bn2 = torch.nn.BatchNorm3d(self.width)
        self.bn3 = torch.nn.BatchNorm3d(self.width)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 4, 1, 2, 3)
        x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., :-self.padding]
        x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        return torch.cat((gridx, gridy, gridz), dim=-1).to(device)

In [46]:
################################################################
# configs
################################################################
DATA_PATH = 'Solutions/solutions_total.npy'

# currently data are 110 samples
ntrain = 80
ntest = 30

modes = 8
width = 20

batch_size = 10

epochs = 100 #500
learning_rate = 0.001
scheduler_step = 100
scheduler_gamma = 0.5

print(epochs, learning_rate, scheduler_step, scheduler_gamma)

path = 'nvs_cylinder_3d_ep100_Tin1'
path_model = 'model/'+path
path_train_err = 'results/'+path+'train.txt'
path_test_err = 'results/'+path+'test.txt'
path_image = 'image/'+path

runtime = np.zeros(2, )
t1 = default_timer()

S1 = 78
S2 = 438
T_in = 1
T = 19
step = 1

100 0.001 100 0.5


In [47]:
################################################################
# load data
################################################################
data_gen = np.load(DATA_PATH)

train_a = torch.tensor(data_gen[:ntrain,:,:,:T_in], dtype=torch.float)
train_u = torch.tensor(data_gen[:ntrain,:,:,T_in:T+T_in], dtype=torch.float)

test_a = torch.tensor(data_gen[-ntest:,:,:,:T_in], dtype=torch.float)
test_u = torch.tensor(data_gen[-ntest:,:,:,T_in:T+T_in], dtype=torch.float)

print(train_u.shape)
print(test_u.shape)
assert (S1 == train_u.shape[-3])
assert (S2 == train_u.shape[-2])
assert (T == train_u.shape[-1])


a_normalizer = UnitGaussianNormalizer(train_a)
train_a = a_normalizer.encode(train_a)
test_a = a_normalizer.encode(test_a)

y_normalizer = UnitGaussianNormalizer(train_u)
train_u = y_normalizer.encode(train_u)

train_a = train_a.reshape(ntrain,S1,S2,1,T_in).repeat([1,1,1,T,1])
test_a = test_a.reshape(ntest,S1,S2,1,T_in).repeat([1,1,1,T,1])

# train_a = train_a.reshape(ntrain,S1,S2,T_in)
# test_a = test_a.reshape(ntest,S1,S2,T_in)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)

t2 = default_timer()

print('preprocessing finished, time used:', t2-t1)
device = torch.device('cuda')

torch.Size([80, 78, 438, 19])
torch.Size([30, 78, 438, 19])
preprocessing finished, time used: 0.5407307169807609


In [48]:
################################################################
# training and evaluation
################################################################
model = FNO3d(modes, modes, modes, width).cuda()

print(count_params(model))
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)


myloss = LpLoss(size_average=False)
y_normalizer.cuda()
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x).view(batch_size, S1, S2, T)

        mse = F.mse_loss(out, y, reduction='mean')
        # mse.backward()

        y = y_normalizer.decode(y)
        out = y_normalizer.decode(out)
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward()

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()

            out = model(x).view(batch_size, S1, S2, T)
            out = y_normalizer.decode(out)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_mse /= len(train_loader)
    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print(ep, t2-t1, train_mse, train_l2, test_l2)
torch.save(model, path_model)

6558357
0 10.900625198992202 0.9827209897339344 0.15418877303600312 0.15433122714360556
1 10.657516494014999 1.0067685525864363 0.1475578159093857 0.15104235808054606
2 10.572039994993247 1.0512401759624481 0.1472574219107628 0.15086035331090292
3 10.556920550006907 1.047810886055231 0.1469679519534111 0.15082954565684
4 10.560756140999729 1.0284593179821968 0.1467852920293808 0.15107063055038453
5 10.392515241983347 1.0271881707012653 0.14656072109937668 0.15024382670720418
6 10.246397561975755 1.037585249170661 0.1460908345878124 0.14987285931905112
7 10.491902957001003 1.0539716016501188 0.14601044952869416 0.14880243937174478
8 10.417820462025702 1.0650967918336391 0.14550287425518035 0.14820261001586915
9 10.29135681700427 1.04983033798635 0.14538637548685074 0.14762118260065715
10 10.288151417014888 1.0469951704144478 0.14531975239515305 0.14783263206481934
11 10.403090606996557 1.0560277234762907 0.14495411962270738 0.14681551853815714
12 10.21029408898903 1.069298017770052 0.14

KeyboardInterrupt: 

In [None]:
pred = torch.zeros(test_u.shape)

index = 0
# model = torch.load("model/nvs_cylinder_3d_ep100")
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False)
first_output = 0
with torch.no_grad():
    for x, y in test_loader:
        test_l2 = 0
        if index == 1:
            first_output = y.clone()
        x, y = x.cuda(), y.cuda()
        out = model(x)
        out = y_normalizer.decode(out[:,:,:,0])
        pred[index] = out

        test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
        print(index, test_l2)
        index = index + 1


# for i in range(15,18):
#     cp = plt.imshow(pred[1,:,:,i])
#     plt.colorbar(cp)
#     plt.show()

#     cp = plt.imshow(first_output[0,:,:,i])
#     plt.colorbar(cp)
#     plt.show()

#     cp = plt.imshow(abs(pred[1,:,:,i]-first_output[0,:,:,i]))
#     plt.colorbar(cp)
#     plt.show()

# cp = plt.imshow(pred[0,:,:,9])
# plt.colorbar(cp)
# plt.show()
# cp = plt.imshow(pois_output[0,:,:,0])
# plt.colorbar(cp)
# plt.show()

# cp = plt.imshow(abs(pred[index][:,:,0] - pois_output[0,:,:,0]))
# plt.colorbar(cp)
# plt.show()


0 0.11792492866516113
1 0.09230264276266098
2 0.08759330958127975
3 0.1204468160867691
4 0.08817866444587708
5 0.049013327807188034
6 0.08691585808992386
7 0.07559073716402054
8 0.10876954346895218
9 0.09856343269348145
10 0.06670535355806351
11 0.08798898011445999
12 0.0878012627363205
13 0.09502517431974411
14 0.09636951237916946
15 0.08967442065477371
16 0.08712951093912125
17 0.1478944569826126
18 0.1720321625471115
19 0.11056836694478989
20 0.08735814690589905
21 0.08703753352165222
22 0.06481015682220459
23 0.0888841301202774
24 0.19038017094135284
25 0.08927062153816223
26 0.06854244321584702
27 0.10525257140398026
28 0.08702849596738815
29 0.08736784756183624
