# PINN Burgers' Equation - Func vs Loss
Autor: Taco de Wolff\
Fecha: 24 mayo 2021

In [2]:
import wandb
import numpy as np
import scipy.io
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

wandb.login();

[34m[1mwandb[0m: Currently logged in as: [33moceania[0m (use `wandb login --relogin` to force relogin)


In [5]:
data = scipy.io.loadmat('../Data/Burgers/burgers_shock.mat')
x = np.tile(data['x'], (data['t'].shape[0],1)) # TN x 1
t = np.repeat(data['t'], data['x'].shape[0], axis=0) # TN x 1
X = np.concatenate([x,t], axis=1) # TN x 2
Y = data['usol'].T.reshape(-1,1) # TN x 1

print("x.shape:", data['x'].shape)
print("t.shape:", data['t'].shape)
print("u.shape:", data['usol'].shape)
print("X.shape:", X.shape)
print("Y.shape:", Y.shape)

print("X: %s ± %s" % (X.mean(axis=0), X.std(axis=0)))
print("Y: %s ± %s" % (Y.mean(axis=0), Y.std(axis=0)))

x.shape: (256, 1)
t.shape: (100, 1)
u.shape: (256, 100)
X.shape: (25600, 2)
Y.shape: (25600, 1)
X: [0.    0.495] ± [0.57960997 0.2886607 ]
Y: [6.42819131e-16] ± [0.61433751]


In [3]:
class PINN_Burgers(nn.Module):
    def __init__(self, layer_width=20, layer_depth=8,
                 activation_function='tanh', initializer='none'): 
        super().__init__()
        
        input_width = 2
        output_width = 1
        
        self.lambda1 = 1.0
        self.lambda2 = 0.01/np.pi
             
        sizes = [input_width] + [layer_width]*layer_depth + [output_width]
        self.net = nn.Sequential(
            *[self.block(dim_in, dim_out, activation_function)
            for dim_in, dim_out in zip(sizes[:-1], sizes[1:-1])],
            nn.Linear(sizes[-2], sizes[-1]) # output layer is regular linear transformation
        )
        
        if initializer == 'xavier':
            def init_weights(m):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight)
            self.net.apply(init_weights)
                
    def forward(self, x):
        return self.net.forward(x)
    
    def block(self, dim_in, dim_out, activation_function):
        activation_functions = nn.ModuleDict([
            ['lrelu', nn.LeakyReLU()],
            ['relu', nn.ReLU()],
            ['tanh', nn.Tanh()],
            ['sigmoid', nn.Sigmoid()],
            ['softplus', nn.Softplus()],
            ['softsign', nn.Softsign()],
            ['tanhshrink', nn.Tanhshrink()],
            ['celu', nn.CELU()],
            ['gelu', nn.GELU()],
            ['elu', nn.ELU()],
            ['selu', nn.SELU()],
            ['logsigmoid', nn.LogSigmoid()]
        ])
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            activation_functions[activation_function],
        )
    
    def f(self, x, t, u):
        u_t  = grad(u,   t, create_graph=True, grad_outputs=torch.ones_like(u))[0]
        u_x  = grad(u,   x, create_graph=True, grad_outputs=torch.ones_like(u))[0]
        u_xx = grad(u_x, x, create_graph=True, grad_outputs=torch.ones_like(u_x))[0]
        return u_t + self.lambda1*u*u_x - self.lambda2*u_xx
        
    def loss(self, Xu, Yu, Xf=None):
        losses = []
        losses.append(F.mse_loss(self.forward(Xu), Yu))
        
        if Xf is not None:
            Xf.requires_grad=True
            x = Xf[:,0]
            t = Xf[:,1]
            Xf = torch.stack((x,t),1)
            Y_hat = self.forward(Xf)
            u = Y_hat[:,0]        
            f = self.f(x, t, u)        
            losses.append(F.mse_loss(f, torch.zeros_like(f)))
        return losses

In [8]:
# parameters
project = 'burgers'
epochs = 10000
N = X.shape[0]

torch.manual_seed(2021)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

rng = np.random.default_rng(2021)
seeds = rng.integers(1000, size=5)

# sweep and train
sweep_config = {
    'project': project,
    'method': 'grid',  
    'parameters': {
        'alpha': {
            'values': [0.0, 0.1]
        },
        'Nu': {
            'values': [200]
        },
        'Nf': {
            'values': [12800]
        },
        'seed': {
            'values': [0,1,2,3,4]
        },
        'func': {
            'values': ['tanh', 'softplus', 'relu', 'sigmoid', 'logsigmoid', 'celu', 'gelu', 'softsign', 'tanhshrink']
        }
    }    
}
sweep_id = wandb.sweep(sweep_config)

# train
models = []
def model_train():
    run = wandb.init()
    config = wandb.config
    rng = np.random.default_rng(seeds[config.seed])
    name = 'burgers_A1s_a%g_nu%g_nf%g_s%d_%s' % (config.alpha, config.Nu, config.Nf, config.seed, config.func)
    
    # data
    Nu = int(config.Nu)
    Xu_idx = rng.choice(X.shape[0], Nu, replace=False)
    Xu = X[Xu_idx,:]
    Yu = Y[Xu_idx,:]

    Nf = int(config.Nf)
    Xf_idx = rng.choice(X.shape[0], Nf, replace=False)
    Xf = X[Xf_idx,:]

    print("Xu.shape:", Xu.shape)
    print("Yu.shape:", Yu.shape)
    print("Xf.shape:", Xf.shape)

    Xu = torch.tensor(Xu, dtype=torch.float, device=device)
    Yu = torch.tensor(Yu, dtype=torch.float, device=device)
    Xf = torch.tensor(Xf, dtype=torch.float, device=device)
    Xval = torch.tensor(X, dtype=torch.float, device=device)
    Yval = torch.tensor(Y, dtype=torch.float, device=device)
  
    # model
    model = PINN_Burgers(layer_width=20,
                         layer_depth=8,
                         activation_function=config.func,
                         initializer='xavier',
                        )
    model.to(device)    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            
    # training
    val_data_losses = np.array([])
    start_time = time.time()
    for epoch in range(epochs):
        if config.alpha != 0.0:
            losses = model.loss(Xu, Yu, Xf)
            train_data_loss = (1.0-config.alpha)*losses[0]
            phys_loss = config.alpha*losses[1]
            loss = train_data_loss + phys_loss
        else:
            losses = model.loss(Xu, Yu, None)
            train_data_loss = losses[0]
            phys_loss = torch.tensor(0.0)
            loss = train_data_loss
        wandb.log({'Data loss (training)': train_data_loss.detach().item()}, step=epoch)
        wandb.log({'Physics loss': phys_loss.detach().item()}, step=epoch)
        wandb.log({'Loss (training)': loss.item()}, step=epoch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses = model.loss(Xval, Yval, None)
        val_data_loss = losses[0]
        val_loss = val_data_loss + phys_loss      
        wandb.log({'Data loss (validation)': val_data_loss.item()}, step=epoch)
        wandb.log({'Loss (validation)': val_loss.item()}, step=epoch)
        
        val_data_losses = np.append(val_data_losses, val_data_loss.item())
        if 250 <= len(val_data_losses):
            lowest_val_data_loss = np.min(val_data_losses[-250:])
        else:
            lowest_val_data_loss = np.min(val_data_losses)
        wandb.log({'Data loss lowest (validation)': lowest_val_data_loss}, step=epoch)
           
        if epoch % 1000 == 0:            
            elapsed = time.time() - start_time
            print('Epoch: %d, Loss: %.3e, Time: %.2fs' % 
                      (epoch, val_data_loss, elapsed))
            start_time = time.time()

    torch.save(model.state_dict(), name + '.pth')
    models.append(model)

wandb.agent(sweep_id, project=project, function=model_train);

Device: cpu
Create sweep with ID: h4ya872h
Sweep URL: https://wandb.ai/oceania/burgers/sweeps/h4ya872h


[34m[1mwandb[0m: Agent Starting Run: wxp1wp74 with config:
[34m[1mwandb[0m: 	Nf: 12800
[34m[1mwandb[0m: 	Nu: 200
[34m[1mwandb[0m: 	alpha: 0.47599660098671426
[34m[1mwandb[0m: 	warmup: 0


Xu.shape: (200, 2)
Yu.shape: (200, 1)
Xf.shape: (12800, 2)
Epoch: 0, Loss: 3.749e-01, Time: 0.20s
Epoch: 100, Loss: 1.256e-01, Time: 11.95s
Epoch: 200, Loss: 9.301e-02, Time: 14.56s
Epoch: 300, Loss: 7.316e-02, Time: 17.42s
Epoch: 400, Loss: 2.473e-02, Time: 17.50s
Epoch: 500, Loss: 1.827e-02, Time: 14.88s
Epoch: 600, Loss: 2.174e-02, Time: 14.27s
Epoch: 700, Loss: 2.012e-02, Time: 13.00s
Epoch: 800, Loss: 2.021e-02, Time: 13.41s
Epoch: 900, Loss: 1.944e-02, Time: 13.22s
Epoch: 1000, Loss: 1.938e-02, Time: 12.75s
Epoch: 1100, Loss: 1.414e-02, Time: 12.50s
Epoch: 1200, Loss: 1.392e-02, Time: 12.74s
Epoch: 1300, Loss: 1.770e-02, Time: 12.59s
Epoch: 1400, Loss: 1.701e-02, Time: 15.19s
Epoch: 1500, Loss: 1.642e-02, Time: 13.39s
Epoch: 1600, Loss: 1.519e-02, Time: 13.08s
Epoch: 1700, Loss: 1.503e-02, Time: 13.70s
Epoch: 1800, Loss: 1.395e-02, Time: 14.15s
Epoch: 1900, Loss: 1.381e-02, Time: 14.27s
Epoch: 2000, Loss: 1.284e-02, Time: 16.16s
Epoch: 2100, Loss: 1.293e-02, Time: 14.09s
Epoch: 2

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [None]:
import matplotlib.pyplot as plt

model_names = ['burgers_A1_w1.pth', 'burgers_A1_w0.pth']

fig, ax = plt.subplots(len(model_names), 3, figsize=(18,5*len(model_names)), sharex=True,
                       sharey=True, tight_layout=True, facecolor='white', squeeze=False)
fig.suptitle(r'Burgers equation: $\frac{\partial u}{\partial t} = -\lambda_1 u \frac{\partial u}{\partial x} + \lambda_2 \frac{\partial^2 u}{\partial t^2}\;$',
             size=24)

for j, model_name in enumerate(model_names):
    model = PINN_Burgers(layer_width=20,
                             layer_depth=8,
                             activation_function='tanh',
                             initializer='xavier',
                            )
    model.load_state_dict(torch.load(model_name))
    
    print('Model:', model_name)
    print('  Lambda1:', model.lambda1)
    print('  Lambda2:', model.lambda2)
    Y_hat = model.forward(torch.tensor(X, dtype=torch.float)).detach()
    for i, pos in enumerate([0, 50, 99]):
        start = pos * data['x'].shape[0]
        end = start + data['x'].shape[0]
        
        if j == 0:
            ax[j,i].set_title('Loss data + physics (t=%d)' % pos, size=16)
        elif j == 1:
            ax[j,i].set_title('Loss data (t=%d)' % pos, size=16)
        ax[j,i].plot(Y[start:end,0], label='Truth')
        ax[j,i].plot(Y_hat[start:end,0], label='PINN')
        ax[j,i].legend()