# **Burger's equation**:
# $ u_t + u u_{xx} = v u_{xx} \qquad x\in[-1,1], \quad t\in [0,1], \quad u ≡ u(x,t)$

## Boundary conditions (Dirichlet):
* $ u(-1, t) = 0, \qquad t\in [0,1]$
* $ u(1, t) = 0, \qquad t\in [0,1]$

## Initial condition:
* $ u(x, 0) = -\sin(\pi x), \qquad x\in[-1,1] $

<!-- ## *ANALYTICAL SOLUTION*
## $ u(x, y, t) = e^{-13\pi^2t}\sin(3\pi x)\sin(2\pi y)  $ -->




# Imports

In [None]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings("ignore")

# Helpers

## Data

In [None]:
def generate_random_numbers(min_, max_, N, dtype):
    return min_ + (max_ - min_) * torch.rand(size=(N,), dtype=dtype)


class Data():
    def __init__(self,
                 x_min, x_max,
                 t_min, t_max,
                 Nx_domain, Nt_domain,
                 Nx_init, Nt_bound,
                 N_test,
                 device='cpu',
                 dtype=torch.float32):

        self.x_min = x_min
        self.x_max = x_max
        self.t_min = t_min
        self.t_max = t_max
        self.Nx_domain = Nx_domain
        self.Nt_domain = Nt_domain
        self.Nx_init = Nx_init
        self.Nt_bound = Nt_bound
        self.N_test = N_test
        self.device = device
        self.dtype = dtype


    # *** Create in-domain points ***
    def sample_domain(self):
        # Random Grid
        x_domain = generate_random_numbers(self.x_min, self.x_max, self.Nx_domain, self.dtype)
        t_domain = generate_random_numbers(self.t_min, self.t_max, self.Nt_domain, self.dtype)
        domain_data = torch.stack(torch.meshgrid(x_domain, t_domain)).view(2, -1).permute(1, 0).requires_grad_(True).to(self.device)
        return domain_data

    # *** Boundary Conditions ***
    def sample_boundary(self):
        # Random boundary points
        t_bound = generate_random_numbers(self.t_min, self.t_max, self.Nt_bound, self.dtype)
        x_left = - torch.ones(1, dtype=self.dtype)
        x_right = torch.ones(1, dtype=self.dtype)

        bound_data_left = torch.stack(torch.meshgrid(x_left, t_bound)).view(2, -1).permute(1, 0)
        bound_data_right = torch.stack(torch.meshgrid(x_right, t_bound)).view(2, -1).permute(1, 0)
        bound_data = torch.cat([bound_data_left, bound_data_right]).requires_grad_(True).to(self.device)

        u_bound = torch.zeros(len(bound_data), 1, dtype=self.dtype, device=self.device)

        return bound_data, u_bound


    # *** Initial Condition ***
    def sample_initial(self):
        # Random initial points
        x_init = generate_random_numbers(self.x_min, self.x_max, self.Nx_init, self.dtype)
        t_init = torch.zeros(1, dtype=self.dtype)
        init_data = torch.stack(torch.meshgrid(x_init, t_init)).view(2, -1).permute(1, 0).requires_grad_(True).to(self.device)

        u_init = - torch.sin(math.pi * x_init)

        return init_data, u_init

    # *** Test set ***
    def sample_test(self):
        x_test = self.x_min + (self.x_max - self.x_min) * torch.rand(self.N_test)
        t_test = self.t_min + (self.t_max - self.t_min) * torch.rand(self.N_test)
        return torch.stack([x_test, t_test], dim=1).requires_grad_(True).to(self.device)

## Networks

In [None]:
class MLPBase(nn.Module):
    def __init__(self, layers, activation=nn.Tanh(), weight_init=None, bias_init=None, device='cpu'):
        super().__init__()
        self.n_layers = len(layers) - 1
        self.layers = layers
        self.activation = activation
        self.weight_init = weight_init
        self.bias_init = bias_init

        dense_layers = [
            self.dense_layer(in_features=self.layers[i], out_features=self.layers[i + 1])
            for i in range(self.n_layers - 1)]
        dense_layers.append(nn.Linear(in_features=self.layers[-2], out_features=self.layers[-1]))

        self.mlp = nn.Sequential(*dense_layers).to(device)

    def dense_layer(self, in_features, out_features):
        dense_layer = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=out_features),
        )

        if self.weight_init is not None:
            self.weight_init(dense_layer[0].weight)

        if self.bias_init is not None:
            self.bias_init(dense_layer[0].bias)

        dense_layer.add_module("activation", self.activation)
        return dense_layer


