Skip to content

Commit

Permalink
Merge pull request neuraloperator#262 from dhpitt/preprocessor
Browse files Browse the repository at this point in the history
Move to DataProcessor API
  • Loading branch information
JeanKossaifi committed Nov 13, 2023
2 parents f172ba7 + 2ae3e8a commit e2daca6
Show file tree
Hide file tree
Showing 20 changed files with 487 additions and 405 deletions.
22 changes: 17 additions & 5 deletions doc/source/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,8 @@ Trainer, we provide a Callback class and several examples of common domain-speci
:template: class.rst

Callback
SimpleWandBLoggerCallback
OutputEncoderCallback
MGPatchingCallback
ModelCheckpointCallback
MonitorMetricCheckpointCallback
BasicLoggerCallback
CheckpointCallback


Datasets
Expand All @@ -223,3 +220,18 @@ We ship a small dataset for testing:
:template: function.rst

load_darcy_flow_small

Much like PyTorch's `Torchvision.Datasets` module, our Datasets module also includes
utilities to transform data from its raw form into the form expected by models and
loss functions:

.. automodule:: neuralop.datasets.data_transforms
:no-members:
:no-inherited-members:

.. autosummary::
:toctree: generated
:template: class.rst

DefaultDataProcessor
MGPatchingDataProcessor
8 changes: 4 additions & 4 deletions examples/checkpoint_FNO_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
from neuralop.models import TFNO
from neuralop import Trainer
from neuralop.training import OutputEncoderCallback, CheckpointCallback
from neuralop.training import CheckpointCallback
from neuralop.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss
Expand All @@ -23,7 +23,7 @@

