# Navier Stokes Equations using Fourier Neural Operator

In this notebook, we will solve the Navier Stokes equations using Fourier Neural Operator. The Navier Stokes equations are given by

$$
\begin{align*}
\partial_t w(x,t) + u(x,t) \cdot \nabla w(x,t) &= \nu \Delta w(x,t) + f(x)\\
\nabla \cdot u(x,t) &= 0\\
w(x,0) &= w_0(x)
\end{align*}
$$

where $w(x,t)$ is the velocity field, $u(x,t)$ is the velocity field, $\nu$ is the viscosity, $f(x)$ is the external force, and $w_0(x)$ is the initial velocity field.

In [1]:
# Setup paths to import modules
import sys
import os

# Add parent directory to path
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

# Import FNO model from PyTorch implementation
from FNO.PyTorch.fno import FNO

# Import loss function
from losses.lploss import LpLoss

# Import training utilities
from training.train import train_model

# Import data utilities
from utilities.utils import MatlabFileReader

import matplotlib.pyplot as plt
import torch
import torch.nn as nn

# Create a wrapper for FNO to match the expected FNO2DTime interface
class FNO2DTime(nn.Module):
    """
    Wrapper for FNO that matches the expected interface for 2D+time problems.
    This is for the Navier-Stokes equation with time evolution.
    """
    def __init__(self, modes=12, width=32, layers=4, **kwargs):
        super(FNO2DTime, self).__init__()
        
        # The input is [batch, x, y, t_in] where t_in is the input time steps
        # For Navier-Stokes: input is initial condition (10 time steps) + 2 spatial coords = 12 channels
        in_channels = kwargs.get('in_channels', 12)
        
        # Output is the next 10 time steps (or however many we want to predict)
        out_channels = kwargs.get('out_channels', 1)
        
        # Create the FNO model
        # modes for 2D: [modes_x, modes_y]
        if isinstance(modes, int):
            modes_list = [modes, modes]
        else:
            modes_list = modes
            
        self.fno = FNO(
            modes=modes_list,
            num_fourier_layers=layers,
            in_channels=in_channels,
            lifting_channels=width,
            projection_channels=width,
            out_channels=out_channels,
            mid_channels=width,
            activation=nn.GELU(),
            **kwargs
        )
    
    def forward(self, x):
        return self.fno(x)
    
    def cuda(self):
        if torch.cuda.is_available():
             return self.to('cuda')
        return self


## Data

In [3]:
data_path = 'NavierStokes_V1e-5_N1200_T20.mat'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
mat_data = MatlabFileReader(data_path, device=device, to_tensor=True)
mat_data

Viewing the data

In [5]:
u_data = mat_data.read_file('u')
plt.imshow(u_data[0, :, :, 4].cpu().numpy())
plt.title('u, shape = {}'.format(u_data.shape))
plt.show()

## Model

In [7]:
# Define the model 
model = FNO2DTime().to(device)

## Data Preprocessing

In [9]:
# Create data loaders

# Split data into train and eval
train_data = u_data[:1000, :, :, :]
eval_data = u_data[1000:, :, :, :]

# Split data from time 0-10 and 10-20
u_train = train_data[:, :, :, :10]
a_train = train_data[:, :, :, 10:]

u_eval = eval_data[:, :, :, :10]
a_eval = eval_data[:, :, :, 10:]

# Define data loaders
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(u_train, a_train),
    batch_size=50, shuffle=True)

eval_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(u_eval, a_eval),
    batch_size=50, shuffle=True)

## Training

In [11]:
loss, mse = train_model(model, train_loader, eval_loader, epochs=100, device=device)

## Results

In [13]:
# Plot the loss and mse
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(loss)
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(mse)
plt.title('MSE')
plt.show()

### Making predictions

In [15]:
data_test = u_data[1000:1001, :, :, :]
data_test.shape

In [16]:
# Make predictions
with torch.no_grad():
    pred = model(data_test[:, :, :, :10].to(device)).cpu()

In [17]:
# Compare the prediction with the actual data
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(pred[0, :, :, -1].numpy())
plt.title('Prediction')
plt.subplot(1, 2, 2)
plt.imshow(data_test[0, :, :, -1].cpu().numpy())
plt.title('Ground truth')
plt.show()