In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from neuralop.models import FNO
from torch.utils.data import TensorDataset, DataLoader
from scipy.special import legendre, chebyt, jv, hermite, eval_laguerre
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [9]:
function_types = [
    "fourier",          # sum of sinusoids
    "poly",             # standard polynomial
    "gaussian",         # single Gaussian
    "gaussian_mixture", # sum of Gaussians
    "damped_sine",      # exponentially damped sinusoid
    "exp_decay",        # exponential decay
    "piecewise",        # piecewise linear/quadratic
    "trig_combo",       # combination of sin and cos
    "legendre",         # Legendre polynomials
    "chebyshev",        # Chebyshev polynomials (1st kind)
    "bessel",           # Bessel function of first kind
    "hermite",          # Hermite polynomials
    "laguerre",         # Laguerre polynomials
    "windowed_sine",    # sinusoid multiplied by Gaussian
    "rect_pulse",       # rectangular pulse
    "sawtooth",         # sawtooth wave
    "triangle",         # triangle wave
    "modulated",        # product of sinusoids (beats)
    "chirp",            # frequency-increasing sinusoid
    "spikes",           # sparse impulses
    "wavelet"           # Mexican hat wavelet
]

In [10]:
def make_function(x, kind="random", max_freq=10):
    """
    Generate a diverse set of functions for Fourier testing.

    Parameters:
        x        : array of input points
        kind     : type of function to generate; if 'random', one is picked randomly
        max_freq : maximum frequency for Fourier-type functions

    Returns:
        f        : array of function values
    """
    
    if kind == "random":
        kind = np.random.choice(function_types)
    
    # ----------------- Standard types -----------------
    if kind == "fourier":
        coeffs = np.random.randn(max_freq)
        f = np.zeros_like(x, dtype=float)
        for n, a in enumerate(coeffs, start=1):
            f += a * np.sin(np.pi * n * x)
        return f

    elif kind == "poly":
        coeffs = np.random.randn(5)
        return sum(c * x**i for i, c in enumerate(coeffs))

    elif kind == "gaussian":
        mu, sigma = np.random.uniform(-0.5, 0.5), np.random.uniform(0.05, 0.5)
        return np.exp(-((x - mu) ** 2) / (2 * sigma ** 2))

    elif kind == "gaussian_mixture":
        num_gaussians = np.random.randint(2, 4)
        f = np.zeros_like(x)
        for _ in range(num_gaussians):
            mu, sigma, amp = np.random.uniform(-0.5,0.5), np.random.uniform(0.05,0.3), np.random.uniform(0.5,2.0)
            f += amp * np.exp(-((x - mu)**2) / (2 * sigma**2))
        return f

    elif kind == "damped_sine":
        freq = np.random.uniform(1, max_freq)
        decay = np.random.uniform(0.5, 2.0)
        phase = np.random.uniform(0, 2*np.pi)
        return np.exp(-decay * np.abs(x)) * np.sin(2 * np.pi * freq * x + phase)

    elif kind == "exp_decay":
        lam = np.random.uniform(0.5, 2.0)
        return np.exp(-lam * np.abs(x))

    elif kind == "piecewise":
        split = np.random.uniform(x[0], x[-1])
        return np.piecewise(x, [x < split, x >= split],
                            [lambda t: t**2, lambda t: -t + split])

    elif kind == "trig_combo":
        f = np.zeros_like(x)
        num_terms = np.random.randint(2, 5)
        for _ in range(num_terms):
            amp = np.random.uniform(0.5, 2.0)
            freq = np.random.randint(1, max_freq)
            phase = np.random.uniform(0, 2*np.pi)
            f += amp * (np.sin(2*np.pi*freq*x + phase) + np.cos(2*np.pi*freq*x + phase))
        return f

    # ----------------- Special polynomials -----------------
    elif kind == "legendre":
        deg = np.random.randint(1, 6)
        P = legendre(deg)
        return P(x)

    elif kind == "chebyshev":
        deg = np.random.randint(1, 6)
        T = chebyt(deg)
        return T(x)

    elif kind == "bessel":
        order = np.random.randint(0, 6)
        k = np.random.uniform(1, 10)
        return jv(order, k * x)

    elif kind == "hermite":
        deg = np.random.randint(1,5)
        H = hermite(deg)
        return H(x)

    elif kind == "laguerre":
        deg = np.random.randint(1,5)
        return eval_laguerre(deg, np.abs(x))  # Laguerre defined on [0,∞)

    # ----------------- Windowed / localized functions -----------------
    elif kind == "windowed_sine":
        freq = np.random.uniform(1, max_freq)
        alpha = np.random.uniform(1,5)
        return np.sin(2*np.pi*freq*x) * np.exp(-alpha*x**2)

    elif kind == "rect_pulse":
        start, end = np.random.uniform(-0.5, 0), np.random.uniform(0,0.5)
        return np.where((x>=start) & (x<=end), 1.0, 0.0)

    elif kind == "sawtooth":
        return 2*(x - np.floor(x + 0.5))  # normalized sawtooth

    elif kind == "triangle":
        return 2*np.abs(2*(x - np.floor(x + 0.5))) - 1

    elif kind == "modulated":
        f1 = np.sin(5*np.pi*x)
        f2 = np.cos(2*np.pi*x)
        return f1*f2

    elif kind == "chirp":
        return np.sin(2*np.pi*(x + x**2))

    elif kind == "spikes":
        f = np.zeros_like(x)
        num_spikes = np.random.randint(3,8)
        indices = np.random.choice(len(x), num_spikes, replace=False)
        f[indices] = np.random.uniform(1,3, size=num_spikes)
        return f

    elif kind == "wavelet":
        return (1 - x**2) * np.exp(-x**2 / 2)  # Mexican hat

    else:
        raise ValueError(f"Unknown function type '{kind}'")

