In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Subset
import gdown
from IPython.display import clear_output as clc

In [None]:
try:
     from dlroms import*
except:
     !pip install git+https://github.com/NicolaRFranco/dlroms.git
     from dlroms import*

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
class DeepONet(nn.Module):
  def __init__(self, m, d, p, h):
      """
      Parameters
      ----------
      m : int
          dimension for the input parameters
      d : int
          dimension for the input data
      p : int
          dimension of branch and trunk networks output
      h : int
          number of neurons in hidden layers
      """
      super().__init__()

      self.branch = nn.Sequential(
          nn.Linear(m, h),
          nn.ReLU(),
          nn.Linear(h, h),
          nn.ReLU(),
          nn.Linear(h, p)
      )
      self.trunk = nn.Sequential(
          nn.Linear(d, h),
          nn.ReLU(),
          nn.Linear(h, h),
          nn.ReLU(),
          nn.Linear(h, p)
      )

  def forward(self, u, y):
      b = self.branch(u)  # Shape: (batch_size, p)
      b = b.unsqueeze(1)  # New shape: (batch_size, 1, p)
      t = self.trunk(y)   # Shape: (batch_size, 1681, p)

      #print(f"b.shape: {b.shape}, t.shape: {t.shape}")

      Gu = torch.sum(b * t, dim=-1) # Shape: (batch_size, 1681)
      #print(f"Gu.shape: {Gu.shape}")

      return Gu

In [None]:
# Data
class FomDataset(Dataset):
  def __init__(self, mu, u, y):
    """
    Parameters
    ----------
    mu : int
    input parameters, referred as u in the DeepONet paper and in the above model
    u : int
    solution of the PDE, referred as G(u) in the DeepONet paper and in the above model
    y : int
    spatial domain coordinate, referred as y in the DeepONet paper and in the above model
    """
    self.mu = torch.tensor(mu, dtype=torch.float32).to(device)
    self.u = torch.tensor(u, dtype=torch.float32).to(device)
    self.y = torch.tensor(y, dtype=torch.float32).to(device)

  def __len__(self):
    return len(self.mu)

  def __getitem__(self, idx):
    return self.mu[idx], self.u[idx], self.y

def train_val_test_split(dataset, train_size, val_size):
    """
    Splits a dataset into training, validation, and test subsets.

    Parameters
    ----------
    dataset : FomDataset
        The full dataset.
    train_size : int
        Total number of training samples (includes validation).
    val_size : int
        Number of validation samples.
    """
    train_indices = list(range(0, train_size - val_size))
    val_indices = list(range(train_size - val_size, train_size))
    test_indices = list(range(train_size, len(dataset)))

    train_set = Subset(dataset, train_indices)
    val_set = Subset(dataset, val_indices)
    test_set = Subset(dataset, test_indices)

    return train_set, val_set, test_set

from torch.utils.data import Subset

def train_val_test_split2(dataset, test_size):
    """
    Splits a dataset into training, validation, and test subsets.

    Parameters
    ----------
    dataset : FomDataset
        The full dataset.
    test_size : int
        Total number of test samples.
    """
    total_size = len(dataset)
    test_indices = list(range(total_size - test_size, total_size))

    # Remaining data for training and validation
    remaining_indices = list(range(0, total_size - test_size))

    # Calculate train and validation split
    train_size = int(0.9 * len(remaining_indices))  # 90% for training
    val_size = len(remaining_indices) - train_size  # Remaining 10% for validation

    train_indices = remaining_indices[:train_size]
    val_indices = remaining_indices[train_size:]

    # Create subsets
    train_set = Subset(dataset, train_indices)
    val_set = Subset(dataset, val_indices)
    test_set = Subset(dataset, test_indices)

    return train_set, val_set, test_set


In [None]:
def load_data(id):
  gdown.download(id = id, output = "data.npz")
  data = np.load("data.npz")
  mu, u = data['mu'].copy(), data['u'].copy()
  return mu, u

def get_fem_space(custom_mesh = None):
  if custom_mesh is None:
    mesh = fe.unitsquaremesh(40, 40)
  else:
    mesh = custom_mesh
  Vh = fe.space(mesh, 'CG', 1)
  y = fe.coordinates(Vh)
  return Vh, y