class gMLP(MLPBase):
    def __init__(self, layers, activation=nn.Tanh(), weight_init=None, bias_init=None, device='cpu'):
        super().__init__(layers, activation, weight_init, bias_init, device)

    def forward(self, x):
        g_out = self.mlp(x)
        x_out = torch.tanh(g_out[:, 0].clone()).view(-1, 1)
        t_out = torch.sigmoid(g_out[:, 1].clone()).view(-1, 1)
        return torch.cat((x_out, t_out), dim=1)


class dMLP(MLPBase):
    def __init__(self, layers, activation=nn.Tanh(), weight_init=None, bias_init=None, device='cpu'):
        super().__init__(layers, activation, weight_init, bias_init, device)

    def forward(self, x):
        return self.mlp(x)


## Generator

In [None]:
class Generator():
    def __init__(self,
                 layers,
                 activation,
                 device):

        # Define the model
        self.model = gMLP(layers=layers,
                          activation=activation,
                          weight_init=lambda m: nn.init.xavier_normal_(m.data, nn.init.calculate_gain('tanh')),
                          bias_init=lambda m: nn.init.zeros_(m.data),
                          device=device)

        # Set the optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters())


    def forward(self, x):
        return self.model(x)


    def calculate_loss(self, fake_data, discriminator):
        pde_res = discriminator.calculate_pde_residual(fake_data)
        pde_target = torch.zeros_like(pde_res)
        return - discriminator.criterion(pde_res, pde_target)


    def train(self, fake_data, discriminator):
        # "Zero" the gradients
        self.optimizer.zero_grad()

        # Calculate loss
        loss = self.calculate_loss(fake_data, discriminator)

        # Backpropagate the loss
        loss.backward()

        # Implement one step of gradient descent
        self.optimizer.step()

        return loss

## Discriminator

In [None]:
class Discriminator():
    def __init__(self,
                 layers,
                 activation,
                 device):

        self.v = 0.01 / math.pi

        # Define the model
        self.model = dMLP(layers=layers,
                          activation=activation,
                          weight_init=lambda m: nn.init.xavier_normal_(m.data, nn.init.calculate_gain('tanh')),
                          bias_init=lambda m: nn.init.zeros_(m.data),
                          device=device)

        # Set the optimizers
        adam = torch.optim.Adam(self.model.parameters())
        lbfgs = torch.optim.LBFGS(self.model.parameters(),
                                  lr=1,
                                  max_iter=250,  # max_iter=2000,
                                  max_eval=None,
                                  tolerance_grad=1e-07,
                                  tolerance_change=1e-09,
                                  history_size=100,
                                  line_search_fn='strong_wolfe')

        self.optimizers = {"adam": adam, "lbfgs": lbfgs}

        # Set the Loss function
        self.criterion = nn.MSELoss()

        # Set the MAE criterion for test data only
        self.l1_loss = nn.L1Loss()


    def forward(self, x):
        return self.model(x)


    def grad(self, output, input):
        return torch.autograd.grad(
                    output, input,
                    grad_outputs=torch.ones_like(output),
                    retain_graph=True,
                    create_graph=True
                )[0]


    def calculate_pde_residual(self, x):
        # Forward pass
        u = self.forward(x)

        # Calculate 1st and 2nd derivatives
        du_dX = self.grad(u, x)
        du_dXX = self.grad(du_dX, x)

        # Retrieve the partial gradients
        du_dt = du_dX[:, 1].flatten()
        du_dx = du_dX[:, 0].flatten()
        du_dxx = du_dXX[:, 0].flatten()

        return du_dt + u.flatten() * du_dx - self.v * du_dxx


    def calculate_pde_loss(self, data):
        # Calculate the domain loss
        pde_res = self.calculate_pde_residual(data)
        pde_target = torch.zeros_like(pde_res)
        return self.criterion(pde_res, pde_target)


    def calculate_real_loss(self, real_data):
        # Calculate boundary loss
        loss_b = self.criterion(
            self.forward(real_data["bound_data"]).flatten(),
            real_data["u_bound"].flatten()
        )

        # Calculate initial loss
        loss_i = self.criterion(
            self.forward(real_data["init_data"]).flatten(),
            real_data["u_init"].flatten()
        )

        # Calculate the domain loss
        loss_pde = self.calculate_pde_loss(real_data["domain_data"])

        # Calculate total discriminator loss
        return loss_b + loss_i + loss_pde


    def calculate_fake_loss(self, fake_data):
        return self.calculate_pde_loss(fake_data)


    def calculate_test_loss(self, test_data):
        pde_res = self.calculate_pde_residual(test_data)
        pde_target = torch.zeros_like(pde_res)
        return self.l1_loss(pde_res, pde_target)


    def train_on_real(self, real_data):
        loss_real = self.calculate_real_loss(real_data)
        loss_real.backward()
        return loss_real


    def train_on_fake(self, fake_data):
        loss_fake = self.calculate_fake_loss(fake_data.detach().requires_grad_(True))
        loss_fake.backward()
        return loss_fake


    def closure(self):
        self.lbfgs_optimizer.zero_grad()
        loss_real = self.train_on_real(self.real_data)
        loss_fake = self.train_on_fake(self.fake_data)
        return loss_real + loss_fake

