In [1]:
# PyTorch
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Scheduler - OneCycleLR, CosineAnnealingLR
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingLR

# PyTorch Lightning
import lightning as L

# wandb
import wandb

import optuna
from optuna.samplers import TPESampler
from optuna.integration import WeightsAndBiasesCallback
from optuna.visualization import plot_optimization_history, plot_param_importances

# Split the data into training and test sets
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import numpy as np
import polars as pl

In [2]:
import os
os.environ["WANDB_SILENT"] = "true"

import warnings
warnings.filterwarnings('ignore')

In [3]:
L.seed_everything(42)

Seed set to 42


42

In [4]:
df_grf = pl.read_parquet("../data/grf_random_l.parquet")
df_grf_int = pl.read_parquet("../data/grf_random_l_int.parquet")

In [5]:
n_samples = df_grf["group"].n_unique()
n_samples

10000

In [6]:
print(df_grf, df_grf_int)

shape: (10_000_000, 3)
┌───────┬───────────┬───────┐
│ x     ┆ grf       ┆ group │
│ ---   ┆ ---       ┆ ---   │
│ f64   ┆ f64       ┆ u64   │
╞═══════╪═══════════╪═══════╡
│ 0.0   ┆ 0.680026  ┆ 0     │
│ 0.001 ┆ 0.678594  ┆ 0     │
│ 0.002 ┆ 0.67716   ┆ 0     │
│ 0.003 ┆ 0.675726  ┆ 0     │
│ 0.004 ┆ 0.674292  ┆ 0     │
│ …     ┆ …         ┆ …     │
│ 0.996 ┆ -0.895085 ┆ 9999  │
│ 0.997 ┆ -0.89404  ┆ 9999  │
│ 0.998 ┆ -0.89299  ┆ 9999  │
│ 0.999 ┆ -0.891941 ┆ 9999  │
│ 1.0   ┆ -0.890879 ┆ 9999  │
└───────┴───────────┴───────┘ shape: (1_000_000, 3)
┌──────┬───────────┬───────┐
│ y    ┆ grf_int   ┆ group │
│ ---  ┆ ---       ┆ ---   │
│ f64  ┆ f64       ┆ u64   │
╞══════╪═══════════╪═══════╡
│ 0.0  ┆ 0.0       ┆ 0     │
│ 0.01 ┆ 0.006729  ┆ 0     │
│ 0.02 ┆ 0.013313  ┆ 0     │
│ 0.03 ┆ 0.019753  ┆ 0     │
│ 0.04 ┆ 0.026047  ┆ 0     │
│ …    ┆ …         ┆ …     │
│ 0.96 ┆ -0.390887 ┆ 9999  │
│ 0.97 ┆ -0.400144 ┆ 9999  │
│ 0.98 ┆ -0.409306 ┆ 9999  │
│ 0.99 ┆ -0.418369 ┆ 9999  │
│ 1.0  ┆ -

In [7]:
df_grf = df_grf.filter(pl.col("x").is_in([round(x * 0.01, 2) for x in range(101)]))
print(df_grf)

shape: (1_000_000, 3)
┌──────┬───────────┬───────┐
│ x    ┆ grf       ┆ group │
│ ---  ┆ ---       ┆ ---   │
│ f64  ┆ f64       ┆ u64   │
╞══════╪═══════════╪═══════╡
│ 0.0  ┆ 0.680026  ┆ 0     │
│ 0.01 ┆ 0.665668  ┆ 0     │
│ 0.02 ┆ 0.651229  ┆ 0     │
│ 0.03 ┆ 0.636693  ┆ 0     │
│ 0.04 ┆ 0.622044  ┆ 0     │
│ …    ┆ …         ┆ …     │
│ 0.96 ┆ -0.93019  ┆ 9999  │
│ 0.97 ┆ -0.920982 ┆ 9999  │
│ 0.98 ┆ -0.911356 ┆ 9999  │
│ 0.99 ┆ -0.901277 ┆ 9999  │
│ 1.0  ┆ -0.890879 ┆ 9999  │
└──────┴───────────┴───────┘


In [8]:
x = df_grf.filter(pl.col("group") == 0)["x"].to_numpy()
y = df_grf_int.group_by("group", maintain_order=True).agg(pl.col("y"))["y"].explode().to_numpy().reshape(n_samples, -1)
grfs = df_grf.group_by("group", maintain_order=True).agg(pl.col("grf"))["grf"].explode().to_numpy().reshape(n_samples, -1)
grf_ints = df_grf_int.group_by("group", maintain_order=True).agg(pl.col("grf_int"))["grf_int"].explode().to_numpy().reshape(n_samples, -1)

y = y.astype(np.float32)
grfs = grfs.astype(np.float32)
grf_ints = grf_ints.astype(np.float32)

print(f"x: {x.shape}, y: {y.shape}")
print(f"grfs: {grfs.shape}, grf_ints: {grf_ints.shape}")

x: (100,), y: (10000, 100)
grfs: (10000, 100), grf_ints: (10000, 100)