In [None]:
# Training
class Trainer:
  def __init__(self, train_loader, val_loader, test_loader, model, loss_function, optimizer, epochs, error_metric = None, verbose=False):
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.test_loader = test_loader
    self.model = model
    self.loss_function = loss_function
    self.error_metric = error_metric
    self.optimizer = optimizer
    self.epochs = epochs
    self.verbose = verbose

    self.train_hist = {"loss": [], "error_metric": []}
    self.val_hist = {"loss": [], "error_metric": []}
    self.test_hist = {"loss": [], "error_metric": []}


  def train_step(self):
    """
    Perform a single training step in the optimization loop
    """
    train_size = len(self.train_loader.dataset)
    num_batches = len(self.train_loader)
    batch_size =  train_loader.batch_size
    total_loss = 0.0  # Accumulate loss for reporting
    total_error_metric = 0.0

    self.model.train()
    for batch_idx, (mu, u, y) in enumerate(self.train_loader):

      # Predict
      Gu_pred = self.model(mu,y)
      loss = self.loss_function(Gu_pred, u) # The optimizer updates weights batch-by-batch in sgd so we don't need to accumolate it for final reporting
      total_loss += loss.item()
      if(self.error_metric is not None):
        metric = self.error_metric(Gu_pred, u)
        total_error_metric += metric.item()

      # Backpropagation
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

      # Print training stats
      if (batch_idx % 10 == 0):
        loss_value, current_sample = loss.item(), batch_idx * batch_size + len(mu)
        print(f"\rCurrent Loss: {loss_value:>7f}  [{current_sample:>5d}/{train_size:>5d}]", end="", flush=True)

    avg_loss = total_loss / num_batches  # Compute average loss for the epoch
    avg_error_metric = total_error_metric / num_batches


    return avg_loss, avg_error_metric # return the avg train loss at the end of the epoch

  def validation_step(self):
    """
    Perform a single validation step in the optimization loop
    """
    num_batches = len(self.val_loader)
    total_loss = 0.0
    total_error_metric = 0.0

    self.model.eval()
    with torch.no_grad():
      for batch_idx, (mu, u, y) in enumerate(self.val_loader):
        #Predict
        Gu_pred = self.model(mu,y)
        loss = self.loss_function(Gu_pred, u)
        total_loss += loss.item()
        if(self.error_metric is not None):
          metric = self.error_metric(Gu_pred, u)
          total_error_metric += metric.item()

      avg_loss = total_loss / num_batches
      avg_error_metric = total_error_metric / num_batches

      return avg_loss, avg_error_metric

  def test_step(self):
    """
    Perform a single test step in the optimization loop
    """
    num_batches = len(self.test_loader)
    total_loss = 0.0
    total_error_metric = 0.0

    self.model.eval()
    with torch.no_grad():
      for batch_idx, (mu, u, y) in enumerate(self.test_loader):
        #Predict
        Gu_pred = self.model(mu,y)
        loss = self.loss_function(Gu_pred, u)
        total_loss += loss.item()
        if(self.error_metric is not None):
          metric = self.error_metric(Gu_pred, u)
          total_error_metric += metric.item()

    avg_loss = total_loss / num_batches
    avg_error_metric = total_error_metric / num_batches

    return avg_loss, avg_error_metric

  def fit(self):
    """
    Perform the optimization loop over the specified number of epochs.
    """
    for epoch in range(self.epochs):
      print(f"Epoch {epoch+1}/{self.epochs}:")

      # Run train, validation, and test steps
      train_loss, train_err_metric = self.train_step()
      val_loss, val_err_metric = self.validation_step()
      test_loss, test_err_metric = self.test_step()

      # Save history
      self.train_hist["loss"].append(train_loss), self.val_hist["loss"].append(val_loss), self.test_hist["loss"].append(test_loss)
      if(self.error_metric is not None):
        self.train_hist["error_metric"].append(train_err_metric), self.val_hist["error_metric"].append(val_err_metric), self.test_hist["error_metric"].append(test_err_metric)

      # Print formatted progress
      print(f"\nTrain Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Test Loss: {test_loss:.4f}")
      if(self.error_metric is not None and self.verbose):
        print(f"Train Error Metric: {train_err_metric:.4f} | Val Error Metric: {val_err_metric:.4f} | Test Error Metric: {test_err_metric:.4f}")

    print("Done!")

    return self.train_hist, self.val_hist, self.test_hist

In [None]:
# Loss and Metric definitions
def mse_loss(true, predicted):
  return (true-predicted).pow(2).mean(axis = -1).mean()

def error_metric(true, predicted):
  return ((true-predicted).abs().mean(axis = -1)/true.abs().mean(axis = -1)).mean()

In [None]:
# Plotting Utilities
def plot_loss(trainer):
  n_epochs = trainer.epochs
  train_hist, val_hist, test_hist = trainer.train_hist['loss'], trainer.val_hist['loss'], trainer.test_hist['loss']
  plt.figure(figsize = (5, 3))
  plt.semilogx(train_hist, '-k', label = 'Train')
  plt.semilogx(val_hist, '--b', label = 'Validation')
  plt.semilogx(test_hist, '--r', label = 'Test')
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.axis([0.5, n_epochs, 0, 0.075])
  plt.legend()
  plt.show()