In [11]:
DOMAIN = (-np.pi, np.pi)
RESOLUTION = 256
TRAIN_SAMPLES = 2000
TEST_SAMPLES = 200
BATCH_SIZE = 32
EPOCHS = 50

In [12]:
def generate_fourier_dataset(num_samples=TRAIN_SAMPLES, resolution=RESOLUTION):
    """Generate dataset of functions and their Fourier transforms"""
    x = torch.linspace(DOMAIN[0], DOMAIN[1], resolution)
    inputs = []
    outputs = []
    
    for _ in range(num_samples):
        # Create random function as sum of sinusoids
        f = torch.Tensor(make_function(x.numpy(), kind="random", max_freq=10))
        
        # Compute Fourier transform
        ft = torch.fft.fft(f)
        # Stack real and imaginary parts
        ft_real_imag = torch.stack([ft.real, ft.imag], dim=0)
        
        inputs.append(f.unsqueeze(0))  # Add channel dimension: (1, resolution)
        outputs.append(ft_real_imag)   # (2, resolution) - real and imag parts
    
    # Stack all samples: (num_samples, channels, resolution)
    X = torch.stack(inputs)     # (num_samples, 1, resolution)  
    Y = torch.stack(outputs)    # (num_samples, 2, resolution)
    
    return X, Y

X_train, Y_train = generate_fourier_dataset(num_samples=TRAIN_SAMPLES, resolution=RESOLUTION)
X_val, Y_val = generate_fourier_dataset(num_samples=TEST_SAMPLES, resolution=RESOLUTION)
X_test, Y_test = generate_fourier_dataset(num_samples=TEST_SAMPLES, resolution=RESOLUTION)

