# PDE-Net 2.0 for 2D Burgers Equation
## —— pretrain

## Overview
PDE-Net 2.0, also named Poly PDE-Net, introduces some improvements on the basis of PDE-Net including a symbolic neural network and pseudo-upwind techniques.
More details can be found in https://arxiv.org/pdf/1812.04426.pdf.

This work constructs PDE-Net 2.0 by Mindspore 1.10.1 for solving 2D burgers equations.

## Problem Description
This case solves the inverse problem of 2d burgers partial differential equations with variable parameters and realizes long-term prediction.

## 2D Burgers Equation
$$
u_t = -uu_x +  vu_y + viscosity * u_{xx} + viscosity * u_{yy}, \quad (x,y) \in[0,2 \pi] \times[0,2 \pi]
$$

$$
v_t = -vv_y +  uv_x + viscosity * v_{xx} + viscosity * v_{yy}, \quad (x,y) \in[0,2 \pi] \times[0,2 \pi]
$$

$$
u|_{t=0} = u_0(x,y), v|_{t=0} = v_0(x,y)
$$

$$
Boundary: periodic.
$$

## Model structure
Same as PDE-Net, PDE-Net 2.0 consists of a series of $\delta T$ blocks sharing parameters. The shared parameters include trainable moment matrices, each of which corresponds to a convolutional kernel responsible for a specific spatial difference order.
On the basis of PDE-Net, PDE-Net 2.0 introduces symbolic network to aggregate outputs of different convolutional kernels. Furthermore, a method named pseudo-upwind is utilized to obtain stable prediction.

![block](images/poly_pdenet_block.png)

![symbolic](images/poly_pdenet_symbolic.png)

![poly_pdenet](images/poly_pdenet.png)

## Technical roadmap
We solve the problem described above as follows:
    1. Model construction.
    2. Warmup training.
    3. Multistep training
    4. Test and Visualization

In [1]:
import argparse
import os
import mindspore as ms
import numpy as np
from mindspore import nn
from mindspore.amp import all_finite
from src.dataset import DataGenerator
from src.pdenet import PDENetWithLoss, PolyPDENet2D
from src.utils import init_env, get_config, init_model, load_param_dict, mkdir, generate_train_data
from src.utils import test, evaluate
from mindspore.train.serialization import load_param_into_net

## Training environment settings

In [2]:
ms.set_seed(1999)
np.random.seed(1999)
my_config = get_config('config.yaml')
my_config['device_target'] = 'CPU'
mkdir(config=my_config)
ms.set_context(mode=ms.PYNATIVE_MODE, device_target=my_config['device_target'], device_id=0)

## Model construction
PDENet 2.0 is defined in class PolyPDENet2D. You need to specify the time-step, mesh width, mesh height, kernel size, highest order of kernels, number of symbolic network 's hidden layers, initial range of symbolic network's parameters, and whether using pseudo-upwind method.

In [3]:
def init_model(config):
    pde_net = PolyPDENet2D(dt=config['dt'], dx=config['dx'], kernel_size=config['kernel_size'],
                           symnet_hidden_num=config['symnet_hidden_num'], symnet_init_range=config['symnet_init_range'],
                           max_order=config['max_order'], if_upwind=config['if_upwind'], dtype=ms.float32)
    return pde_net

## Define single step training
The model is trained with increasing steps. The moment matrices are not updated during warmup phase at which step equals 1. While step is larger than 1, the moment matrices are normally updated. Each time a $\delta T$ block is added, the program generates data and reads data sets. After the model initialized, the program loads the checkpoint trained in the previous step, defines the optimizer, mode, and loss function. During training process, the performance of model is evaluated and the equivalent expression of model is shown periodically.

