# Fourier Neural Operators

## Preparation

In [1]:
import torch
import matplotlib.pyplot as plt
import sys
from neuralop.models import FNO
from neuralop import Trainer
from neuralop.datasets import load_spherical_swe
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss
import numpy as np

torch.manual_seed(0)
np.random.seed(0)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# %%
# Loading the Navier-Stokes dataset in 128x128 resolution
train_loader, test_loaders = load_spherical_swe(n_train=200, batch_size=4, train_resolution=(32, 64),
                                                test_resolutions=[(32, 64), (64, 128)], n_tests=[50, 50], test_batch_sizes=[10, 10])

model = FNO(n_modes=(32, 64), in_channels=3, out_channels=3, hidden_channels=32, projection_channels=64, factorization='dense').to(device)

Loading train dataloader at resolution (32, 64) with 200 samples and batch-size=4
Loading test dataloader at resolution (32, 64) with 50 samples and batch-size=10
Loading test dataloader at resolution (64, 128) with 50 samples and batch-size=10


## Train

In [2]:
n_params = count_model_params(model)
print(f'\nOur model has {n_params} parameters.')
sys.stdout.flush()


# %%
#Create the optimizer
optimizer = torch.optim.Adam(model.parameters(), 
                                lr=8e-4, 
                                weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)


# %%
# Creating the losses
l2loss = LpLoss(d=2, p=2, reduce_dims=(0,1))
# h1loss = H1Loss(d=2, reduce_dims=(0,1))

train_loss = l2loss
eval_losses={'l2': l2loss} #'h1': h1loss, 


# %%


print('\n### MODEL ###\n', model)
print('\n### OPTIMIZER ###\n', optimizer)
print('\n### SCHEDULER ###\n', scheduler)
print('\n### LOSSES ###')
print(f'\n * Train: {train_loss}')
print(f'\n * Test: {eval_losses}')
sys.stdout.flush()

step = 0

with open('script/sfno_loss.txt', 'w') as f:
    for epoch in range(20):
        avg_loss = 0
        train_err = 0.0
        
        # track number of training examples in batch
        n_samples = 0
        for idx, sample in enumerate(train_loader):
            optimizer.zero_grad(set_to_none=True)

            sample = {
                k: v.to(device)
                for k, v in sample.items()
                if torch.is_tensor(v)
            }

            n_samples += sample["y"].shape[0]
            out = model(sample["x"])

            loss = l2loss(out, **sample)

            loss.backward()
            del out

            optimizer.step()
            train_err += loss.item()
            with torch.no_grad():
                print("=======Loss:",loss.detach().cpu().numpy(),"======")
                f.write(f'Step {step + 1}, Loss: {loss.item()}\n')
                step += 1

        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(train_err)
        else:
            scheduler.step()


Our model has 8666531 parameters.

