Skip to content

Commit

Permalink
edit
Browse files Browse the repository at this point in the history
  • Loading branch information
DUCH714 committed Jan 18, 2024
1 parent 0739d7f commit d42f7f8
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 69 deletions.
6 changes: 3 additions & 3 deletions examples/phycrnet/conf/burgers_equations.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_phycrnet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: burgers/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchanged
Expand All @@ -23,8 +23,8 @@ hydra:
mode: train # running mode: train/eval
seed: 66
output_dir: ${hydra:run.dir}
DATA_PATH: ./data/burgers_1501x2x128x128.mat

DATA_PATH: ./data/burgers_2001x2x128x128.mat
num_convlstm: 1
# set working condition
TIME_STEPS: 1001
DT: 0.002
Expand Down
4 changes: 2 additions & 2 deletions examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_phycrnet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: fitzhugh/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchanged
Expand All @@ -24,7 +24,7 @@ mode: train # running mode: train/eval
seed: 66
output_dir: ${hydra:run.dir}
DATA_PATH: ./data/FN_1001x2x128x128.mat

num_convlstm: 1
# set working condition
TIME_STEPS: 751
DT: 0.006
Expand Down
4 changes: 2 additions & 2 deletions examples/phycrnet/conf/lambda_omega_RD_equation.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_phycrnet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
dir: lambda_omega/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchanged
Expand Down Expand Up @@ -31,7 +31,7 @@ TIME_STEPS: 201
DT: 0.025
DX: [20.0, 512]
TIME_BATCH_SIZE: 200

num_convlstm: 1
# model settings
MODEL:
input_channels: 2
Expand Down
18 changes: 5 additions & 13 deletions examples/phycrnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
PhyCRNet for solving spatiotemporal PDEs
Reference: https://github.com/isds-neu/PhyCRNet/
"""
import os
from os import path as osp

import functions
Expand All @@ -22,10 +21,10 @@ def train(cfg: DictConfig):
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

# set initial states for convlstm
num_convlstm = 1
NUM_CONVLSTM = cfg.num_convlstm
(h0, c0) = (paddle.randn((1, 128, 16, 16)), paddle.randn((1, 128, 16, 16)))
initial_state = []
for _ in range(num_convlstm):
for _ in range(NUM_CONVLSTM):
initial_state.append((h0, c0))

# grid parameters
Expand Down Expand Up @@ -111,9 +110,7 @@ def _transform_out(_in, _out):
# initialize solver
scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.lr_scheduler)()

# temporary, better than scheduler, to align the code, use:
# optimizer = ppsci.optimizer.Adam(scheduler)(model)
optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr_scheduler.learning_rate)(model)
optimizer = ppsci.optimizer.Adam(scheduler)(model)
solver = ppsci.solver.Solver(
model,
constraint_pde,
Expand All @@ -133,11 +130,6 @@ def _transform_out(_in, _out):
model.register_output_transform(functions.tranform_output_val)
solver.eval()

# save the model
checkpoint_path = os.path.join(cfg.output_dir, "phycrnet.pdparams")
layer_state_dict = model.state_dict()
paddle.save(layer_state_dict, checkpoint_path)


def evaluate(cfg: DictConfig):
# set random seed for reproducibility
Expand All @@ -146,10 +138,10 @@ def evaluate(cfg: DictConfig):
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

# set initial states for convlstm
num_convlstm = 1
NUM_CONVLSTM = cfg.num_convlstm
(h0, c0) = (paddle.randn((1, 128, 16, 16)), paddle.randn((1, 128, 16, 16)))
initial_state = []
for _ in range(num_convlstm):
for _ in range(NUM_CONVLSTM):
initial_state.append((h0, c0))

# grid parameters
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ nav:
- LDC2D_unsteady: zh/examples/ldc2d_unsteady.md
- Labelfree_DNN_surrogate: zh/examples/labelfree_DNN_surrogate.md
- NSFNets: zh/examples/nsfnet.md
- PhyCRNet: zh/examples/phycrnet.md
- ShockWave: zh/examples/shock_wave.md
- tempoGAN: zh/examples/tempoGAN.md
- PhyCRNet: zh/examples/phycrnet.md
- ViV: zh/examples/viv.md
- 结构:
- Biharmonic2D: zh/examples/biharmonic2d.md
Expand Down
96 changes: 48 additions & 48 deletions ppsci/arch/phycrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ppsci.arch import base

# define the high-order finite difference kernels
lapl_op = [
LALP_OP = [
[
[
[0, 0, -1 / 12, 0, 0],
Expand All @@ -20,7 +20,7 @@
]
]

partial_y = [
PARTIAL_Y = [
[
[
[0, 0, 0, 0, 0],
Expand All @@ -32,7 +32,7 @@
]
]

partial_x = [
PARTIAL_X = [
[
[
[0, 0, 1 / 12, 0, 0],
Expand All @@ -44,8 +44,9 @@
]
]


# specific parameters for burgers equation
def initialize_weights(module):
def _initialize_weights(module):
if isinstance(module, nn.Conv2D):
c = 1.0 # 0.5
initializer = nn.initializer.Uniform(
Expand Down Expand Up @@ -75,15 +76,15 @@ class PhyCRNet(base.Arch):
Examples:
>>> import ppsci
>>> model = ppsci.arch.PhyCRNet(
input_channels=2,
hidden_channels=[8, 32, 128, 128],
input_kernel_size=[4, 4, 4, 3],
input_stride=[2, 2, 2, 1],
input_padding=[1, 1, 1, 1],
dt=0.002,
num_layers=[3, 1],
upscale_factor=8
)
... input_channels=2,
... hidden_channels=[8, 32, 128, 128],
... input_kernel_size=[4, 4, 4, 3],
... input_stride=[2, 2, 2, 1],
... input_padding=[1, 1, 1, 1],
... dt=0.002,
... num_layers=[3, 1],
... upscale_factor=8
... )
"""

