In [32]:
import lightning as L
import numpy
import torch
import torch.nn as nn
from torchdiffeq import odeint
from torch.utils.data import Dataset, DataLoader

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [60]:
# maps x,t,mu to u (pde solution)
def burgers_exact_eqn(x, t, mu):
    pi = torch.pi
    e1 = torch.exp(-(pi**2)*t/mu)
    e4 = torch.exp(-(4*pi**2)*t/mu)
    
    num = 0.25 * e1 * torch.sin(pi*x) + e4 * torch.sin(2*pi*x)
    den = 1.0 + 0.25 * e1 * torch.cos(pi*x) + 0.5 * e4 * torch.cos(2*pi*x)

    return (2*pi/mu)*(num/den)

In [61]:
def generate_burgers_solution_grid(mu_values, n_x, n_t, T_final=1.0):
    # create range of values as vectors
    x_axis = torch.linspace(0.0, 2.0, n_x)          # (nx,)
    t_axis = torch.linspace(0.0, T_final, n_t + 1)  # (nt+1,)
    mu_axis = torch.tensor(mu_values)                # (nmu,)

    # create full grid across all space, time, parameters
    X_grid = x_axis[None, None, :].expand(mu_axis.shape[0], t_axis.shape[0], x_axis.shape[0])
    T_grid = t_axis[None, :, None].expand_as(X_grid)
    Mu_grid = mu_axis[:, None, None].expand_as(X_grid)

    # evaluate solution on full grid
    u_grid = burgers_exact_eqn(X_grid, T_grid, Mu_grid)

    # enforce boundary conditions
    u_grid[:, :, 0]  = 0.0
    u_grid[:, :, -1] = 0.0

    return x_axis, t_axis, mu_axis, u_grid

In [62]:
# samples (x, t, mu) with target u from the solution grid
class BurgersExactDataset(Dataset):
    def __init__(self, x_axis, t_axis, mu_axis, u_grid, n_samples=200000):
        super().__init__()
        # coordinate axes
        self.x_axis = x_axis
        self.t_axis = t_axis
        self.mu_axis = mu_axis

        # solution field
        self.u_grid = u_grid

        # grid sizes
        self.n_mu, self.n_tp1, self.n_x = u_grid.shape

        self.n_samples = int(n_samples)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        # sample random axis indices
        i_mu = torch.randint(0, self.n_mu, (1,)).item()
        i_t = torch.randint(1, self.n_tp1, (1,)).item()
        i_x = torch.randint(0, self.n_x, (1,)).item()

        # coordinates at those indices
        x_coord = self.x_axis[i_x]
        t_coord = self.t_axis[i_t]
        mu_val = self.mu_axis[i_mu]

        # target solution value
        y = self.u_grid[i_mu, i_t, i_x]

        return x_coord, t_coord, mu_val, y

In [63]:
class BurgersPhysicsDataset(Dataset):
    def __init__(self, mus, n_samples):
        super().__init__()
        self.mus = torch.tensor(mus, device=device)
        self.n_mu = len(self.mus)
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        i_mu = torch.randint(0, self.n_mu, (1,)).item()
        mu_val = self.mus[i_mu]

        eps = 1e-6
        x_coord = torch.clamp(torch.rand(())*2.0, eps, 2.0-eps)
        t_coord = torch.clamp(torch.rand(()), eps, 1.0)

        return x_coord, t_coord, mu_val

In [64]:
mu_train = [20, 30]
mu_test = [15, 25]
mu_all = sorted(mu_train+mu_test)
n_x = 16
n_t = 100
T_f = 1.0
x_axis, t_axis, mu_axis, u_grid = generate_burgers_solution_grid(mu_train, n_x, n_t, T_f)

In [65]:
exact_train_dataset = BurgersExactDataset(x_axis, t_axis, mu_axis, u_grid, n_samples=1000)
exact_train_loader = DataLoader(exact_train_dataset, batch_size=32, shuffle=True)

physics_train_dataset = BurgersPhysicsDataset(mu_all, n_samples=1000)
physics_train_loader = DataLoader(physics_train_dataset, batch_size=32, shuffle=True)

In [66]:
# pass inputs through fourier features
class FourierFeatures(nn.Module):
    def __init__(self, n_freqs=16, max_freq=10.0):
        super().__init__()
        freqs = torch.linspace(1.0, max_freq, n_freqs)
        self.register_buffer("freqs", freqs)

    def forward(self, x): 
        if x.dim() == 1:
            x = x[:, None]
        w = x * self.freqs[None, :] * torch.pi # (B, n_freq)
        return torch.cat([torch.sin(w), torch.cos(w)], dim=-1) # (B, 2*n_freq)

In [78]:
# decoder that takes in x and alpha and outputs u
class Decoder(nn.Module):
    def __init__(self, latent_dim=10, n_freqs=16, max_freq=10.0, hidden=128, n_hidden_layers=3):
        super().__init__()
        self.ff = FourierFeatures(n_freqs=n_freqs, max_freq=max_freq)
        dim = 2 * n_freqs + latent_dim
        
        layers = []
        for _ in range(n_hidden_layers):
            layers += [nn.Linear(dim, hidden), nn.Tanh()]
            dim = hidden
        layers += [nn.Linear(dim, 1)]
        
        self.model = nn.Sequential(*layers)

    def forward(self, x, alpha):
        phi_x = self.ff(x) # (B, 2*n_freq)
        inputs = torch.cat([phi_x, alpha], dim=-1)
        u = self.model(inputs) # (B, 1)
        return u.squeeze(-1)