## GAN-PINN

In [None]:
class GAN_PINN():
    def __init__(self,
                 x_min, x_max,
                 t_min, t_max,
                 Nx_domain, Nt_domain,
                 Nx_init, Nt_bound,
                 N_test, N_noise,
                 g_layers, g_activation,
                 d_layers, d_activation,
                 checkpoint_path,
                 device='cpu',
                 dtype=torch.float32):

        # Constants
        self.checkpoint_path = checkpoint_path
        self.device = device
        self.dtype = dtype
        self.N_noise = N_noise
        self.N_test = N_test

        # Create real data
        self.real_data_init = Data(x_min, x_max,
                                   t_min, t_max,
                                   Nx_domain, Nt_domain,
                                   Nx_init, Nt_bound,
                                   N_test,
                                   device,
                                   dtype)

        # Create test data
        self.test_data = self.real_data_init.sample_test()

        # Create Generator
        self.generator = Generator(g_layers, g_activation, device)

        # Create Discriminator
        self.discriminator = Discriminator(d_layers, d_activation, device)


    def generate_data(self):
        # Create real data
        real_data = {}
        real_data["domain_data"] = self.real_data_init.sample_domain()
        real_data["bound_data"], real_data["u_bound"] = self.real_data_init.sample_boundary()
        real_data["init_data"], real_data["u_init"] = self.real_data_init.sample_initial()

        # Create noise (Generator's input)
        random_tensor = torch.rand(self.N_noise, 2)  # 2 refers to (x, t)
        x_noise = 2 * random_tensor[:, 0] - 1
        t_noise = random_tensor[:, 1]
        noise = torch.cat((x_noise.view(-1, 1), t_noise.view(-1, 1)), dim=1)

        return real_data, noise


    def train_with_adam(self, N_adam, real_data, fake_data):
        optimizer = self.discriminator.optimizers['adam']

        for epoch in range(1, N_adam + 1):
            optimizer.zero_grad()
            loss_real = self.discriminator.train_on_real(real_data)
            loss_fake = self.discriminator.train_on_fake(fake_data)
            loss_D = loss_real + loss_fake
            optimizer.step()


    def train_with_lbfgs(self, N_lbfgs, real_data, fake_data):
        self.discriminator.lbfgs_optimizer = self.discriminator.optimizers["lbfgs"]
        self.discriminator.real_data = real_data
        self.discriminator.fake_data = fake_data

        for epoch in range(1, N_lbfgs + 1):
            loss_D = self.discriminator.lbfgs_optimizer.step(self.discriminator.closure)

        return loss_D


    def checkpoint(self):
        torch.save({
            "model": self.discriminator.model.state_dict()
        }, self.checkpoint_path)


    def format_loss(self, loss):
        if loss == 0:
            return "0.0e+00"

        # Calculate the exponent part
        exponent = int(math.log10(abs(loss)))

        # Determine the format based on the value of the loss
        if abs(loss) < 1:
            formatted_loss = f"{loss:.2e}"
        else:
            # Adjust the sign of the formatted loss
            sign = "-" if loss < 0 else ""

            # Calculate the number of decimal places
            decimal_places = 2 - exponent

            # Ensure at least two decimal places
            decimal_places = max(decimal_places, 2)

            # Format the loss with the correct sign
            formatted_loss = f"{sign}{abs(loss):.{decimal_places}e}"

        return formatted_loss


    def keep_checkpoints_and_print_losses(self, iter, patience, print_every,
                                          loss_D, loss_G, loss_test):

        loss_D_str = self.format_loss(loss_D)
        loss_G_str = self.format_loss(loss_G)
        loss_test_str = self.format_loss(loss_test)

        if iter == 1:
            self.best_val_loss = loss_test
            self.best_epoch = -1
            self.checkpoint()
            self.flag = 1
            print(f"Iteration: {iter} | loss_D: {loss_D_str} | loss_G: {loss_G_str} | test_mae: {loss_test_str} - *Checkpoint*")
        else:
            if loss_test < self.best_val_loss:
                self.best_val_loss = loss_test
                self.best_epoch = iter
                self.checkpoint()
                self.flag = 1
                if iter % print_every == 0:
                    print(f"Iteration: {iter} | loss_D: {loss_D_str} | loss_G: {loss_G_str} | test_mae: {loss_test_str} - *Checkpoint*")
            elif iter - self.best_epoch > patience:
                if iter % print_every == 0:
                    self.early_stopping_applied = 1
                    print(f"Iteration: {iter} | loss_D: {loss_D_str} | loss_G: {loss_G_str} | test_mae: {loss_test_str}")
                return

        if (self.flag == 0) and (iter % print_every == 0):
            print(f"Iteration: {iter} | loss_D: {loss_D_str} | loss_G: {loss_G_str} | test_mae: {loss_test_str}")


    def train(self, iters, patience, print_every, N_adam, N_lbfgs):

        print(f"GAN-PINN: {iters} iterations")
        print(f"a. PINN: {N_adam} epochs --> Adam")
        print(f"b. PINN: {N_lbfgs} epochs --> L-BFGS")
        print(f"c. Generator: 1 epoch --> Adam\n")

        for iter in tqdm(range(1, iters + 1)):
            self.flag = 0
            self.early_stopping_applied = 0

            real_data, noise = self.generate_data()
            fake_data = self.generator.model(noise)

            self.train_with_adam(N_adam, real_data, fake_data)
            loss_D = self.train_with_lbfgs(N_lbfgs, real_data, fake_data)
            loss_G = self.generator.train(fake_data, self.discriminator)

            loss_test = self.discriminator.calculate_test_loss(self.test_data)

            self.keep_checkpoints_and_print_losses(iter, patience, print_every,
                                                   loss_D, loss_G, loss_test)

            if self.early_stopping_applied:
                print(f"\nEarly stopping applied at epoch {iter}.")
                break