# %%
# Loading the Navier-Stokes dataset in 128x128 resolution
train_loader, test_loaders, output_encoder = load_darcy_flow_small(
train_loader, test_loaders, data_processor = load_darcy_flow_small(
n_train=1000, batch_size=32,
test_resolutions=[16, 32], n_tests=[100, 50],
test_batch_sizes=[32, 32],
Expand Down Expand Up @@ -75,12 +75,12 @@
trainer = Trainer(model=model, n_epochs=20,
device=device,
callbacks=[
OutputEncoderCallback(output_encoder),
CheckpointCallback(save_dir='./checkpoints',
save_interval=10,
save_optimizer=True,
save_scheduler=True)
],
],
data_processor=data_processor,
wandb_log=False,
log_test_interval=3,
use_distributed=False,
Expand Down
5 changes: 2 additions & 3 deletions examples/plot_FNO_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import sys
from neuralop.models import TFNO
from neuralop import Trainer
from neuralop.training import OutputEncoderCallback
from neuralop.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
from neuralop import LpLoss, H1Loss
Expand All @@ -25,7 +24,7 @@

# %%
# Loading the Navier-Stokes dataset in 128x128 resolution
train_loader, test_loaders, output_encoder = load_darcy_flow_small(
train_loader, test_loaders, data_processor = load_darcy_flow_small(
n_train=1000, batch_size=32,
test_resolutions=[16, 32], n_tests=[100, 50],
test_batch_sizes=[32, 32],
Expand Down Expand Up @@ -76,7 +75,7 @@
# Create the trainer
trainer = Trainer(model=model, n_epochs=20,
device=device,
callbacks=[OutputEncoderCallback(output_encoder)],
data_processor=data_processor,
wandb_log=False,
log_test_interval=3,
use_distributed=False,
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_SFNO_swe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Training a SFNO on the spherical Shallow Water equations
=============================
==========================================================
In this example, we demonstrate how to use the small Spherical Shallow Water Equations example we ship with the package
to train a Spherical Fourier-Neural Operator
Expand Down
5 changes: 2 additions & 3 deletions examples/plot_UNO_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import matplotlib.pyplot as plt
import sys
from neuralop.models import TFNO, UNO
from neuralop.training import OutputEncoderCallback
from neuralop import Trainer
from neuralop.datasets import load_darcy_flow_small
from neuralop.utils import count_model_params
Expand All @@ -25,7 +24,7 @@

# %%
# Loading the Darcy Flow dataset
train_loader, test_loaders, output_encoder = load_darcy_flow_small(
train_loader, test_loaders, data_processor = load_darcy_flow_small(
n_train=1000, batch_size=32,
test_resolutions=[16, 32], n_tests=[100, 50],
test_batch_sizes=[32, 32],
Expand Down Expand Up @@ -77,7 +76,7 @@
trainer = Trainer(model=model,
n_epochs=20,
device=device,
callbacks=[OutputEncoderCallback(output_encoder)],
data_processor=data_processor,
wandb_log=False,
log_test_interval=3,
use_distributed=False,
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_darcy_flow_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

# %%
# Loading the Navier-Stokes dataset in 128x128 resolution
train_loader, test_loaders, output_encoder = load_darcy_flow_small(
train_loader, test_loaders, data_processor = load_darcy_flow_small(
n_train=50, batch_size=50,
test_resolutions=[16, 32], n_tests=[50],
test_batch_sizes=[32], positional_encoding=False,
Expand Down
44 changes: 24 additions & 20 deletions neuralop/datasets/darcy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pathlib import Path
import torch

from ..utils import UnitGaussianNormalizer
from .output_encoder import UnitGaussianNormalizer
from .tensor_dataset import TensorDataset
from .transforms import PositionalEmbedding
from .transforms import PositionalEmbedding2D
from .data_transforms import DefaultDataProcessor


def load_darcy_flow_small(
Expand Down Expand Up @@ -115,9 +116,10 @@ def load_darcy_pt(
elif encoding == "pixel-wise":
reduce_dims = [0]

input_encoder = UnitGaussianNormalizer(x_train, reduce_dim=reduce_dims)
x_train = input_encoder.encode(x_train)
x_test = input_encoder.encode(x_test.contiguous())
input_encoder = UnitGaussianNormalizer(dim=reduce_dims)
input_encoder.fit(x_train)
#x_train = input_encoder.transform(x_train)
#x_test = input_encoder.transform(x_test.contiguous())
else:
input_encoder = None

Expand All @@ -127,17 +129,15 @@ def load_darcy_pt(
elif encoding == "pixel-wise":
reduce_dims = [0]

output_encoder = UnitGaussianNormalizer(y_train, reduce_dim=reduce_dims)
y_train = output_encoder.encode(y_train)
output_encoder = UnitGaussianNormalizer(dim=reduce_dims)
output_encoder.fit(y_train)
#y_train = output_encoder.transform(y_train)
else:
output_encoder = None

train_db = TensorDataset(
x_train,
y_train,
transform_x=PositionalEmbedding(grid_boundaries, 0)
if positional_encoding
else None,
)
train_loader = torch.utils.data.DataLoader(
train_db,
Expand All @@ -151,9 +151,6 @@ def load_darcy_pt(
test_db = TensorDataset(
x_test,
y_test,
transform_x=PositionalEmbedding(grid_boundaries, 0)
if positional_encoding
else None,
)
test_loader = torch.utils.data.DataLoader(
test_db,
Expand All @@ -177,15 +174,12 @@ def load_darcy_pt(
)
y_test = data["y"][:n_test, :, :].unsqueeze(channel_dim).clone()
del data
if input_encoder is not None:
x_test = input_encoder.encode(x_test)
#if input_encoder is not None:
#x_test = input_encoder.transform(x_test)

test_db = TensorDataset(
x_test,
y_test,
transform_x=PositionalEmbedding(grid_boundaries, 0)
if positional_encoding
else None,
)
test_loader = torch.utils.data.DataLoader(
test_db,
Expand All @@ -195,6 +189,16 @@ def load_darcy_pt(
pin_memory=True,
persistent_workers=False,
)
test_loaders[res] = test_loader
test_loaders[res] = test_loader

return train_loader, test_loaders, output_encoder

if positional_encoding:
pos_encoding = PositionalEmbedding2D(grid_boundaries=grid_boundaries)
else:
pos_encoding = None
data_processor = DefaultDataProcessor(
in_normalizer=input_encoder,
out_normalizer=output_encoder,
positional_encoding=pos_encoding
)
return train_loader, test_loaders, data_processor
Loading

0 comments on commit e2daca6

Please sign in to comment.