## DeepONet from Scratch

$$
G: u \in C[\mathcal{D}] \rightarrow G(u) \in C[\mathcal{R}] \quad \text{where } \mathcal{D}, \mathcal{R} \text{ are compact}
$$
$$
u(x) \overset{G}{\longrightarrow} G(u)(y) = \int_0^y u(x) dx
$$

In [9]:
n_train = int(0.8 * n_samples)
n_val = int(0.1 * n_samples)
n_test = n_samples - n_train - n_val

grf_train = grfs[:n_train]
grf_val = grfs[n_train:n_train + n_val]
grf_test = grfs[n_train + n_val:]

y_train = y[:n_train]
y_val = y[n_train:n_train + n_val]
y_test = y[n_train + n_val:]

grf_int_train = grf_ints[:n_train]
grf_int_val = grf_ints[n_train:n_train + n_val]
grf_int_test = grf_ints[n_train + n_val:]

In [10]:
class IntegralData(Dataset):
    def __init__(self, grf, y, grf_int):
        self.grf = torch.tensor(grf)
        self.y = torch.tensor(y)
        self.grf_int = torch.tensor(grf_int)

    def __len__(self):
        return len(self.grf)

    def __getitem__(self, idx):
        return self.grf[idx], self.y[idx], self.grf_int[idx]

In [11]:
ds_train = IntegralData(grf_train, y_train, grf_int_train)
ds_val = IntegralData(grf_val, y_val, grf_int_val)
ds_test = IntegralData(grf_test, y_test, grf_int_test)

In [12]:
class DeepONetScratch(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        
        num_input = hparams["num_input"]
        num_branch = hparams["num_branch"]
        num_output = hparams["num_output"]
        dim_output = hparams["dim_output"]
        hidden_size = hparams["hidden_size"]
        hidden_depth = hparams["hidden_depth"]

        branch_net = [nn.Linear(num_input, hidden_size), nn.GELU()]
        for _ in range(hidden_depth-1):
            branch_net.append(nn.Linear(hidden_size, hidden_size))
            branch_net.append(nn.GELU())
        branch_net.append(nn.Linear(hidden_size, num_branch))
        self.branch_net = nn.Sequential(*branch_net)

        trunk_net = [nn.Linear(dim_output, hidden_size), nn.GELU()]
        for _ in range(hidden_depth-1):
            trunk_net.append(nn.Linear(hidden_size, hidden_size))
            trunk_net.append(nn.GELU())
        trunk_net.append(nn.Linear(hidden_size, num_branch))
        self.trunk_net = nn.Sequential(*trunk_net)
        
        self.bias = nn.Parameter(torch.randn(1), requires_grad=True)

    def forward(self, u, y):
        l = y.shape[1]
        branch_out = self.branch_net(u)
        trunk_out = torch.stack([self.trunk_net(y[:, i:i+1]) for i in range(l)], dim=2)
        pred = torch.einsum("bp,bpl->bl", branch_out, trunk_out) + self.bias
        return pred

In [13]:
def train_epoch(model, optimizer, dataloader, device):
    model = model.to(device)
    model.train()
    epoch_loss = 0
    for u, y, Guy in dataloader:
        u, y, Guy = u.to(device), y.to(device), Guy.to(device)
        optimizer.zero_grad()
        pred = model(u, y)
        loss = F.mse_loss(pred, Guy)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= len(dataloader)
    return epoch_loss

In [14]:
def evaluate(model, test_loader, device):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for u, y, Guy in test_loader:
            u, y, Guy = u.to(device), y.to(device), Guy.to(device)
            pred = model(u, y)
            loss = F.mse_loss(pred, Guy)
            test_loss += loss.item()
    test_loss /= len(test_loader)
    return test_loss

In [15]:
dl_train = DataLoader(ds_train, batch_size=500, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=500)
dl_test = DataLoader(ds_test, batch_size=500)

## Optuna for hyperparameter tuning

In [16]:
class WandbOptunaCallback(WeightsAndBiasesCallback):
    def __init__(self, metric_name):
        super().__init__(metric_name)

    def after_trial(self, study, trial):
        # Log optimization history
        fig_history = plot_optimization_history(study)
        wandb.log({"optimization_history": wandb.Image(fig_history)})

        # Log parameter importances
        fig_importances = plot_param_importances(study)
        wandb.log({"param_importances": wandb.Image(fig_importances)})

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

def objective(trial):
    hparams = {
        "num_input": 100,
        "num_branch": trial.suggest_categorical("num_branch", [10, 20, 30, 40]),
        "num_output": 100,
        "dim_output": 1,
        "hidden_size": trial.suggest_categorical("hidden_size", [40, 80, 120, 160]),
        "hidden_depth": trial.suggest_int("hidden_depth", 2, 4),
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-4, 1e-2),
        "batch_size": 500,
        "epochs": 200
    }
    L.seed_everything(42)
    model = DeepONetScratch(hparams)
    
    run = wandb.init(project="DeepONet-Optuna", config=hparams, group="Optuna1", reinit=True)
    
    optimizer = optim.Adam(model.parameters(), lr=hparams["learning_rate"])
    scheduler = OneCycleLR(optimizer, max_lr=hparams["learning_rate"], epochs=hparams["epochs"], steps_per_epoch=len(dl_train) // hparams["batch_size"] + 1)
    
    for epoch in range(hparams["epochs"]):
        train_loss = train_epoch(model, optimizer, dl_train, device) 
        val_loss = evaluate(model, dl_val, device)
        scheduler.step()
        
        trial.report(val_loss, epoch)
        wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch+1})
        
        if trial.should_prune():
            raise optuna.TrialPruned()
            
    test_loss = evaluate(model, dl_test, device)
    wandb.log({"test_loss": test_loss})
    
    run.finish()

    return val_loss