train_dataset = TensorDataset(X_train, Y_train)
val_dataset = TensorDataset(X_val, Y_val)
test_dataset = TensorDataset(X_test, Y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model = FNO(
    n_modes=(32,),           # Number of Fourier modes for 1D
    hidden_channels=64,      # Width of the network
    in_channels=1,           # Input: single function
    out_channels=2,          # Output: real and imaginary parts
    n_layers=4,              # Number of FNO layers
    factorization=None       # No tensorization
)

FNO(
  (positional_embedding): GridEmbeddingND()
  (fno_blocks): FNOBlocks(
    (convs): ModuleList(
      (0-3): 4 x SpectralConv(
        (weight): DenseTensor(shape=torch.Size([64, 64, 17]), rank=None)
      )
    )
    (fno_skips): ModuleList(
      (0-3): 4 x Flattened1dConv(
        (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
      )
    )
    (channel_mlp): ModuleList(
      (0-3): 4 x ChannelMLP(
        (fcs): ModuleList(
          (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
          (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
        )
      )
    )
    (channel_mlp_skips): ModuleList(
      (0-3): 4 x SoftGating()
    )
  )
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(2, 128, kernel_size=(1,), stride=(1,))
      (1): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (1): Conv1d(128, 2, kernel_size=

In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
criterion = nn.MSELoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

FNO(
  (positional_embedding): GridEmbeddingND()
  (fno_blocks): FNOBlocks(
    (convs): ModuleList(
      (0-3): 4 x SpectralConv(
        (weight): DenseTensor(shape=torch.Size([64, 64, 17]), rank=None)
      )
    )
    (fno_skips): ModuleList(
      (0-3): 4 x Flattened1dConv(
        (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
      )
    )
    (channel_mlp): ModuleList(
      (0-3): 4 x ChannelMLP(
        (fcs): ModuleList(
          (0): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
          (1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
        )
      )
    )
    (channel_mlp_skips): ModuleList(
      (0-3): 4 x SoftGating()
    )
  )
  (lifting): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(2, 128, kernel_size=(1,), stride=(1,))
      (1): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
    )
  )
  (projection): ChannelMLP(
    (fcs): ModuleList(
      (0): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (1): Conv1d(128, 2, kernel_size=

In [24]:
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    
    # Training phase
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += criterion(output, target).item()

    # Update scheduler
    scheduler.step()
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch [{epoch+1}/{EPOCHS}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}')

metrics = {"Training Loss": avg_train_loss, "Validation Loss": avg_val_loss}

Epoch [10/50], Train Loss: 15438.202699, Val Loss: 13090.276280
Epoch [20/50], Train Loss: 6193.406149, Val Loss: 9914.399152
Epoch [30/50], Train Loss: 4665.404650, Val Loss: 8718.177211
Epoch [40/50], Train Loss: 4220.761840, Val Loss: 8415.488713
Epoch [50/50], Train Loss: 4017.259726, Val Loss: 8359.490060


In [25]:
model.eval()
test_loss = 0.0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += criterion(output, target).item()

avg_test_loss = test_loss / len(test_loader)
metrics["Test Loss"] = avg_test_loss
print(f"Test Loss: {avg_test_loss:.6f}")

  return self.__class__(self.tensor[indices])
  out_fft[slices_x] = self._contract(x[slices_x], weight, separable=self.separable)


Test Loss: 3404.498413


In [26]:
metrics_df = pd.DataFrame(metrics, index=[model._name])
metrics_df

Unnamed: 0,Training Loss,Validation Loss,Test Loss
FNO,4017.259726,8359.49006,3404.498413


In [None]:
def test_function(f):
    # Generate test data
    x_test = torch.linspace(DOMAIN[0], DOMAIN[1], RESOLUTION)
    f_test = f(x_test)

    # True Fourier transform
    ft_true = torch.fft.fft(f_test)

    # Predict using trained model
    model.eval()
    with torch.no_grad():
        # Prepare input: add batch and channel dimensions
        f_input = f_test.unsqueeze(0).unsqueeze(0).to(device)  
        ft_pred = model(f_input).cpu().squeeze(0)              

    # Plot results
    plt.figure(figsize=(12, 8))

    # Original function
    plt.subplot(2, 2, 1)
    plt.plot(x_test.numpy(), f_test.numpy())
    plt.title('Input Function')
    plt.xlabel('x')
    plt.ylabel('f(x)')

    # Real part comparison
    plt.subplot(2, 2, 3)
    plt.plot(ft_true.real.numpy(), label='True Real', linewidth=2)
    plt.plot(ft_pred[0].numpy(), '--', label='Predicted Real', linewidth=2)
    plt.title('Fourier Transform - Real Part')
    plt.legend()

    # Imaginary part comparison
    plt.subplot(2, 2, 4)
    plt.plot(ft_true.imag.numpy(), label='True Imag', linewidth=2)
    plt.plot(ft_pred[1].numpy(), '--', label='Predicted Imag', linewidth=2)
    plt.title('Fourier Transform - Imaginary Part')
    plt.legend()

    # Magnitude comparison
    plt.subplot(2, 2, 2)
    ft_true_mag = torch.abs(ft_true)
    ft_pred_complex = ft_pred[0] + 1j * ft_pred[1]
    ft_pred_mag = torch.abs(ft_pred_complex)
    plt.plot(ft_true_mag.numpy(), label='True Magnitude', linewidth=2)
    plt.plot(ft_pred_mag.numpy(), '--', label='Predicted Magnitude', linewidth=2)
    plt.title('Fourier Transform - Magnitude')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Print error metrics
    mse_real = torch.mean((ft_true.real - ft_pred[0])**2)
    mse_imag = torch.mean((ft_true.imag - ft_pred[1])**2)
    mse = torch.mean((ft_true_mag - ft_pred_mag)**2)
    print(f"MSE Real Part: {mse_real.item():.6f}")
    print(f"MSE Imaginary Part: {mse_imag.item():.6f}")
    print(f"MSE Magnitude: {mse.item():.6f}")

### To be fixed
-   Use a more focussed training data
-   Change layers/channels