In [1]:
import numpy as np
import matplotlib.pyplot as plt
from timeit import default_timer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn as nn
import torch.nn.functional as F

#SOBOLEV NORM

In [3]:
# Sobolev norm (HS norm)
# where we also compare the numerical derivatives between the output and target
class HsLoss(object):
    def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True):
        super(HsLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.k = k
        self.balanced = group
        self.reduction = reduction
        self.size_average = size_average

        if a == None:
            a = [1,] * k
        self.a = a

    def rel(self, x, y):
        num_examples = x.size()[0]
        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)
        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)
        return diff_norms/y_norms

    def __call__(self, x, y, a=None):
        nx = x.size()[1]
        ny = x.size()[2]
        k = self.k
        balanced = self.balanced
        a = self.a
        x = x.view(x.shape[0], nx, ny, -1)
        y = y.view(y.shape[0], nx, ny, -1)

        k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),
                         torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny)
        k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),
                         torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1)
        k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device)
        k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device)

        x = torch.fft.fftn(x, dim=[1, 2])
        y = torch.fft.fftn(y, dim=[1, 2])

        if balanced==False:
            weight = 1
            if k >= 1:
                weight += a[0]**2 * (k_x**2 + k_y**2)
            if k >= 2:
                weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
            weight = torch.sqrt(weight)
            loss = self.rel(x*weight, y*weight)
        else:
            loss = self.rel(x, y)
            if k >= 1:
                weight = a[0] * torch.sqrt(k_x**2 + k_y**2)
                loss += self.rel(x*weight, y*weight)
            if k >= 2:
                weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4)
                loss += self.rel(x*weight, y*weight)
            loss = loss / (k+1)

        return loss


In [4]:
x = torch.randn(20,32,128,128)
y = torch.randn(20,32,128,128)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
PATH = '/content/drive/MyDrive/Parallel_PDE_project/fourier_neural_operator'
os.chdir(PATH)
import sys  
sys.path.insert(0, '/content/drive/MyDrive/Parallel_PDE_project/fourier_neural_operator/Notebooks')

In [None]:
from utils import *
from Adam import * 
from FNO import Lifting, Proj, set_activ

In [None]:
def anderson(f, x0, m=5, lam=1e-4, max_iter=50, tol=1e-2, beta = 1.0):
    """ Anderson acceleration for fixed point iteration. """
    bsz, d, H, W = x0.shape
    X = torch.zeros(bsz, m, d*H*W, dtype=x0.dtype, device=x0.device)
    F = torch.zeros(bsz, m, d*H*W, dtype=x0.dtype, device=x0.device)
    X[:,0], F[:,0] = x0.reshape(bsz, -1), f(x0).view(bsz, -1)
    X[:,1], F[:,1] = F[:,0], f(F[:,0].view_as(x0)).view(bsz, -1)
    
    H = torch.zeros(bsz, m+1, m+1, dtype=x0.dtype, device=x0.device)
    H[:,0,1:] = H[:,1:,0] = 1
    y = torch.zeros(bsz, m+1, 1, dtype=x0.dtype, device=x0.device)
    y[:,0] = 1
    
    res = []
    for k in range(2, max_iter):
        n = min(k, m)
        G = F[:,:n]-X[:,:n]
        H[:,1:n+1,1:n+1] = torch.bmm(G,G.transpose(1,2)) + lam*torch.eye(n, dtype=x0.dtype,device=x0.device)[None]
        alpha = torch.solve(y[:,:n+1], H[:,:n+1,:n+1])[0][:, 1:n+1, 0]   # (bsz x n)
        
        X[:,k%m] = beta * (alpha[:,None] @ F[:,:n])[:,0] + (1-beta)*(alpha[:,None] @ X[:,:n])[:,0]
        F[:,k%m] = f(X[:,k%m].view_as(x0)).view(bsz, -1)
        res.append((F[:,k%m] - X[:,k%m]).norm().item()/(1e-5 + F[:,k%m].norm().item()))
        if (res[-1] < tol):
            break
    return X[:,k%m].view_as(x0), res

In [None]:
import torch.autograd as autograd

class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.kwargs = kwargs
        
    def forward(self, x):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
        z = self.f(z,x)
        
        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = self.f(z0,x)
        def backward_hook(grad):
            g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
                                               grad, **self.kwargs)
            return g
                
        z.register_hook(backward_hook)
        return z

In [None]:
################################################################
# Spectral Convolution Layers 
# 1D and 2D
################################################################