# Main

In [None]:
# Data
x_min, x_max = -1, 1
t_min, t_max = 0, 1
Nx_domain = 200
Nt_domain = 100
Nx_init = 100
Nt_bound = 100
N_noise = 1_000
N_test = 100_000

# Generator
Ng_layers = 3
Ng_neurons = 64
g_layers = [2] + Ng_layers * [Ng_neurons] + [2]
g_activation = nn.Tanh()

# Discriminator
Nd_layers = 3
Nd_neurons = 20
d_layers = [2] + Nd_layers * [Nd_neurons] + [1]
d_activation = nn.Tanh()

# Other
checkpoint_path = "discriminator.pth"
dtype = torch.float32
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# GAN-PINN initialization
gan_pinn = GAN_PINN(
    x_min, x_max,
    t_min, t_max,
    Nx_domain, Nt_domain,
    Nx_init, Nt_bound,
    N_test, N_noise,
    g_layers, g_activation,
    d_layers, d_activation,
    checkpoint_path
)

# Training
iterations = 100
patience = 10
print_every = 1
num_epochs_adam = 20
num_epochs_lbfgs = 5

gan_pinn.train(iterations, patience, print_every, num_epochs_adam, num_epochs_lbfgs)

GAN-PINN: 100 iterations
a. PINN: 20 epochs --> Adam
b. PINN: 5 epochs --> L-BFGS
c. Generator: 1 epoch --> Adam



  0%|          | 0/100 [00:00<?, ?it/s]

Iteration: 1 | loss_D: 1.77e-02 | loss_G: -9.12e-04 | test_mae: 5.63e-02 - *Checkpoint*
Iteration: 2 | loss_D: 8.32e-03 | loss_G: -5.43e-04 | test_mae: 4.97e-02 - *Checkpoint*
Iteration: 3 | loss_D: 7.16e-03 | loss_G: -2.61e-04 | test_mae: 3.90e-02 - *Checkpoint*
Iteration: 4 | loss_D: 3.81e-03 | loss_G: -2.21e-04 | test_mae: 2.90e-02 - *Checkpoint*
Iteration: 5 | loss_D: 1.83e-03 | loss_G: -2.17e-04 | test_mae: 2.65e-02 - *Checkpoint*
Iteration: 6 | loss_D: 1.64e-03 | loss_G: -1.76e-04 | test_mae: 2.72e-02
Iteration: 7 | loss_D: 1.11e-03 | loss_G: -1.07e-04 | test_mae: 3.82e-02
Iteration: 8 | loss_D: 2.44e-02 | loss_G: -4.58e-03 | test_mae: 4.14e-02
Iteration: 9 | loss_D: 1.09e-03 | loss_G: -1.17e-04 | test_mae: 4.04e-02
Iteration: 10 | loss_D: 9.70e-04 | loss_G: -7.67e-05 | test_mae: 2.37e-02 - *Checkpoint*
Iteration: 11 | loss_D: 8.62e-04 | loss_G: -4.63e-05 | test_mae: 4.27e-02
Iteration: 12 | loss_D: 7.10e-04 | loss_G: -6.04e-05 | test_mae: 1.94e-02 - *Checkpoint*
Iteration: 13 | 