In [4]:
def single_train(step_num: int, config: dict, data_generator: DataGenerator):
    # generate data: (data_num, step_num + 1, 2, sample_mesh_size_y, sample_mesh_size_x)
    train_dataset = generate_train_data(config=config, data_generator=data_generator)
    pde_net = init_model(config=config)
    if step_num == 1:
        # warm up
        regularization = False
        frozen = True
        pde_net.set_frozen(frozen=frozen)
        epochs = config['warmup_epochs']
        lr = config['warmup_lr']
    else:
        if step_num == 2:
            load_epoch = config['warmup_epochs']
        else:
            load_epoch = config['epochs']

        param_dict = load_param_dict(save_directory=config['save_directory'], step=step_num - 1, epoch=load_epoch)
        param_not_load = load_param_into_net(net=pde_net, parameter_dict=param_dict)
        print('=============== Net saved at last step is loaded. ===============')
        print('!!!!!!!!! param not loaded: ', param_not_load)

        regularization = True
        frozen = False
        pde_net.set_frozen(frozen=frozen)
        epochs = config['epochs']
        lr = config['lr'] * np.power(config['lr_reduce_gamma'], (step_num - 1) // config['lr_reduce_interval'])

    # lr scheduler
    my_optimizer = nn.Adam(params=pde_net.trainable_params(), learning_rate=lr)
    net_with_loss = PDENetWithLoss(pde_net=pde_net,
                                   moment_loss_threshold=config['moment_loss_threshold'],
                                   symnet_loss_threshold=config['symnet_loss_threshold'],
                                   moment_loss_scale=config['moment_loss_scale'],
                                   symnet_loss_scale=config['symnet_loss_scale'],
                                   step_num=step_num, regularization=regularization)

    def forward_fn(trajectory):
        loss = net_with_loss.get_loss(batch_trajectory=trajectory)
        return loss
    value_and_grad = ms.ops.value_and_grad(forward_fn, None, weights=my_optimizer.parameters)

    def train_process(trajectory):
        # TNCHW
        trajectory = ms.numpy.swapaxes(trajectory, 0, 1)
        loss, grads = value_and_grad(trajectory)
        if config['device_target'].upper() == 'ASCEND':
            status = ms.numpy.zeros((8, ))
        else:
            status = None
        if all_finite(grads, status=status):
            my_optimizer(grads)
        return loss

    for epoch_idx in range(1, epochs + 1):
        pde_net.set_train(mode=True)
        avg_loss = 0
        for batch_trajectory in train_dataset.fetch():
            train_loss = train_process(batch_trajectory)
            avg_loss += train_loss.asnumpy()
        print('step_num: {} -- epoch: {} -- lr: {} -- loss: {}'.format(step_num, epoch_idx,
                                                                       my_optimizer.learning_rate.value(),
                                                                       avg_loss))
        # generate new data
        if epoch_idx % config['generate_data_interval'] == 0:
            train_dataset = generate_train_data(config=config, data_generator=data_generator)
        # evaluate
        if epoch_idx % config['evaluate_interval'] == 0:
            evaluate_error = evaluate(model=pde_net, data_generator=data_generator, config=config, step_num=step_num)
            print('=============== Max evaluate error: {} ==============='.format(evaluate_error))
            print('=============== Current Expression ===============')
            pde_net.show_expression(coe_threshold=config['coe_threshold'])

        if epoch_idx == epochs:
            print('=============== Current Expression ===============')
            pde_net.show_expression(coe_threshold=config['coe_threshold'])
            pde_net.show_kernels()
            save_path = os.path.join(config['save_directory'],
                                     'pde_net_step{}_epoch{}.ckpt'.format(step_num, epoch_idx))
            ms.save_checkpoint(pde_net, save_path)
    return

## Warmup training

In [5]:
def pretrain(config):
    data_generator = DataGenerator(config=config)
    single_train(config=config, step_num=1, data_generator=data_generator)
    return

In [6]:
pretrain(config=my_config)

generating data ...
step 20 generated.
generated.
step_num: 1 -- epoch: 1 -- lr: 0.001 -- loss: 26.776227951049805
step_num: 1 -- epoch: 2 -- lr: 0.001 -- loss: 26.633957386016846
step_num: 1 -- epoch: 3 -- lr: 0.001 -- loss: 26.502192497253418
step_num: 1 -- epoch: 4 -- lr: 0.001 -- loss: 26.375828742980957
step_num: 1 -- epoch: 5 -- lr: 0.001 -- loss: 26.273977279663086
step_num: 1 -- epoch: 6 -- lr: 0.001 -- loss: 26.165760040283203
step_num: 1 -- epoch: 7 -- lr: 0.001 -- loss: 26.068905353546143
step_num: 1 -- epoch: 8 -- lr: 0.001 -- loss: 25.981074810028076
step_num: 1 -- epoch: 9 -- lr: 0.001 -- loss: 25.893912315368652
step_num: 1 -- epoch: 10 -- lr: 0.001 -- loss: 25.81773567199707
step_num: 1 -- epoch: 11 -- lr: 0.001 -- loss: 25.738539695739746
step_num: 1 -- epoch: 12 -- lr: 0.001 -- loss: 25.66281223297119
step_num: 1 -- epoch: 13 -- lr: 0.001 -- loss: 25.60450839996338
step_num: 1 -- epoch: 14 -- lr: 0.001 -- loss: 25.535609245300293
step_num: 1 -- epoch: 15 -- lr: 0.001 