class fourier_conv_1d(nn.Module):
  def __init__(self, in_, out_, wavenumber1):
    super(fourier_conv_1d, self).__init__()
    self.in_ = in_
    self.out_ = out_
    self.wavenumber1 = wavenumber1
    self.scale = (1 / (in_ * out_))
    self.weights1 = nn.Parameter(self.scale * torch.rand(in_, out_, self.wavenumber1, dtype= torch.complex128))
   
    # Complex multiplication
  def compl_mul1d(self, input, weights):
    # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
    return torch.einsum("bix,iox->box", input, weights)

  def forward(self, x):
    batchsize = x.shape[0]
    #Compute Fourier coeffcients up to factor of e^(- something constant)
    x_ft = torch.fft.rfft(x)

    # Multiply relevant Fourier modes
    out_ft = torch.zeros(batchsize, self.out_, x.size(-1)//2 + 1,  device=x.device, dtype= torch.complex128)
    out_ft[:, :, :self.wavenumber1] = self.compl_mul1d(x_ft[:, :, :self.wavenumber1], self.weights1)

    #Return to physical space
    x = torch.fft.irfft(out_ft, n=x.size(-1))
    return x

  def print(self):
    return f'FourierConv1d({self.in_}, {self.out_}, wavenumber={self.wavenumber1})'

class fourier_conv_2d(nn.Module):
  def __init__(self, in_, out_, wavenumber1, wavenumber2):
    super(fourier_conv_2d, self).__init__()
    self.in_ = in_
    self.out_ = out_
    self.wavenumber1 = wavenumber1
    self.wavenumber2 = wavenumber2
    self.scale = (1 / (in_ * out_))
    self.weights1 = nn.Parameter(self.scale * torch.rand(in_, out_, self.wavenumber1, self.wavenumber2, dtype= torch.complex128))
    self.weights2 = nn.Parameter(self.scale * torch.rand(in_, out_, self.wavenumber1, self.wavenumber2, dtype= torch.complex128))

    # Complex multiplication
  def compl_mul2d(self, input, weights):
    # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
    return torch.einsum("bixy,ioxy->boxy", input, weights)

  def forward(self, x):
    batchsize = x.shape[0]
    #Compute Fourier coeffcients up to factor of e^(- something constant)
    x_ft = torch.fft.rfft2(x)
    # Multiply relevant Fourier modes
    out_ft = torch.zeros(batchsize, self.in_,  x.size(-2), x.size(-1)//2 + 1, dtype= torch.complex128, device=x.device)
    out_ft[:, :, :self.wavenumber1, :self.wavenumber2] = \
        self.compl_mul2d(x_ft[:, :, :self.wavenumber1, :self.wavenumber2], self.weights1)
    out_ft[:, :, -self.wavenumber1:, :self.wavenumber2] = \
        self.compl_mul2d(x_ft[:, :, -self.wavenumber1:, :self.wavenumber2], self.weights2)
    #Return to physical space
    x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
    return x

  def print(self):
    return f'FourierConv2d({self.in_}, {self.out_}, wavenumber={self.wavenumber1, self.wavenumber2})'


In [None]:
################################################################
# Fourier Layer 
################################################################
class Fourier_layer(nn.Module):
  def __init__(self,  features_, wavenumber, activation = 'relu', is_last = False):
    super(Fourier_layer, self).__init__()
    self.is_last = is_last
    self.activation = activation.lower()
    self.features_ = features_
    self.wavenumber = wavenumber
    self.dim = len(wavenumber)
    self.W =  nn.Conv1d(features_, features_, 1).double() if self.dim==1 else nn.Conv2d(features_, features_, 1).double()
    self.fourier_conv = self.set_conv_dim()
    self.nonlinear = set_activ(activation)

  def set_conv_dim(self):
    if self.dim== 1:
      return  fourier_conv_1d(self.features_, self.features_, *self.wavenumber)
    elif self.dim== 2:
      return  fourier_conv_2d(self.features_, self.features_, *self.wavenumber)
   
  def forward(self, x):
        x1 = self.fourier_conv(x)
        x2 = self.W(x)
        x = x1 + x2
        if self.is_last == True:
          return x
        else:
          x = self.nonlinear(x)
          return x
            
  def __repr__(self):
    with torch.no_grad():
      return self.activation+'('+self.fourier_conv.print() +' + '+ self.W.__repr__()+')'

In [None]:
class ResNetLayerFourier(nn.Module):
    def __init__(self, features_, wavenumber, activation = 'relu'):
        super(ResNetLayerFourier, self).__init__()
        self.dim = len(wavenumber)
        self.activation = activation.lower()
        self.FourierLayer = Fourier_layer(features_, wavenumber, activation = self.activation)
        self.nonlinear = set_activ(activation)
        self.W =nn.Conv1d(features_, features_, 1).double() if self.dim==1 else nn.Conv2d(features_, features_, 1).double()
    def forward(self, z, x):
        return self.nonlinear(self.W(x) +self.FourierLayer(z))

In [None]:
from torch.autograd import gradcheck
# run a very small network with double precision, iterating to high precision
f = ResNetLayerFourier(32,[12,12]).double()
deq = DEQFixedPoint(f, anderson, tol=1e-5, max_iter=10).double()
gradcheck(deq, torch.randn(1,32,64,64).double().requires_grad_(), eps=1e-5, atol=1e-3, check_undefined_grad=False)


In [None]:
################################################################
# Lifting map
################################################################

class Lifting(nn.Module):
  def __init__(self, input, width, activation = 'relu'):
    super().__init__()
    self.fc1 = nn.Linear(input, width//2)
    self.nonlinear =set_activ(activation)
    self.fc2 = nn.Linear(width//2, width)
  def forward(self,x):
    x = self.fc1(x)
    x = self.nonlinear(x)
    x = self.fc2(x)
    return x

################################################################
# Projection map
################################################################

class Proj(nn.Module):
  def __init__(self,width1, width2=1, activation = 'relu'):
    super().__init__()
    self.fc1 = nn.Linear(width1, 128)
    self.fc2 = nn.Linear(128, width2)
    self.nonlinear =set_activ(activation)
  def forward(self,x):
    x = self.fc1(x)
    x = self.nonlinear(x)
    x = self.fc2(x)
    return x

In [None]:
################################################################
# FNO_DEQ map 1D and 2D
################################################################

class FNO_DEQ(nn.Module):
  def __init__(self, wavenumber, features_, 
                    padding = 9, 
                    activation= 'relu',
                    lifting = Lifting, 
                    proj = Proj, 
                    max_iter=100, 
                    beta=2.0):
    super(FNO_DEQ, self).__init__()
    self.wavenumber = wavenumber
    self.dim = len(wavenumber)
    self.activation = activation.lower() 
    self.padding = padding   
    self.features_ =features_
    self.lifting = lifting(self.dim+1, self.features_)
    self.fno = ResNetLayerFourier(features_ = self.features_, 
                                  wavenumber=self.wavenumber, 
                                  activation = self.activation)
    self.deq = DEQFixedPoint(self.fno, anderson, tol=1e-4, max_iter=max_iter, beta=beta)
    self.proj = proj(self.features_, 1)
    
  def forward(self, x):
    grid = self.get_grid2D(x.shape, x.device) if self.dim == 2 else self.get_grid1D(x.shape, x.device)
    x = torch.cat((x, grid), dim=-1)
    ####Lifting Map 
    x = self.lifting(x)
    ###Actual Neural Operator
    x = x.permute(0, 3, 1, 2) if self.dim == 2 else x.permute(0, 2, 1)
    x = F.pad(x, [0,self.padding, 0,self.padding]) if self.dim == 2 else F.pad(x, [0,self.padding])
    x = self.deq(x)
    x = x[..., :-self.padding, :-self.padding] if self.dim == 2 else x[..., :-self.padding]
    x = x.permute(0, 2, 3, 1) if self.dim == 2 else x.permute(0, 2, 1)
    ####Projection Map
    x =self.proj(x)
    return x
    
  def get_grid2D(self, shape, device):
    batchsize, size_x, size_y = shape[0], shape[1], shape[2]
    gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
    gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
    gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
    gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
    return torch.cat((gridx, gridy), dim=-1).to(device)

  def get_grid1D(self, shape, device):
    batchsize, size_x = shape[0], shape[1]
    gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
    gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
    return gridx.to(device)

In [None]:
################################################################
# load data and data normalization
################################################################
EqFile = 'DarcyFlow'
PDE_dir =  os.path.join(PATH, 'data', EqFile)
x_train = torch.load( os.path.join(PDE_dir, f'{EqFile}_x_train.pt')).double()
y_train = torch.load(os.path.join(PDE_dir, f'{EqFile}_y_train.pt')).double()
x_test = torch.load(os.path.join(PDE_dir, f'{EqFile}_x_test.pt')).double()
y_test= torch.load(os.path.join(PDE_dir, f'{EqFile}_y_test.pt')).double()

In [None]:
#Parameters
ntrain = x_train.shape[0]
ntest = x_test.shape[0]
s = x_test.shape[-2]
batch_size = 20

In [None]:
x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
x_test = x_normalizer.encode(x_test)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

x_train = x_train.reshape(ntrain,s,s,1)
#x_train = x_train.reshape(ntrain,s,1)
x_test = x_test.reshape(ntest,s,s,1)
#x_test = x_test.reshape(ntest,s,1)

In [None]:
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

In [None]:
loss_train = []
loss_test =  []
epoch_vec = []

In [None]:
#NN parameters
activ_vec = ['relu', 'tanh', 'sine', 'gelu']
activ = activ_vec[0]
layers = 5
learning_rate = 0.001

epochs = 100
step_size = epochs//5
gamma = 0.5

wavenumber = [12, 12]
features_ = 32

In [None]:
model = FNO_DEQ(wavenumber = wavenumber,
            features_ =features_, 
            activation = 'relu', 
            lifting = Lifting,
            proj = Proj,
            max_iter = 10).double().to(device)
model

FNO_DEQ(
  (lifting): Lifting(
    (fc1): Linear(in_features=3, out_features=16, bias=True)
    (fc2): Linear(in_features=16, out_features=32, bias=True)
  )
  (fno): ResNetLayerFourier(
    (FourierLayer): relu(FourierConv2d(32, 32, wavenumber=(12, 12)) + Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1)))
    (W): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (deq): DEQFixedPoint(
    (f): ResNetLayerFourier(
      (FourierLayer): relu(FourierConv2d(32, 32, wavenumber=(12, 12)) + Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1)))
      (W): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (proj): Proj(
    (fc1): Linear(in_features=32, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [None]:
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(train_loader), epochs=epochs)
scheduler= torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
myloss = LpLoss(size_average=False)
y_normalizer.to(device)

################################################################
# training and evaluation
for ep in range(epochs//2):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.double().to(device), y.double().to(device)

        optimizer.zero_grad()
        out = model(x).reshape(batch_size, s, s)
        #out = model(x).reshape(batch_size, s)
        out = y_normalizer.decode(out)
        y = y_normalizer.decode(y)

        loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()
    
    epoch_vec.append(ep)
    #e= epoch_vec[-1]
    #epoch_vec.append(e+1)
    
    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.double().to(device), y.double().to(device)
    
            out = model(x).reshape(batch_size, s, s)
            #out = model(x).reshape(batch_size, s)
            out = y_normalizer.decode(out)

            test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item()

    train_l2/= ntrain
    test_l2 /= ntest

    loss_train.append(train_l2)
    loss_test.append(test_l2)

    t2 = default_timer()
    print(ep, t2-t1, train_l2, test_l2)

301985


RuntimeError: ignored

In [None]:
model.double()(x)

torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:766.)


tensor([[[[0.0463],
          [0.0460],
          [0.0458],
          ...,
          [0.0837],
          [0.0839],
          [0.0838]],

         [[0.0463],
          [0.0462],
          [0.0461],
          ...,
          [0.0845],
          [0.0842],
          [0.0843]],

         [[0.0465],
          [0.0463],
          [0.0462],
          ...,
          [0.0852],
          [0.0850],
          [0.0849]],

         ...,

         [[0.1050],
          [0.1052],
          [0.1052],
          ...,
          [0.0554],
          [0.0555],
          [0.0553]],

         [[0.1051],
          [0.1053],
          [0.1055],
          ...,
          [0.0558],
          [0.0555],
          [0.0555]],

         [[0.1052],
          [0.1053],
          [0.1055],
          ...,
          [0.0556],
          [0.0556],
          [0.0557]]],


        [[[0.0463],
          [0.0460],
          [0.0458],
          ...,
          [0.0458],
          [0.0458],
          [0.0458]],

         [[0.0464],
    

In [None]:
x.shape

In [None]:
grid = model.get_grid(x.shape, x.device)
grid.shape
#x = torch.cat((x, grid), dim=-1)
#x =model.lifting(x)


In [None]:
grid.shape

In [None]:
x.shape

In [None]:
    ####Lifting Map 
    x = self.lifting(x)
    ###Actual Neural Operator
    x = x.permute(0, 3, 1, 2) if self.dim == 2 else x.permute(0, 2, 1)
    x = F.pad(x, [0,self.padding, 0,self.padding]) if self.dim == 2 else F.pad(x, [0,self.padding])
    x = self.fno(x)
    x = x[..., :-self.padding, :-self.padding] if self.dim == 2 else x[..., :-self.padding]
    x = x.permute(0, 2, 3, 1) if self.dim == 2 else x.permute(0, 2, 1)
    ####Projection Map
    x =self.proj(x)