In [79]:
# parameterized neural ode that takes in mu, alpha and t and outputs time derivative of alpha
class PNODEFunc(nn.Module):
    def __init__(self, latent_dim=10, hidden=128, n_hidden_layers=2):
        super().__init__()
        dim = latent_dim + 2 # alpha + t + mu

        layers = []
        for _ in range(n_hidden_layers):
            layers += [nn.Linear(dim, hidden), nn.Tanh()]
            dim = hidden
        layers += [nn.Linear(dim, latent_dim)]

        self.model = nn.Sequential(*layers)

    def forward(self, alpha, t, mu):
        B = alpha.shape[0]
        t_col = t.expand(B, 1)
        mu_col = mu.view(B, 1)
        inputs = torch.cat([alpha, t_col, mu_col], dim=-1)
        return self.model(inputs)

In [80]:
class PNODE(nn.Module):
    def __init__(self, func: PNODEFunc, latent_dim=10):
        super().__init__()
        self.func = func
        self.latent_dim = latent_dim

    def solve_alpha(self, t_eval, mu, alpha0=None, method="rk4"):
        B = mu.shape[0]
        
        # initalize alpha
        if alpha0 is None:
            alpha0 = torch.zeros(B, self.latent_dim, device=mu.device)

        def f_wrapped(t, alpha):
            return self.func(alpha, t, mu)

        alpha_trajectory = odeint(f_wrapped, alpha0, t_eval, method=method)
        return alpha_trajectory

In [81]:
def phi_xt(x, t):
    return x * (2.0-x) * t
    
def u_constrained(decoder, x, t, alpha):
    return phi_xt(x, t) * decoder(x, alpha)

In [82]:
def data_loss(model, batch):
    x, t, mu, y = batch
    pred = model(x, t, mu)
    return torch.mean((pred-y)**2)

def physics_loss(model, batch):
    x, t, mu = batch
    x = x.clone().detach().requires_grad_(True)
    t = t.clone().detach().requires_grad_(True)
    mu = mu.clone().detach()

    u = model(x, t, mu)  

    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
    u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    
    residual = u_t + u * u_x - (1.0/mu) * u_xx

    return torch.mean(residual**2)

In [83]:
class CNFROM(nn.Module):
    def __init__(self, decoder: Decoder, pnode: PNODE):
        super().__init__()
        self.decoder = decoder
        self.pnode = pnode

    def forward(self, x, t, mu):
        B = x.shape[0]
        u_out = []
        for i in range(B):
            ti  = t[i:i+1]
            mui = mu[i:i+1]
            t_eval = torch.cat([torch.zeros_like(ti), ti], dim=0)
            alpha_trajectory = self.pnode.solve_alpha(t_eval, mui)  
            alpha_t = alpha_trajectory[-1, 0, :].unsqueeze(0)
            ui = u_constrained(self.decoder, x[i:i+1], ti, alpha_t)
            u_out.append(ui)
        return torch.cat(u_out, dim=0)

In [84]:
class Model(L.LightningModule):
    def __init__(self, model, lr=1e-3, mode="data"):
        super().__init__()
        self.model = model
        self.lr = lr
        self.mode = mode

    def training_step(self, batch, batch_idx):
        if self.mode == "data":
            loss = data_loss(self.model, batch) 
        else:
            loss = physics_loss(self.model, batch)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [85]:
latent_dim = 10
max_epochs = 10
lr = 1e-3

In [86]:
decoder = Decoder(latent_dim=latent_dim)
pnode_func = PNODEFunc(latent_dim=latent_dim)
pnode = PNODE(pnode_func, latent_dim=latent_dim)
cnf = CNFROM(decoder, pnode)
model = Model(cnf, mode="data") 

In [54]:
data_trainer = L.Trainer(accelerator="gpu", devices=1, max_epochs=max_epochs)
data_trainer.fit(model, exact_train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | CNFROM | 58.1 K | train
-----------------------------------------
58.1 K    Trainable params
0         Non-trainable params
58.1 K    Total params
0.232     Total estimated model params size (MB)
19        Modules in train mode
0         Modules in eval mode
/home/kendra/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/kendra/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (32) is smaller than the logging interval Trainer(

Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [31]:
for p in model.model.decoder.parameters():
    p.requires_grad = False
model.mode = "physics"
physics_trainer = L.Trainer(accelerator="gpu", devices=1, max_epochs=max_epochs)
physics_trainer.fit(model, physics_train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | CNFROM | 58.1 K | train
-----------------------------------------
19.5 K    Trainable params
38.7 K    Non-trainable params
58.1 K    Total params
0.232     Total estimated model params size (MB)
19        Modules in train mode
0         Modules in eval mode
/home/kendra/miniconda3/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/kendra/miniconda3/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...

KeyboardInterrupt

