Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepONet should support arbitrary trunk/branch networks #149

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 0.2.0

- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.
- Add `branch_network` and `trunk_network` arguments to `DeepONet` to allow for custom network architectures.
- Add `MaskedOperator` base class.

## 0.1.0
Expand Down
79 changes: 52 additions & 27 deletions src/continuiti/operators/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ class DeepONet(Operator):
trunk_width: Width of trunk network.
trunk_depth: Depth of trunk network.
basis_functions: Number of basis functions.
act: Activation function.
act: Activation function used in default trunk and branch networks.
device: Device.
branch_network: Custom branch network that maps input function
evaluations to `basis_functions` many coefficients (if set,
branch_width and branch_depth will be ignored).
trunk_network: Custom trunk network that maps `shapes.y.dim`-dimensional
evaluation coordinates to `basis_functions` many basis function
evaluations (if set, trunk_width and trunk_depth will be ignored).
samuelburbulla marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand All @@ -44,30 +50,42 @@ def __init__(
basis_functions: int = 8,
act: Optional[torch.nn.Module] = None,
device: Optional[torch.device] = None,
branch_network: Optional[torch.nn.Module] = None,
trunk_network: Optional[torch.nn.Module] = None,
):
super().__init__(shapes, device)

self.basis_functions = basis_functions
self.dot_dim = shapes.v.dim * basis_functions
# trunk network
self.trunk = DeepResidualNetwork(
input_size=shapes.y.dim,
output_size=self.dot_dim,
width=trunk_width,
depth=trunk_depth,
act=act,
device=device,
)
if trunk_network is not None:
self.trunk = trunk_network
self.trunk.to(device)
else:
self.trunk = DeepResidualNetwork(
input_size=shapes.y.dim,
output_size=shapes.v.dim * basis_functions,
width=trunk_width,
depth=trunk_depth,
act=act,
device=device,
)

# branch network
self.branch_input_dim = math.prod(shapes.u.size) * shapes.u.dim
self.branch = DeepResidualNetwork(
input_size=self.branch_input_dim,
output_size=self.dot_dim,
width=branch_width,
depth=branch_depth,
act=act,
device=device,
)
if branch_network is not None:
self.branch = branch_network
self.branch.to(device)
else:
branch_input_dim = math.prod(shapes.u.size) * shapes.u.dim
self.branch = torch.nn.Sequential(
torch.nn.Flatten(),
DeepResidualNetwork(
input_size=branch_input_dim,
output_size=shapes.v.dim * basis_functions,
width=branch_width,
depth=branch_depth,
act=act,
device=device,
),
)

def forward(
self, _: torch.Tensor, u: torch.Tensor, y: torch.Tensor
Expand All @@ -85,24 +103,31 @@ def forward(
assert u.size(0) == y.size(0)
y_num = y.shape[2:]

# flatten inputs for both trunk and branch network
u = u.flatten(1, -1)
assert u.shape[1:] == torch.Size([self.branch_input_dim])

# flatten inputs for trunk network
y = y.swapaxes(1, -1).flatten(0, -2)
assert y.shape[-1:] == torch.Size([self.shapes.y.dim])

# Pass through branch and trunk networks
# Pass through branch network
b = self.branch(u)

# Pass through trunk network
t = self.trunk(y)

assert b.shape[1:] == t.shape[1:], (
f"Branch network output of shape {b.shape[1:]} does not match "
f"trunk network output of shape {t.shape[1:]}"
)

# determine basis functions dynamically
basis_functions = b.shape[1] // self.shapes.v.dim

# dot product
b = b.reshape(-1, self.shapes.v.dim, self.basis_functions)
b = b.reshape(-1, self.shapes.v.dim, basis_functions)
t = t.reshape(
b.size(0),
-1,
self.shapes.v.dim,
self.basis_functions,
basis_functions,
)
dot_prod = torch.einsum("abcd,acd->acb", t, b)
dot_prod = dot_prod.reshape(-1, self.shapes.v.dim, *y_num)
Expand Down
120 changes: 120 additions & 0 deletions tests/operators/test_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,123 @@ def test_deeponet():
# Check solution
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
assert MSELoss()(operator, x, u, y, v) < 1e-2


@pytest.mark.slow
def test_custom_branch_network():
# Data set
n_sensors = 32
dataset = SineBenchmark(n_train=1, n_sensors=n_sensors).train_dataset

# CNN as branch network
basis_functions = 8
branch_network = torch.nn.Sequential(
torch.nn.Conv1d(1, 16, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv1d(16, 1, kernel_size=3, padding=1),
torch.nn.Flatten(),
torch.nn.Linear(n_sensors, basis_functions),
)

# Operator
operator = DeepONet(
dataset.shapes,
branch_network=branch_network,
basis_functions=basis_functions,
)

# Train
Trainer(operator).fit(dataset, tol=1e-3)

# Check solution
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
assert MSELoss()(operator, x, u, y, v) < 1e-3


@pytest.mark.slow
def test_custom_trunk_network():
# Data set
n_sensors = 32
dataset = SineBenchmark(n_train=1, n_sensors=n_sensors).train_dataset

# MLP as trunk network
basis_functions = 32
trunk_network = torch.nn.Sequential(
torch.nn.Linear(1, 32),
torch.nn.LayerNorm(32),
torch.nn.Sigmoid(),
torch.nn.Linear(32, 32),
torch.nn.BatchNorm1d(32),
torch.nn.GELU(),
torch.nn.Linear(32, basis_functions),
)

# Operator
operator = DeepONet(
dataset.shapes,
trunk_network=trunk_network,
basis_functions=basis_functions,
)

# Train
Trainer(operator).fit(dataset, tol=1e-3)

# Check solution
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
assert MSELoss()(operator, x, u, y, v) < 1e-3


@pytest.mark.slow
def test_custom_branch_and_trunk_network():
# Data set
n_sensors = 32
dataset = SineBenchmark(n_train=1, n_sensors=n_sensors).train_dataset

# CNN as branch network
basis_functions = 32
branch_network = torch.nn.Sequential(
torch.nn.Conv1d(1, 16, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv1d(16, 1, kernel_size=3, padding=1),
torch.nn.Flatten(),
torch.nn.Linear(n_sensors, basis_functions),
)

# Custom MLP as trunk network
trunk_network = torch.nn.Sequential(
torch.nn.Linear(1, 32),
torch.nn.LayerNorm(32),
torch.nn.Sigmoid(),
torch.nn.Linear(32, 32),
torch.nn.BatchNorm1d(32),
torch.nn.GELU(),
torch.nn.Linear(32, basis_functions),
)

# Operator
operator = DeepONet(
dataset.shapes,
branch_network=branch_network,
trunk_network=trunk_network,
)

# Train
Trainer(operator).fit(dataset, tol=1e-3)

# Check solution
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
assert MSELoss()(operator, x, u, y, v) < 1e-3

# Operator
operator = DeepONet(
dataset.shapes,
trunk_network=trunk_network,
basis_functions=basis_functions,
)

# Train
Trainer(operator).fit(dataset, tol=1e-3)

# Check solution
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
assert MSELoss()(operator, x, u, y, v) < 1e-3
Loading