### MODEL ###
 FNO(
  (fno_blocks): FNOBlocks(
    (convs): SpectralConv(
      (weight): ModuleList(
        (0-3): 4 x ComplexDenseTensor(shape=torch.Size([32, 32, 32, 33]), rank=None)
      )
    )
    (fno_skips): ModuleList(
      (0-3): 4 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
  )
  (lifting): MLP(
    (fcs): ModuleList(
      (0): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (projection): MLP(
    (fcs): ModuleList(
      (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

### OPTIMIZER ###
 Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.0008
    lr: 0.0008
    maximize: False
    weight_decay: 0.0
)

### SCHEDULER ###
 <torch.

## Test

In [None]:
resolution = (32, 64)
_, test_loaders = load_spherical_swe(n_train=200, batch_size=4, train_resolution=(32, 64),
                                                test_resolutions=[resolution], n_tests=[50], test_batch_sizes=[1])

test_losses = []

with torch.no_grad():
    for idx, sample in enumerate(test_loaders[resolution]):
        inputs = sample['x'].to(device)
        targets = sample['y'].to(device)
        outputs = model(inputs)
        loss = l2loss(outputs, targets)
        test_losses.append(loss.item())

mean_loss = torch.mean(torch.tensor(test_losses))
std_loss = torch.std(torch.tensor(test_losses))

print(f'Test Loss: {mean_loss:.4f} ± {std_loss:.4f}')

In [3]:
resolution = (64, 128)
_, test_loaders = load_spherical_swe(n_train=200, batch_size=4, train_resolution=(32, 64),
                                                test_resolutions=[resolution], n_tests=[50], test_batch_sizes=[1])

test_losses = []

with torch.no_grad():
    for idx, sample in enumerate(test_loaders[resolution]):
        inputs = sample['x'].to(device)
        targets = sample['y'].to(device)
        outputs = model(inputs)
        loss = l2loss(outputs, targets)
        test_losses.append(loss.item())

mean_loss = torch.mean(torch.tensor(test_losses))
std_loss = torch.std(torch.tensor(test_losses))

print(f'Test Loss: {mean_loss:.4f} ± {std_loss:.4f}')

Loading train dataloader at resolution (32, 64) with 200 samples and batch-size=4
Loading test dataloader at resolution (64, 128) with 50 samples and batch-size=1
Test Loss: 1.0431 ± 0.0477


In [4]:
resolution = (128, 256)
_, test_loaders = load_spherical_swe(n_train=200, batch_size=4, train_resolution=(32, 64),
                                                test_resolutions=[resolution], n_tests=[50], test_batch_sizes=[1])

test_losses = []

with torch.no_grad():
    for idx, sample in enumerate(test_loaders[resolution]):
        inputs = sample['x'].to(device)
        targets = sample['y'].to(device)
        outputs = model(inputs)
        loss = l2loss(outputs, targets)
        test_losses.append(loss.item())

mean_loss = torch.mean(torch.tensor(test_losses))
std_loss = torch.std(torch.tensor(test_losses))

print(f'Test Loss: {mean_loss:.4f} ± {std_loss:.4f}')

Loading train dataloader at resolution (32, 64) with 200 samples and batch-size=4
Loading test dataloader at resolution (128, 256) with 50 samples and batch-size=1
Test Loss: 2.3513 ± 0.0391


In [5]:
import time
start_time = time.time()
model.eval()
outputs = model(inputs)
print("test time:", time.time()-start_time)

test time: 0.00400090217590332


## Vis

In [10]:
fig = plt.figure(figsize=(4, 2))
for index, resolution in enumerate([(32, 64), (64, 128), (128, 256)]):
    # Input x
    x = torch.tensor(np.load("../../test_dataset/input_"+str(resolution[0])+"_resolution.npy"))
    # Ground-truth
    y = np.load("../../test_dataset/label_"+str(resolution[0])+"_resolution.npy")
    # Model prediction
    x_in = x.unsqueeze(0).to(device)
    out = model(x_in).squeeze()[0, ...].detach().cpu().numpy()
    x = x[0, ...].detach().numpy()

    plt.imshow(out)
    plt.axis('off')
    plt.savefig("./script/output_" + str(resolution[0]) + "_resolution.png", bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
# for index, resolution in enumerate([(32, 64), (64, 128), (128, 256)]):
#     # Input x
#     x = torch.tensor(np.load("../../test_dataset/input_"+str(resolution[0])+"_resolution.npy"))
#     # Ground-truth
#     y = np.load("../../test_dataset/label_"+str(resolution[0])+"_resolution.npy")
#     x = x[0, ...].detach().numpy()
    
#     plt.imshow(x)
#     plt.axis('off')
#     plt.savefig("./script/input_" + str(resolution[0]) + "_resolution.png", bbox_inches='tight', pad_inches=0)
#     plt.close()
    
#     plt.imshow(y)
#     plt.axis('off')
#     plt.savefig("./script/label_" + str(resolution[0]) + "_resolution.png", bbox_inches='tight', pad_inches=0)
#     plt.close()
    

In [3]:
resolution = (128, 256)
_, test_loaders = load_spherical_swe(n_train=200, batch_size=4, train_resolution=resolution,
                                                test_resolutions=[resolution], n_tests=[50], test_batch_sizes=[1])

loss_low = 1e5
loss_high = 0

with torch.no_grad():
    for idx, sample in enumerate(test_loaders[resolution]):
        inputs = sample['x'].to(device)
        targets = sample['y'].to(device)
        outputs = model(inputs)
        loss = l2loss(outputs, targets)
        if loss > loss_high:
            loss_high = loss
            target_high = targets.squeeze()[0, ...].detach().cpu().numpy()
            output_high = outputs.squeeze()[0, ...].detach().cpu().numpy()
        if loss < loss_low:
            loss_low = loss
            target_low = targets.squeeze()[0, ...].detach().cpu().numpy()
            output_low = outputs.squeeze()[0, ...].detach().cpu().numpy()
            


Loading train dataloader at resolution (128, 256) with 200 samples and batch-size=4
Loading test dataloader at resolution (128, 256) with 50 samples and batch-size=1


In [4]:
fig = plt.figure(figsize=(4, 2))
plt.imshow(target_high)
plt.axis('off')
plt.savefig("./script/target_high.png", bbox_inches='tight', pad_inches=0)
plt.close()

plt.imshow(output_high)
plt.axis('off')
plt.savefig("./script/output_high.png", bbox_inches='tight', pad_inches=0)
plt.close()

plt.imshow(target_low)
plt.axis('off')
plt.savefig("./script/target_low.png", bbox_inches='tight', pad_inches=0)
plt.close()

plt.imshow(output_low)
plt.axis('off')
plt.savefig("./script/output_low.png", bbox_inches='tight', pad_inches=0)
plt.close()