cuda


In [18]:
sampler = TPESampler(seed=42)

study = optuna.create_study(direction="minimize", sampler=sampler)
study.optimize(objective, n_trials=100, callbacks=[WandbOptunaCallback(metric_name="val_loss")])

[I 2024-03-25 07:20:51,369] A new study created in memory with name: no-name-b7de3fc1-473c-4ab2-a0bd-0164717344a6
Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115388066456136, max=1.0…

[I 2024-03-25 07:22:10,845] Trial 0 finished with value: 3.916575951734558e-05 and parameters: {'num_branch': 20, 'hidden_size': 160, 'hidden_depth': 3, 'learning_rate': 0.0026070247583707684}. Best is trial 0 with value: 3.916575951734558e-05.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011116839200258254, max=1.0…

Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115994911071741, max=1.0…

[I 2024-03-25 07:23:34,107] Trial 1 finished with value: 0.0001559476731927134 and parameters: {'num_branch': 20, 'hidden_size': 160, 'hidden_depth': 3, 'learning_rate': 0.0003823475224675188}. Best is trial 0 with value: 3.916575951734558e-05.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112630878212966, max=1.0…

Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011116905166353616, max=1.0…

[I 2024-03-25 07:24:51,449] Trial 2 finished with value: 0.002596413716673851 and parameters: {'num_branch': 10, 'hidden_size': 80, 'hidden_depth': 3, 'learning_rate': 0.0001238513729886094}. Best is trial 0 with value: 3.916575951734558e-05.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112604499794542, max=1.0…

Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111226150015783, max=1.0)…

[I 2024-03-25 07:26:22,032] Trial 3 finished with value: 0.0004765285848407075 and parameters: {'num_branch': 40, 'hidden_size': 40, 'hidden_depth': 4, 'learning_rate': 0.0007591104805282694}. Best is trial 0 with value: 3.916575951734558e-05.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112525345136721, max=1.0…

Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111643024487421, max=1.0)…

[I 2024-03-25 07:27:39,268] Trial 4 finished with value: 0.0005009327578591183 and parameters: {'num_branch': 40, 'hidden_size': 80, 'hidden_depth': 3, 'learning_rate': 0.00023426581058204064}. Best is trial 0 with value: 3.916575951734558e-05.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112617688357002, max=1.0…

Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115559566921245, max=1.0…

[I 2024-03-25 07:27:47,664] Trial 5 pruned. 
Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115585945339667, max=1.0…

[I 2024-03-25 07:28:50,415] Trial 6 finished with value: 0.0018284181714989245 and parameters: {'num_branch': 30, 'hidden_size': 160, 'hidden_depth': 2, 'learning_rate': 0.009413993046829943}. Best is trial 0 with value: 3.916575951734558e-05.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111479441040299, max=1.0)…

Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113884144773085, max=1.0…

[I 2024-03-25 07:28:58,791] Trial 7 pruned. 
Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011116548978154443, max=1.0…

[I 2024-03-25 07:29:05,459] Trial 8 pruned. 
Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111971266412486, max=1.0…

[I 2024-03-25 07:29:12,065] Trial 9 pruned. 
Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115572756777208, max=1.0…

[I 2024-03-25 07:29:18,657] Trial 10 pruned. 
Seed set to 42


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115546366717253, max=1.0…

[W 2024-03-25 07:29:36,450] Trial 11 failed with parameters: {'num_branch': 20, 'hidden_size': 160, 'hidden_depth': 3, 'learning_rate': 0.0024442522440721454} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/kavis/Documents/Project/Machine_Learning/DeepONet_from_scratch/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_85570/1217486451.py", line 25, in objective
    train_loss = train_epoch(model, optimizer, dl_train, device)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_85570/1984425491.py", line 8, in train_epoch
    pred = model(u, y)
           ^^^^^^^^^^^
  File "/home/kavis/Documents/Project/Machine_Learning/DeepONet_from_scratch/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **k

KeyboardInterrupt: 

In [None]:
print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

In [None]:
# Visualize the optimization history
optuna.visualization.plot_optimization_history(study)

In [None]:
# Visualize the parameter importances
optuna.visualization.plot_param_importances(study)