Ref Paper: [FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS](https://arxiv.org/pdf/2010.08895) \

Ref Code: https://github.com/neuraloperator/neuraloperator/blob/master/fourier_3d.py

# Imports

In [48]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import functools
import operator
import pandas as pd
from tqdm import tqdm
import h5py
import math
import copy
import scipy
import pickle
from timeit import default_timer
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
import torch.nn.functional 
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import scipy.io

%matplotlib inline

In [69]:
units = {
    0: 'B',
    1: 'KiB',
    2: 'MiB',
    3: 'GiB',
    4: 'TiB'
}


def format_mem(x):
    """
    Takes integer 'x' in bytes and returns a number in [0, 1024) and
    the corresponding unit.

    """
    if abs(x) < 1024:
        return round(x, 2), 'B'

    scale = math.log2(abs(x)) // 10
    scaled_x = x / (1024 ** scale)
    unit = units[scale]

    if int(scaled_x) == scaled_x:
        return int(scaled_x), unit

    # rounding leads to 2 or fewer decimal places, as required
    return round(scaled_x, 2), unit


def format_tensor_size(x):
    val, unit = format_mem(x)
    return f'{val}{unit}'


In [70]:
class SpectralConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super(SpectralConv3d, self).__init__()

        """
        3D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels  
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul3d(self, input, weights):
        # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        # summation along in_channel 
        return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

    def forward(self, x):
        # x = [batchsize, width, size_x, size_y, T + padding]
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1]) 
        # [batchsize, width, size_x, size_y, if (T + padding) is even ((T + padding)/2 +1) else (T + padding)/2 ]
        
        # Multiply relevant Fourier modes ( Corners or R)
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)  # upper right
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) # upper left
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) # lower right
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) # lower left

        #Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) # x = [batchsize, width, size_x, size_y, T + padding]
        return x

class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP, self).__init__()
        self.mlp1 = nn.Conv3d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv3d(mid_channels, out_channels, 1)

    def forward(self, x):
        # input: [batchsize, in_channel=width, size_x, size_y, T + padding]
        # weight: [mid_channel=width, in_channel=width, 1,1,1]
        # output: [batchsize, out_channel=mid_channel, size_x, size_y, T + padding]
        x = self.mlp1(x)
        x = torch.nn.Functional.gelu(x)
        # output: [batchsize, out_channel=mid_channel, size_x, size_y, T + padding]
        x = self.mlp2(x)
        # input: [batchsize, in_channel=mid_channel, size_x, size_y, T + padding]
        # weight: [out_channel=width, mid_channel=width, 1, 1, 1]
        # output: [batchsize, out_channel=width, size_x, size_y, T + padding]
        return x

class FNO3d(nn.Module):
    def __init__(self, modes1, modes2, modes3, width):
        super(FNO3d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y),  x, y, t).
        It's a constant function in time, except for the last index.
        input shape: (batchsize, x=64, y=64, t=40, c=13)
        output: the solution of the next 40 timesteps
        output shape: (batchsize, x=64, y=64, t=40, c=1)
        """

        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.padding = 6 # pad the domain if input is non-periodic
        
        # x = (batchsize, x=64, y=64, t=40, c=13)
        # input channel is 13: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y),  x, y, t)
        self.p = nn.Linear(13, self.width)
        self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.mlp0 = MLP(self.width, self.width, self.width)
        self.mlp1 = MLP(self.width, self.width, self.width)
        self.mlp2 = MLP(self.width, self.width, self.width)
        self.mlp3 = MLP(self.width, self.width, self.width)
        self.w0 = nn.Conv3d(self.width, self.width, 1)
        self.w1 = nn.Conv3d(self.width, self.width, 1)
        self.w2 = nn.Conv3d(self.width, self.width, 1)
        self.w3 = nn.Conv3d(self.width, self.width, 1)
        self.q = MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device) # [batchsize, size_x, size_y, T, c=T_in] ---> [batchsize, size_x, size_y, T, c=3]
        x = torch.cat((x, grid), dim=-1)        # [batchsize, size_x, size_y, T, c=T_in+3]
        x = self.p(x)                           
        # input: [batchsize, size_x, size_y, T, c=T_in+3], 
        # Weight: [width,T_in+3]
        # Output: [batchsize, size_x, size_y, T, c=width]
        
        x = x.permute(0, 4, 1, 2, 3)            # [batchsize, size_x, size_y, T, c=width] --> [batchsize, width, size_x, size_y, T]
        x = torch.nn.functional.pad(x, [0,self.padding]) # pad the domain if input is non-periodic 
        # (padding_left,padding_right, padding_top,padding_bottompadding_top,padding_bottom padding_front,padding_back)padding_front,padding_back)
        # [batchsize, width, size_x, size_y, T+ padding]
        
        x1 = self.conv0(x) # SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        
        # input: [batchsize, width, size_x, size_y, T + padding]
        # weight: torch.rand(in_channels=width, out_channels=width, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)
        # Output: [batchsize, out_channel=width, size_x, size_y, T + padding]
        
        x1 = self.mlp0(x1) # MLP(self.width, self.width, self.width)
        # input: [batchsize, in_channel=width, size_x, size_y, T + padding]
        # output: [batchsize, out_channel=width, size_x, size_y, T + padding]
        
        x2 = self.w0(x)   # nn.Conv3d(self.width, self.width, 1)
        # input: [batchsize, in_channel=width, size_x, size_y, T + padding]
        # weight: [out_channel=width, in_channel=width, 1, 1,1]
        # output: [batchsize, out_channel=width, size_x, size_y, T + padding]
        
        x = x1 + x2
        x = torch.nn.functional.gelu(x)

        x1 = self.conv1(x)
        x1 = self.mlp1(x1)
        x2 = self.w1(x)
        x = x1 + x2
        x = torch.nn.functional.gelu(x)

        x1 = self.conv2(x)
        x1 = self.mlp2(x1)
        x2 = self.w2(x)
        x = x1 + x2
        x = torch.nn.functional.gelu(x)

        x1 = self.conv3(x)
        x1 = self.mlp3(x1)
        x2 = self.w3(x)
        x = x1 + x2
        # output: [batchsize, out_channel=width, size_x, size_y, T + padding]
        
        x = x[..., :-self.padding]
        # output: [batchsize, out_channel=width, size_x, size_y, T]
        
        x = self.q(x) # MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y)
        
        # input: [batchsize, in_channel=width, size_x, size_y, T ]
        # weight: [mid_channel=4*width, in_channel=width, 1,1,1]
        # output: [batchsize, out_channel=mid_channel=4*width, size_x, size_y, T ]
        # x = self.mlp1(x)
        # x = torch.nn.Functional.gelu(x)
        # output: [batchsize, out_channel=mid_channel=4*width, size_x, size_y, T]
        # x = self.mlp2(x)
        # input: [batchsize, in_channel=mid_channel=4*width, size_x, size_y, T]
        # weight: [out_channel=1, mid_channel=4*width, 1, 1, 1]
        # output: [batchsize, out_channel=1, size_x, size_y, T]
        
        x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic
        # output: [batchsize, size_x, size_y, T, out_channel=1]
        return x


    def get_grid(self, shape, device):
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        
        return torch.cat((gridx, gridy, gridz), dim=-1).to(device) # [batchsize, size_x, size_y, size_z, 3]
    
    def print_size(self):
        properties = []

        for param in self.parameters():
            properties.append([list(param.size()+(2,) if param.is_complex() else param.size()), param.numel(), (param.data.element_size() * param.numel())/1000])
            
        elementFrame = pd.DataFrame(properties, columns = ['ParamSize', 'NParams', 'Memory(KB)'])
 
        print(f'Total number of model parameters: {elementFrame["NParams"].sum()} with (~{format_tensor_size(elementFrame["Memory(KB)"].sum()*1000)})')
        return elementFrame
    
    


In [74]:
x = torch.rand(1,20,64,64,40)
x = torch.nn.functional.pad(x, [0,6])
x.shape

torch.Size([1, 20, 64, 64, 46])

In [71]:
model = FNO3d(8, 8, 8, 20)

In [72]:
model.print_size()

Total number of model parameters: 3283881 with (~25.03MiB)


Unnamed: 0,ParamSize,NParams,Memory(KB)
0,"[20, 13]",260,1.04
1,[20],20,0.08
2,"[20, 20, 8, 8, 8, 2]",204800,1638.4
3,"[20, 20, 8, 8, 8, 2]",204800,1638.4
4,"[20, 20, 8, 8, 8, 2]",204800,1638.4
5,"[20, 20, 8, 8, 8, 2]",204800,1638.4
6,"[20, 20, 8, 8, 8, 2]",204800,1638.4
7,"[20, 20, 8, 8, 8, 2]",204800,1638.4
8,"[20, 20, 8, 8, 8, 2]",204800,1638.4
9,"[20, 20, 8, 8, 8, 2]",204800,1638.4
