In [1]:
import torch as tc
import torch.nn as nn
import numpy as np

In [2]:
class FNOLayer(nn.Module):
    def __init__(self, features: int, k_max: int):
        super().__init__()
        self.features = features
        self.k_max = k_max
        #Fourier space filters
        self.filter = nn.Parameter(tc.randn(features, features, k_max, dtype=tc.cfloat))

        #Physical space transformation
        self.linear_trans = nn.Linear(features, features)

    def forward(self, x: tc.Tensor) -> tc.Tensor:
        # x: (batch, grid_size, features), represents the function which solves desired diff eq
        batch, grid_size, features = x.shape
        x_tilde = tc.fft.rfft(x, dim = 1)[:,:self.k_max,:]
        # x_tilde: (batch, k_max, features)
        x_filtered = tc.einsum('bkf,fgk->bkg', x_tilde, self.filter)
        x_tilde_full = tc.zeros(batch, grid_size//2 + 1, features, dtype=tc.cfloat, device=x.device)
        x_tilde_full[:, :self.k_max, :] = x_filtered
        freq_branch = tc.fft.irfft(
            x_tilde_full,
            n=grid_size,
            dim = 1).real
        space_branch = self.linear_trans(x)
        output = space_branch + freq_branch
        return nn.GELU()(output)

In [21]:
class FNO(nn.Module):
    def __init__(self, features: int, k_max: int, num_layers: int, grid_size: int):
        super().__init__()

        # This can be included if we don't want the grid to be created dynamically in the forward call
        # self.register_buffer('grid', tc.linspace(0, 1, grid_size).reshape(1, grid_size, 1))

        self.input_proj = nn.Linear(2, features)
        self.GNO_layers = nn.ModuleList([FNOLayer(features, k_max) for _ in range(num_layers)])
        self.output_proj = nn.Linear(features, 1)

    def forward(self, x: tc.Tensor) -> tc.Tensor:
        batch, grid_size, _ = x.shape
        grid = tc.linspace(0, 1, grid_size, device=x.device).reshape(1, grid_size, 1).expand(batch, -1, -1)
        x = tc.cat([x, grid], dim=-1)
        x = self.input_proj(x)
        for layer in self.GNO_layers:
            x = layer(x)
        x = self.output_proj(x)
        return x.squeeze(-1)

In [4]:
def random_initial_condition(x, n_modes=5):
    u0 = np.zeros_like(x)
    amps = []
    for k in range(1, n_modes + 1):
        amplitude = np.random.uniform(-1, 1)
        u0 += amplitude * np.sin(k * np.pi * x)
        amps.append(amplitude)
    return u0, amps

In [5]:
from scipy.integrate import solve_ivp

def solve_heat_equation(u0, x, t_end=0.1, alpha=1.0):
    N = len(x)
    dx = x[1] - x[0]

    def heat_rhs(t, u):
        dudt = np.zeros_like(u)
        # finite difference for second derivative
        dudt[1:-1] = alpha * (u[2:] - 2*u[1:-1] + u[:-2]) / dx**2
        # boundary conditions: u[0] = u[-1] = 0 (already zero)
        return dudt

    sol = solve_ivp(heat_rhs, [0, t_end], u0, method='RK45',
                    t_eval=[t_end], rtol=1e-6, atol=1e-8)
    return sol.y[:, -1]  # solution at t_end

In [6]:
N = 64          # spatial grid points during training
x = np.linspace(0, 1, N)
n_samples = 10000

inputs = []
targets = []

for _ in range(n_samples):
    u0, _ = random_initial_condition(x)
    uT = solve_heat_equation(u0, x)
    inputs.append(u0)
    targets.append(uT)

inputs = np.array(inputs)   # (10000, 64)
targets = np.array(targets) # (10000, 64)
np.save('inputs.npy', inputs)
np.save('targets.npy', targets)

In [9]:
from torch.utils.data import TensorDataset, DataLoader
# Don't need to load if calculating the data for the first time
# inputs = np.load('inputs.npy')
# targets = np.load('targets.npy')
X = tc.tensor(inputs, dtype=tc.float32).unsqueeze(-1)  # (10000, 64, 1)
Y = tc.tensor(targets, dtype=tc.float32)               # (10000, 64)

dataset = TensorDataset(X, Y)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = tc.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64)

In [22]:
model = FNO(features=64, k_max=16, num_layers=4, grid_size=64)
optimizer = tc.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
scheduler = tc.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