def plot_error_metric(trainer):
  n_epochs = trainer.epochs
  train_hist, val_hist, test_hist = trainer.train_hist['error_metric'], trainer.val_hist['error_metric'], trainer.test_hist['error_metric']
  plt.figure(figsize = (5, 3))
  plt.semilogx(train_hist, '-k', label = 'Train')
  plt.semilogx(val_hist, '--b', label = 'Validation')
  plt.semilogx(test_hist, '--r', label = 'Test')
  plt.xlabel('Epochs')
  plt.ylabel('MRE')
  plt.axis([0.5, n_epochs, 0, 0.075])
  plt.legend()
  plt.show()

In [None]:
def plot_errors(trainer, ymax_loss, ymax_metric):
    n_epochs = trainer.epochs

    # Plot Loss
    plt.figure(figsize=(10, 4))  # Wider figure for side-by-side plots

    plt.subplot(1, 2, 1)
    train_loss = trainer.train_hist['loss']
    val_loss = trainer.val_hist['loss']
    test_loss = trainer.test_hist['loss']
    plt.semilogx(train_loss, '-k', label='Train')
    plt.semilogx(val_loss, '--b', label='Validation')
    plt.semilogx(test_loss, '--r', label='Test')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss')
    plt.axis([0.5, n_epochs, 0, ymax_loss])
    plt.legend()

    # Plot Error Metric if available
    if trainer.error_metric is not None:
        plt.subplot(1, 2, 2)
        train_metric = trainer.train_hist['error_metric']
        val_metric = trainer.val_hist['error_metric']
        test_metric = trainer.test_hist['error_metric']
        plt.semilogx(train_metric, '-k', label='Train')
        plt.semilogx(val_metric, '--b', label='Validation')
        plt.semilogx(test_metric, '--r', label='Test')
        plt.xlabel('Epochs')
        plt.ylabel('MRE')
        plt.title('Error Metric')
        plt.axis([0.5, n_epochs, 0, ymax_metric])
        plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
def plot_solution(mu, u, Vh):
  clc()
  fe.plot(u, Vh)
  plt.title("Solution for $\mu$ = [%.2f,%.2f,%.2f]" % tuple(mu))
  None

def add_batch_dimension(test_sample):
    # Add batch dimension to the test_sample
    mu, true_u, y = test_sample
    mu = mu.unsqueeze(0)
    true_u = true_u.unsqueeze(0)
    y = y.unsqueeze(0)
    return mu, true_u, y

def compare_solutions(test_sample, Vh):
    # Add batch dimension
    mu, true_u, y = add_batch_dimension(test_sample)

    # Get the predicted solution
    pred_u = model(mu, y).squeeze(0)  # Remove batch dimension for plotting

    # Ensure the solutions are numpy arrays
    true_u = true_u.squeeze(0).detach().numpy()  # Convert to numpy array
    pred_u = pred_u.detach().numpy()  # Convert to numpy array

    # Plot the true and predicted solutions
    clc()  # If you are using Matlab-like syntax for clearing console output (optional)
    plt.figure(figsize=(10, 5))

    vmin, vmax = true_u.min(), true_u.max()

    plt.subplot(1, 2, 1)
    plt.title("True solution")
    fe.plot(true_u, Vh, colorbar=True, vmin=vmin, vmax=vmax, shrink=0.7)

    plt.subplot(1, 2, 2)
    plt.title("DeepONet approximation")
    fe.plot(pred_u, Vh, colorbar=True, vmin=vmin, vmax=vmax, shrink=0.7)

    plt.show()

In [None]:
def compare_solutions_scalar(test_sample):
  # Add batch dimension
  mu, true_u, y = add_batch_dimension(test_sample)

  # Get the predicted solution
  pred_u = model(mu, y).squeeze(0)  # Remove batch dimension for plotting

  # Ensure the solutions are numpy arrays
  true_u = true_u.squeeze(0).detach().numpy()  # Convert to numpy array
  pred_u = pred_u.detach().numpy()  # Convert to numpy array

  # Plot the true and predicted solutions
  clc()  # If you are using Matlab-like syntax for clearing console output (optional)
  plt.figure(figsize=(10, 5))

  plt.subplot(1, 2, 1)
  plt.plot(xsens, mu.squeeze(0), '.', label = 'sensors')
  plt.title("Input $f$")
  plt.legend()

  plt.subplot(1, 2, 2)
  plt.plot(ygrid, true_u, color = 'orange', label = 'True')
  plt.plot(ygrid, pred_u, '--r', label = 'DeepONet')
  plt.title("Output $u$")
  plt.legend()
  None