In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from scipy.io import savemat

from UQpy.scientific_machine_learning.neural_networks import DeepOperatorNetwork
from UQpy.scientific_machine_learning.trainers import Trainer
from dataset import load_data, rescale

In [2]:
import logging
logger = logging.getLogger("UQpy")
logger.setLevel(logging.INFO)

In [3]:
# Define Branch network

class BranchNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fnn = nn.Sequential(nn.Linear(101, 100), nn.Tanh())
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 16, (5, 5), padding="same"),
            nn.AvgPool2d(2, 1, padding=0),
            nn.Conv2d(16, 16, (5, 5), padding="same"),
            nn.AvgPool2d(2, 1, padding=0),
            nn.Conv2d(16, 16, (5, 5), padding="same"),
            nn.AvgPool2d(2, 1, padding=0),
            nn.Conv2d(16, 64, (5, 5), padding="same"),
            nn.AvgPool2d(2, 1, padding=0),
        )
        self.dnn = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 6 * 6, 512),
            nn.Tanh(),
            nn.Linear(512, 512),
            nn.Tanh(),
            nn.Linear(512, 200),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fnn(x)
        x = x.view(-1, 1, 10, 10)
        x = self.conv_layers(x)
        x = self.dnn(x)
        return x

In [4]:
# Define Trunk network

class TrunkNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fnn = nn.Sequential(
            nn.Linear(2, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 200),
            nn.Tanh(),
        )
        self.Xmin = np.array([0.0, 0.0]).reshape((-1, 2))
        self.Xmax = np.array([1.0, 1.0]).reshape((-1, 2))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = 2.0 * (x - self.Xmin) / (self.Xmax - self.Xmin) - 1.0
        x = x.float()
        x = self.fnn(x)
        return x


In [5]:
# Define the model

branch_network = BranchNet()
trunk_network = TrunkNet()
model = DeepOperatorNetwork(branch_network, trunk_network, 2) # Number of ouptus: 2

In [6]:
# Load datasets and create data loaders
class ElasticityDataSet(Dataset):
    """Load the Elasticity dataset"""

    def __init__(self, x, f_x, u_x, u_y):
        self.x = x
        self.f_x = f_x
        self.u_x = u_x
        self.u_y = u_y

    def __len__(self):
        return int(self.f_x.shape[0])

    def __getitem__(self, i):
        return self.x, self.f_x[i, :], (self.u_x[i, :, 0], self.u_y[i, :, 0])

(F_train,Ux_train, Uy_train, F_test, Ux_test, Uy_test,
    X, ux_train_mean, ux_train_std, uy_train_mean, uy_train_std,) = load_data()
train_data = DataLoader(
    ElasticityDataSet(
        np.float32(X), np.float32(F_train), np.float32(Ux_train), np.float32(Uy_train)
    ),
    batch_size=100,
    shuffle=True,
)
test_data = DataLoader(
    ElasticityDataSet(
        np.float32(X), np.float32(F_test), np.float32(Ux_test), np.float32(Uy_test)
    ),
    batch_size=100,
    shuffle=True,
)


In [7]:
# Define the loss function

class LossFunction(nn.Module):
    def __init__(self, reduction: str = "mean", *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reduction = reduction

    def forward(self, prediction, label):
        return F.mse_loss(
            prediction[0], label[0], reduction=self.reduction
        ) + F.mse_loss(prediction[1], label[1], reduction=self.reduction)

In [8]:
# Define optimizer and trainer

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
trainer = Trainer(model, optimizer, LossFunction())

In [9]:
# Run the model

trainer.run(train_data=train_data, test_data=test_data, epochs=1000)

[INFO] - 2024-05-15 17:23:51,114 - UQpy: Scientific Machine Learning: Beginning training and testing DeepOperatorNetwork
[INFO] - 2024-05-15 17:23:54,373 - UQpy: Scientific Machine Learning: Epoch 1 / 1,000 Train Loss 98.72335411372937 Test Loss 4.101278781890869
[INFO] - 2024-05-15 17:23:57,551 - UQpy: Scientific Machine Learning: Epoch 2 / 1,000 Train Loss 8.91635571028057 Test Loss 6.430164337158203
[INFO] - 2024-05-15 17:24:00,708 - UQpy: Scientific Machine Learning: Epoch 3 / 1,000 Train Loss 2.9467433565541317 Test Loss 2.6804041862487793
[INFO] - 2024-05-15 17:24:03,867 - UQpy: Scientific Machine Learning: Epoch 4 / 1,000 Train Loss 2.099652497391952 Test Loss 2.5019302368164062
[INFO] - 2024-05-15 17:24:07,023 - UQpy: Scientific Machine Learning: Epoch 5 / 1,000 Train Loss 1.9123037300611798 Test Loss 2.377794027328491
[INFO] - 2024-05-15 17:24:10,176 - UQpy: Scientific Machine Learning: Epoch 6 / 1,000 Train Loss 1.8285051646985506 Test Loss 2.3296563625335693
[INFO] - 2024-05

In [10]:
# Evaluate test data and save results
def eval_model(test_data,model):
    model.eval()
    ux_pred_list = []
    uy_pred_list = []
    ux_test_list = []
    uy_test_list = []
    x_list = []
    for batch_number, (*x, y) in enumerate(test_data):
        ux_pred, uy_pred = model(*x)
        ux_test , uy_test = y
        ux_pred_list.append(ux_pred)
        uy_pred_list.append(uy_pred)
        ux_test_list.append(ux_test)
        uy_test_list.append(uy_test)
        x_list.append(x[1][:,0,:])
    return torch.cat(ux_pred_list), torch.cat(uy_pred_list), torch.cat(ux_test_list), torch.cat(uy_test_list), torch.cat(x_list)

In [13]:
ux_pred, uy_pred, ux_test, uy_test, x_test = eval_model(test_data,model)
ux_pred = rescale(ux_pred.detach(), np.squeeze(ux_train_mean, axis=2), np.squeeze(ux_train_std, axis=2))
uy_pred = rescale(uy_pred.detach(), np.squeeze(uy_train_mean, axis=2), np.squeeze(uy_train_std, axis=2))
ux_test = rescale(ux_test.detach(), np.squeeze(ux_train_mean, axis=2), np.squeeze(ux_train_std, axis=2))
uy_test = rescale(uy_test.detach(), np.squeeze(uy_train_mean, axis=2), np.squeeze(uy_train_std, axis=2))

In [12]:
savemat('Elastic_plate.mat',{'x_test': x_test.detach().numpy(), 'ux_test': ux_test.detach().numpy(), 'uy_test': uy_test.detach().numpy(), 
                              'ux_pred': ux_pred.detach().numpy(), 'uy_pred': uy_pred.detach().numpy()} )