for epoch in range(350):
    model.train()
    train_loss = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        pred = model(x_batch)
        loss = loss_fn(pred, y_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with tc.no_grad():
        for x_batch, y_batch in val_loader:
            pred = model(x_batch)
            val_loss += loss_fn(pred, y_batch).item()

    print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_loader):.6f} | Val Loss: {val_loss/len(val_loader):.6f}")
    scheduler.step(val_loss)

Epoch 1 | Train Loss: 1624.555689 | Val Loss: 301.651521
Epoch 2 | Train Loss: 212.039833 | Val Loss: 152.829993
Epoch 3 | Train Loss: 119.874609 | Val Loss: 97.364727
Epoch 4 | Train Loss: 80.128402 | Val Loss: 69.052670
Epoch 5 | Train Loss: 57.990014 | Val Loss: 51.747051
Epoch 6 | Train Loss: 44.259965 | Val Loss: 40.640850
Epoch 7 | Train Loss: 34.975215 | Val Loss: 32.749330
Epoch 8 | Train Loss: 28.379397 | Val Loss: 26.969306
Epoch 9 | Train Loss: 23.486779 | Val Loss: 22.683391
Epoch 10 | Train Loss: 19.720148 | Val Loss: 19.163158
Epoch 11 | Train Loss: 16.763180 | Val Loss: 16.499756
Epoch 12 | Train Loss: 14.387924 | Val Loss: 14.247454
Epoch 13 | Train Loss: 12.443971 | Val Loss: 12.421850
Epoch 14 | Train Loss: 10.851373 | Val Loss: 10.907535
Epoch 15 | Train Loss: 9.519711 | Val Loss: 9.628162
Epoch 16 | Train Loss: 8.384711 | Val Loss: 8.489716
Epoch 17 | Train Loss: 7.412732 | Val Loss: 7.561955
Epoch 18 | Train Loss: 6.583508 | Val Loss: 6.739882
Epoch 19 | Train Loss

In [13]:
def analytical_solution(u0_amplitudes, x, t, alpha=1.0):
    u = np.zeros_like(x)
    for k, A in enumerate(u0_amplitudes, start=1):
        u += A * np.sin(k * np.pi * x) * np.exp(-alpha * (k * np.pi)**2 * t)
    return u

In [14]:
import time

n_test = 100

test_inputs = []
test_amplitudes = []
for _ in range(n_test):
    u0, a = random_initial_condition(x)
    test_inputs.append(u0)
    test_amplitudes.append(a)

test_tensor = tc.tensor(np.array(test_inputs), dtype=tc.float32).unsqueeze(-1)

In [38]:
# scipy
start = time.time()
for u0 in test_inputs:
    solve_heat_equation(u0, x)
scipy_time = time.time() - start

# FNO
start = time.time()
with tc.no_grad():
    model(test_tensor)
FNO_time = time.time() - start

print(
    f"Scipy: {scipy_time:.3f}s | FNO: {FNO_time:.3f}s | Speedup: {scipy_time / FNO_time:.1f}x")


Scipy: 1.999s | FNO: 0.010s | Speedup: 208.0x


In [39]:
FNO_solution = model(test_tensor).detach().numpy()

In [25]:
ana_solution = np.zeros_like(GNO_solution)
for i in range(100):
    ana_solution[i,:] = analytical_solution(test_amplitudes[i], x, t = 0.1)

In [40]:
relative_l2 = 0
for i in range(100):
    relative_l2 += np.sqrt(np.sum((FNO_solution[i, :] - ana_solution[i, :]) ** 2)) / np.sqrt(
        np.sum(ana_solution ** 2)) / 100
print(relative_l2*100)

0.11325549


In [30]:
x_test_128 = np.linspace(0, 1, 128)

test_inputs_128 = []
test_amplitudes_128 = []
for _ in range(n_test):
    u0, a = random_initial_condition(x_test_128)
    test_inputs_128.append(u0)
    test_amplitudes_128.append(a)

test_tensor_128 = tc.tensor(np.array(test_inputs_128), dtype=tc.float32).unsqueeze(-1)

In [41]:
FNO_solution_128 = model(test_tensor_128).detach().numpy()

In [35]:
ana_solution_128 = np.zeros_like(GNO_solution_128)
for i in range(100):
    ana_solution_128[i,:] = analytical_solution(test_amplitudes_128[i], x_test_128, t = 0.1)

In [42]:
relative_l2_128 = 0
for i in range(100):
    relative_l2_128 += np.sqrt(np.sum((FNO_solution_128[i, :] - ana_solution_128[i, :]) ** 2)) / np.sqrt(
        np.sum(ana_solution_128 ** 2)) / 100
print(relative_l2_128*100)

0.1806133