def __init__(
Expand Down Expand Up @@ -118,32 +119,32 @@ def __init__(
self.num_convlstm = num_layers[1]

# encoder - downsampling
for i in range(self.num_encoder):
name = "encoder{}".format(i)
cell = encoder_block(
input_channels=self.input_channels[i],
hidden_channels=self.hidden_channels[i],
input_kernel_size=self.input_kernel_size[i],
input_stride=self.input_stride[i],
input_padding=self.input_padding[i],
)

setattr(self, name, cell)
self._all_layers.append(cell)
self.encoder = paddle.nn.LayerList(
[
encoder_block(
input_channels=self.input_channels[i],
hidden_channels=self.hidden_channels[i],
input_kernel_size=self.input_kernel_size[i],
input_stride=self.input_stride[i],
input_padding=self.input_padding[i],
)
for i in range(self.num_encoder)
]
)

# ConvLSTM
for i in range(self.num_encoder, self.num_encoder + self.num_convlstm):
name = "convlstm{}".format(i)
cell = ConvLSTMCell(
input_channels=self.input_channels[i],
hidden_channels=self.hidden_channels[i],
input_kernel_size=self.input_kernel_size[i],
input_stride=self.input_stride[i],
input_padding=self.input_padding[i],
)

setattr(self, name, cell)
self._all_layers.append(cell)
self.ConvLSTM = paddle.nn.LayerList(
[
ConvLSTMCell(
input_channels=self.input_channels[i],
hidden_channels=self.hidden_channels[i],
input_kernel_size=self.input_kernel_size[i],
input_stride=self.input_stride[i],
input_padding=self.input_padding[i],
)
for i in range(self.num_encoder, self.num_encoder + self.num_convlstm)
]
)

# output layer
self.output_layer = nn.Conv2D(
Expand All @@ -154,7 +155,7 @@ def __init__(
self.pixelshuffle = nn.PixelShuffle(self.upscale_factor)

# initialize weights
self.apply(initialize_weights)
self.apply(_initialize_weights)
initializer_0 = paddle.nn.initializer.Constant(0.0)
initializer_0(self.output_layer.bias)
self.enable_transform = True
Expand All @@ -175,22 +176,20 @@ def forward(self, x):
xt = x

# encoder
for i in range(self.num_encoder):
name = "encoder{}".format(i)
x = getattr(self, name)(x)
for encoder in self.encoder:
x = encoder(x)

# convlstm
for i in range(self.num_encoder, self.num_encoder + self.num_convlstm):
name = "convlstm{}".format(i)
for i, LSTM in enumerate(self.ConvLSTM):
if step == 0:
(h, c) = getattr(self, name).init_hidden_tensor(
(h, c) = LSTM.init_hidden_tensor(
prev_state=self.initial_state[i - self.num_encoder]
)
internal_state.append((h, c))

# one-step forward
(h, c) = internal_state[i - self.num_encoder]
x, new_c = getattr(self, name)(x, h, c)
x, new_c = LSTM(x, h, c)
internal_state[i - self.num_encoder] = (x, new_c)

# output
Expand Down Expand Up @@ -464,15 +463,15 @@ def __init__(self, dt, dx):

# spatial derivative operator
self.laplace = Conv2DDerivative(
der_filter=lapl_op, resol=(dx**2), kernel_size=5, name="laplace_operator"
der_filter=LALP_OP, resol=(dx**2), kernel_size=5, name="laplace_operator"
)

self.dx = Conv2DDerivative(
der_filter=partial_x, resol=(dx * 1), kernel_size=5, name="dx_operator"
der_filter=PARTIAL_X, resol=(dx * 1), kernel_size=5, name="dx_operator"
)

self.dy = Conv2DDerivative(
der_filter=partial_y, resol=(dx * 1), kernel_size=5, name="dy_operator"
der_filter=PARTIAL_Y, resol=(dx * 1), kernel_size=5, name="dy_operator"
)

# temporal derivative operator
Expand Down Expand Up @@ -517,6 +516,7 @@ def get_phy_Loss(self, output):
assert laplace_u.shape == u.shape
assert laplace_v.shape == v.shape

# Reynolds number
R = 200.0

# 2D burgers eqn
Expand Down

0 comments on commit d42f7f8

Please sign in to comment.