Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.57】Neural networks for topology optimization #597

Merged
merged 15 commits into from Nov 21, 2023

Conversation

NKNaN
Copy link
Contributor

@NKNaN NKNaN commented Oct 23, 2023

PR types

Others

PR changes

Others

Describe

按照意见修改了之前的pr: #559

@paddle-bot
Copy link

paddle-bot bot commented Oct 23, 2023

Thanks for your contribution!

@lijialin03
Copy link
Contributor

辛苦大佬提交代码~
请参照 模型复现流程及验收标准 检查一下代码
另外代码需要提交到 example/案例名称 下,文档提交到docs/zh/examples,规范也请参照上面链接或其他example的文档

nn.Sigmoid(),
)

def forward(self, x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参照UNetEx,是否可以尽量使用for循环,或调用UNetEx中的encodedecode来写forward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

x1 = x[self.input_keys[0]][:, k, :, :]
x2 = x[self.input_keys[0]][:, k - 1, :, :]
x = paddle.stack((x1, x1 - x2), axis=1)
# Layer 1 (bs, 2, 40, 40) -> (bs, 16, 40, 40)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


def augmentation(input_dict, label_dict, weight_dict=None):
"""Apply random transformation from D4 symmetry group
# Arguments
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hint的格式不对,可参考ppsci下文件修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -0,0 +1,409 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前github代码中不放ipynb,需要改成代码。最终案例文档会附加AI Studio链接,可以把这个ipynb放到AI Studio项目中


if __name__ == "__main__":
### CASE 1: poisson(5)
SIMP_stop_point_sampler = poisson_sampler(5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个例子是不是除了这一行都一样?如果是的话,只写一个文件就行,其他的备注写一下,或者写一个函数,如
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

VOL_COEFF = 1
LEARNING_RATE = 0.001 / (1 + NUM_EPOCHS // 15)
ITERS_PER_EPOCH = int(N_SAMPLE * TRAIN_TEST_RATIO / BATCH_SIZE)
NUM_PARAMS = 192113 # the number given in paper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考其他的案例的开头,如example/bracket.py,增加一些必要部分,调整一下参数顺序,以及加一点点说明参数是干什么的备注,不用加太多

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -0,0 +1,31 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件在case中似乎没有被import,看Readme是生成数据用的,代码中没有OriginalData,所以这个其实也跑不了。数据集可以先上传到AI Studio,这个文件就删掉不用提交了

# optimizer
optimizer = ppsci.optimizer.Adam(learning_rate=LEARNING_RATE, epsilon=1e-07)(
model
) # epsilon = 1e-07 is the default in tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句注释可以删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

iters_per_epoch=ITERS_PER_EPOCH,
)

solver.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参照其他案例如bracket,在solver这边加几句注释,不用加多

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -0,0 +1,38 @@
# TopOpt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

正式提交的时候需要写md文档,但不是README

)
),
dtype="float32",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w00 = paddle.sum(paddle.multiply(paddle.equal(paddle.round(output), 0.0),paddle.equal(paddle.round(y), 0.0),),dtype=paddle.get_default_dtype())应该就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

paddle.equal(paddle.round(y), 1.0),
)
),
dtype="float32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

paddle.equal(paddle.round(y), 0.0),
)
),
dtype="float32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

paddle.equal(paddle.round(y), 1.0),
)
),
dtype="float32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

"_vol_coeff",
str(VOL_COEFF),
]
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以写成 OUTPUT_DIR = (f"{OUTPUT_DIR}/{sampler_key}{num}_vol_coeff{VOL_COEFF}" if num is not Noneelse f"{OUTPUT_DIR}/{sampler_key}_vol_coeff{VOL_COEFF}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

model_path,
"".join([model_name, "_vol_coeff", str(VOL_COEFF)]),
"checkpoints",
"latest",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_path = f"{model_path}/{model_name}_vol_coeff{VOL_COEFF}/checkpoints/latest"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

from ppsci.utils import logger


def evaluation_and_plot(dataloader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数里eval的部分是否可以用validator写?用ppsci.loss.FunctionalLoss这个API自定义一下loss,想存4个实验的结果也可以写成类似于这样?
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

current_iou_results = []

# only evaluate for NUM_VAL_STEP times of iteration
for x, y, _ in iter(dataloader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边只是取了n个batch的数据?n相当于iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里evaluate的逻辑是,对于每一个训练的模型,用16组不同的数据做16次评估(不同的数据是指分别取原始数据的第5,10,15,20,...,80通道作为输入,对应代码中的 stop_iter , stop_iter 这里指SIMP的初始迭代次数——原始数据有100个通道对应的是SIMP算法100次的迭代结果,这个模型想做的就是用SIMP中间某一步的迭代结果直接预测最后一步的迭代结果,所以他是这样来评估的),每一次评估的时候只取 NUM_VAL_STEP 个batch的数据。

iou_results = []

# evaluation for different fixed iteration stop times
for stop_iter in iterations_stop_times:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边相当于epoch?



# NCHW data format
class TopOptNN(ppsci.arch.UNetEx):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参照MLP等网络加点注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@lijialin03
Copy link
Contributor

辛苦~
有几点需要再修改一下:

  1. 参照【快乐开源】基于hydra的案例改造计划 ,需要将案例改为hydra格式
  2. 参照hpinns案例,按照功能,将除了网络和训练/eval的其他代码,尽量整理在一两个文件中
  3. 案例训练代码文件名字改为案例名字,eval代码用hydra格式后,应该能和训练代码整合在同一个文件中?
  4. 将数据集上传AI Studio,以便我这边review效果
  5. md文档如果写了也可以一起提交,当然或者后面再单独提交pr也行
    谢谢

@NKNaN
Copy link
Contributor Author

NKNaN commented Nov 2, 2023

辛苦~ 有几点需要再修改一下:

  1. 参照【快乐开源】基于hydra的案例改造计划 ,需要将案例改为hydra格式
  2. 参照hpinns案例,按照功能,将除了网络和训练/eval的其他代码,尽量整理在一两个文件中
  3. 案例训练代码文件名字改为案例名字,eval代码用hydra格式后,应该能和训练代码整合在同一个文件中?
  4. 将数据集上传AI Studio,以便我这边review效果
  5. md文档如果写了也可以一起提交,当然或者后面再单独提交pr也行
    谢谢
  1. 已修改
  2. 已修改
  3. 已修改
  4. 数据集:https://aistudio.baidu.com/datasetdetail/245115
  5. md已提交

Copy link
Contributor

@lijialin03 lijialin03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦~

  1. 代码有一些需要修改的地方。
  2. 需要确定一下指标,以便进行ci检查,一般是eval的指标,如 l2 error,acc等。同时为了确认复现效果(参考官网),将指标和原代码的对比贴到comment中(不是md文档)吧,格式类似这样:
total acc possion5 xxx
paper xx xx xx
ppsci xx xx xx
diff xx% xx% xx%
  1. 文档内容大体没什么问题,主要需要检查一下用英文符号、公式是否正确显示、改代码行数(等代码完全确认后)
  2. 写完文档后提交之前,可以通过下图方法(可能需要pip安装一些包),用mkdocs提前预览显示效果,以确保正确显示
image 6. 另外需要在PaddleScience/mkdocs.yml中添加相应部分,以确保该案例文档可以通过官网主页进入

image
谢谢

hydra:
run:
# dynamic output directory
dir: outputs_topopt/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成:outputs_topopt/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
这样子会生成类似下面这样的目录结构,就不用担心训练结果被覆盖了
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

dir: outputs_topopt/
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchaned
Copy link
Contributor

@lijialin03 lijialin03 Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考bracket,增加exclude_keys,这样才能允许用户指定参数,如用 "EVAL.pretrained_model_path"指定预训练模型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

layers: 2

# training settings
TRAIN:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参照bracket,这边主要放跟训练有直接关系的值,比如epoch等,然后改一下名字。其他的值放到TRAIN这一层外面吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

shuffle: true

# evaluation settings
EVAL:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样的,不是直接相关的值放外面一层就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

CASE_PARAM: [[Poisson, 5], [Poisson, 10], [Poisson, 30], [Uniform, null]]

# set data path
DATA_PATH: ./Dataset/PreparedData/top_dataset.h5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成./datasets/top_dataset.h5吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

--8<--
```

优化器选用 Adam,训练代码如下:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边还是分"优化器构建"里写一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

--8<--
```

### 3.9 metric构建
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss和metric构建,loss既然也是自定义的,也介绍一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

本案例选择 binary accuracy 和 IoU 进行评估:
$\text{Bin. Acc.} = \frac{w_{00}+w_{11}}{n_{0}+n_{1}}$
$\text{IoU} = \frac{1}{2}\left[\frac{w_{00}}{n_{0}+w_{10}} + \frac{w_{11}}{n_{1}+w_{01}}\right]$
$n_{0} = w_{00} + w_{01}, \quad n_{1} = w_{10} + w_{11}$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

公式还是不居中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

--8<--
```

### 3.10 评估模型
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要介绍一下评估是怎么做的,评估的是哪些指标。然后可以在介绍之后加个 #### 3.10.1 评估器构建,把validator的代码单独写一下,避免展示重复代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

--8<--
```

## 5. 结果展示
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参照bracket,增加一点对图片内容的介绍,主要是让其他人能看懂。
参照tempoGAN,把指标(eval结果)用表格放上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor Author

@NKNaN NKNaN left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢review!

hydra:
run:
# dynamic output directory
dir: outputs_topopt/
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

dir: outputs_topopt/
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchaned
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

CASE_PARAM: [[Poisson, 5], [Poisson, 10], [Poisson, 30], [Uniform, null]]

# set data path
DATA_PATH: ./Dataset/PreparedData/top_dataset.h5
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

layers: 2

# training settings
TRAIN:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

shuffle: true

# evaluation settings
EVAL:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# define metric
def val_metric(output_dict, label_dict, weight_dict=None):
output = output_dict["output"]
y = label_dict["output"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# evaluation for different fixed iteration stop times
for stop_iter in iterations_stop_times:
# only evaluate for NUM_VAL_STEP times of iteration
X_data, Y_data = generate_train_test(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

"drop_last": cfg.EVAL.drop_last,
"shuffle": cfg.EVAL.shuffle,
},
"transforms": (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate_and_plot(cfg)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

iou_results_summary["thresholding"] = th_iou_results

# plot and save figures
plt.figure(figsize=(12, 6))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改用ppsci.utils.misc.plot_curve()

label_dict (Dict[str, np.adarray]): label dict of np.adarray size `(batch_size, 1, height, width)`
weight_dict (Dict[str, np.adarray]): weight dict if any

Returns:
Copy link
Contributor

@lijialin03 lijialin03 Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns删掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

label_dict: Dict[str, np.ndarray],
weight_dict: Dict[str, np.ndarray] = None,
) -> Tuple[
Dict[str, paddle.Tensor], Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数的返回dict里的值应该是np.darray形式,其他案例写错了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Returns:
Tuple[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]]: (transformed_input_dict, transformed_label_dict, transformed_weight_dict)
"""
inputs = paddle.to_tensor(input_dict["input"], dtype=paddle.get_default_dtype())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用变tensor了,不然会报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

new_perm = list(range(len(inputs.shape)))
new_perm[1], new_perm[2] = new_perm[2], new_perm[1]
inputs = paddle.transpose(inputs, perm=new_perm)
labels = paddle.transpose(labels, perm=new_perm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.xx的都需要改成np的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

cfg, model, model_path, data_iters, data_targets, iterations_stop_times
)

model_name = model_path.split("\\")[-3].split("_")[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个"\\"只是windows的,linux就会报错。
还写成类似sampler_key, num = cfg.CASE_PARAM[i]这样的吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,那这样的话就让用户传一个 EVAL.pretrained_model_path 然后这样写?
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样确实很合理,但是参考其他单一模型的案例,其实预训练模型是从网上下载的,把train和evaluate分开,也是为了让用户可以不必一定先训练,而是可以直接加载训练好的模型
image
所以如果要一次性加载多个,虽然很难受,但是似乎必须让用户在运行时输入命令后面加EVAL.pretrained_model_path=[path1,path2,...]或者EVAL.pretrained_model_path_possion5=path1 EVAL.pretrained_model_path_xx=path2,而不能直接指定xx/checkpoints/latest这样的路径

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白了,那改成这样可以吗?让用户在运行时输入命令后面加EVAL.pretrained_model_path=[path1,path2,...],同时加上对应的名字EVAL.case_name=[case_name1, case_name2,...]这样
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白了,那改成这样可以吗?让用户在运行时输入命令后面加EVAL.pretrained_model_path=[path1,path2,...],同时加上对应的名字EVAL.case_name=[case_name1, case_name2,...]这样 image

这里现在改成了下面这样
image
image
可以这样输入参数:

python topopt.py 'mode=eval' 'EVAL.pretrained_model_path_dict={"Uniform": "path1",  "Poisson5": "path2",  "Poisson10": "path3",  "Poisson30": "path4"}'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,这样可以的

current_iou_results = []

# only calculate for NUM_VAL_STEP times of iteration
for _ in range(10):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10是哪里来的数字呀,是cfg.EVAL.num_val_step吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦哦这里应该是 cfg.EVAL.num_val_step,已修改

acc_results_summary["thresholding"] = th_acc_results
iou_results_summary["thresholding"] = th_iou_results

# # plot and save figures
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用的代码就删掉吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

data_iters, data_targets, 1.0, cfg.EVAL.batch_size * cfg.EVAL.num_val_step
)

sup_validator = ppsci.validate.SupervisedValidator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validator一定要放这里吗?放这里运行的时候会报错,错误比较像是程序运行结束,但是data_loader还在等着读取数据
如果实在不行,evaluate里不用validator也行,可以参考deephpms最新的evaluate代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是因为 validator 放在循环里面了吗,如果要把 validator 放在循环外面就需要它包含所有的10000个数据,但是我不知道怎么设置 solver 让 solver.eval() 运行的时候只跑 cfg.EVAL.num_val_step 个batch (比如10个)。我试了如果把 validator 放在循环外面,然后 solver.epochs 设为1,solver.iters_per_epoch 设为 cfg.EVAL.num_val_step (比如10),validator 里 drop_last 设为 True,solver.eval() 还是会跑 10000/batch_size 个 iteration; solver.epochs 设为 cfg.EVAL.num_val_step,solver.iters_per_epoch 设为1,solver.eval() 还是会跑 10000/batch_size 个 iteration

Copy link
Contributor

@lijialin03 lijialin03 Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先说一下结论,我们又查了一下发现,只要设置一下这个就不会报错了,估计是底层数据读取的时候自动开了多线程
image
然后说一下,eval的逻辑是运行n(1)个epoch,每个epoch 跑 iterations(10)次,每次会从所有数据(160)中读取batch_size(16)个
我理解现在是每个 sampler 调用 solver.eval() 80/5=16次,每次的iteration是10(cfg.EVAL.num_val_step),batch_size是16(cfg.EVAL.batch_size),数据量是10x16=160?其实validator放外面之后,还按照现在这样设置,只是在里面调solver.eval应该就行。但是这样有个问题是按照原先的写法,validator输入的数据是每次循环新生成的,这部分数据应该是所有数据随机后,其中的前160*1.0个,也就是每次循环都会改变,使用不同的数据,这个是个问题
所以就按照第一行,加一下"num_workers": 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外上次漏了,现在代码里有几个dtype="float32",都改成dtype=paddle.get_default_dtype()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样的。好的,麻烦你们啦~

Copy link
Contributor

@lijialin03 lijialin03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦~
代码部分看起来差不多了,之后应该完善一下文档和指标就行了

CASE_PARAM: [[Poisson, 5], [Poisson, 10], [Poisson, 30], [Uniform, null]]

# set data path
DATA_PATH: ./Dataset/top_dataset.h5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

跟其他案例统一,改成./datasets/top_dataset.h5吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

learning_rate: 0.001
epsilon:
log_loss: 0.0000001
optimizer: 0.0000001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

太长了改成科学技术法,1.0e-7吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@NKNaN
Copy link
Contributor Author

NKNaN commented Nov 14, 2023

辛苦~ 代码部分看起来差不多了,之后应该完善一下文档和指标就行了

感谢!

又发现了一个问题:在AI Studio上拉取最新的 ppsci 之后构建 SupervisedConstraint 这一步就有问题了,报错是这样子
image

看起来是这个更新导致的

把报错里面的 data 打印出来是这种形式:( ( {'input': input_dataset}, {'output': output_dataset}, {} ), )
看起来多套了一层 tuple,不知道这块要怎么改一下好

@lijialin03
Copy link
Contributor

lijialin03 commented Nov 15, 2023

把报错里面的 data 打印出来是这种形式:( ( {'input': input_dataset}, {'output': output_dataset}, {} ), ) 看起来多套了一层 tuple,不知道这块要怎么改一下好

之前有这个错误,是代码中多加了一个括号导致的,这个pr已经去掉了括号pr619,改成了:
image
尝试使用git pull --rebase upstream develop重新拉取一下上游仓库试试吧

@NKNaN
Copy link
Contributor Author

NKNaN commented Nov 15, 2023

尝试使用git pull --rebase upstream develop重新拉取一下上游仓库试试吧

已解决,感谢!

@NKNaN
Copy link
Contributor Author

NKNaN commented Nov 15, 2023

指标和原代码的对比:

bin_acc eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Poisson5 0.947 0.961 0.970 0.974 0.980 0.982 0.983 0.985 0.987
paper-Poisson5 0.958 0.973 0.977 0.979 0.982 0.984 0.985 0.986 0.987
diff-Poisson5 1.14% 1.13% 0.69% 0.49% 0.19% 0.14% 0.13% 0.10% 0.02%
bin_acc eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Poisson10 0.945 0.970 0.974 0.979 0.984 0.987 0.988 0.989 0.990
paper-Poisson10 0.954 0.976 0.981 0.984 0.987 0.989 0.990 0.990 0.990
diff-Poisson10 0.87% 0.58% 0.66% 0.42% 0.24% 0.20% 0.20% 0.01% 0.04%
bin_acc eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Poisson30 0.925 0.959 0.973 0.983 0.988 0.989 0.991 0.992 0.993
paper-Poisson30 0.927 0.963 0.978 0.985 0.990 0.992 0.994 0.995 0.996
diff-Poisson30 0.13% 0.36% 0.44% 0.18% 0.17% 0.28% 0.24% 0.24% 0.24%
bin_acc eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Uniform 0.941 0.967 0.971 0.972 0.982 0.984 0.989 0.990 0.992
paper-Uniform 0.947 0.968 0.977 0.982 0.987 0.990 0.993 0.994 0.996
diff-Uniform 0.62% 0.07% 0.53% 0.95% 0.46% 0.55% 0.37% 0.33% 0.38%
iou eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Poisson5 0.899 0.926 0.942 0.949 0.961 0.965 0.967 0.970 0.974
paper-Poisson5 0.920 0.947 0.954 0.960 0.965 0.969 0.971 0.973 0.974
diff-Poisson5 2.22% 2.14% 1.24% 1.07% 0.41% 0.33% 0.32% 0.26% 0.08%
iou eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Poisson10 0.896 0.942 0.950 0.960 0.969 0.974 0.976 0.980 0.981
paper-Poisson10 0.911 0.953 0.964 0.969 0.974 0.978 0.980 0.980 0.981
diff-Poisson10 1.54% 1.11% 1.42% 0.88% 0.45% 0.37% 0.38% 0.01% 0.00%
iou eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Poisson30 0.861 0.922 0.948 0.967 0.976 0.978 0.983 0.985 0.987
paper-Poisson30 0.864 0.929 0.957 0.970 0.981 0.985 0.988 0.990 0.992
diff-Poisson30 0.26% 0.73% 0.85% 0.30% 0.41% 0.64% 0.47% 0.47% 0.47%
iou eval_dataset_ch_5 eval_dataset_ch_10 eval_dataset_ch_15 eval_dataset_ch_20 eval_dataset_ch_30 eval_dataset_ch_40 eval_dataset_ch_50 eval_dataset_ch_60 eval_dataset_ch_80
ppsci-Uniform 0.888 0.936 0.945 0.946 0.965 0.969 0.978 0.981 0.984
paper-Uniform 0.900 0.939 0.955 0.964 0.975 0.981 0.986 0.988 0.992
diff-Uniform 1.25% 0.24% 1.02% 1.78% 0.97% 1.17% 0.73% 0.64% 0.75%

@lijialin03
Copy link
Contributor

代码冲突了,看mkdocs.yml文件会有这么多修改,应该是拉的不是最新的,要拉最新分支修改。看你仓库里这个分支的commit,最后一个非你进行的修改在10.23,需要更新一下。

```


### 3.4 transform构建
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在文档里的transform都是指网络输入输出transform,所以这边改成data transform吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


### 3.6 采样器构建

原始数据有100个通道对应的是 SIMP 算法 100 次的迭代结果,本案例模型目标是用 SIMP 中间某一步的迭代结果直接预测 SIMP 最后一步的迭代结果,而论文原始代码中的模型输入是原始数据对通道进行采样后的数据,为应用 PaddleScience API,本案例将采样步骤放入模型的 forward 方法中,所以需要传入不同的采样器。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档主要是把内容介绍清楚就行,不用太强调跟原论文的比较,修改一下这边,只将现在的代码是怎么做的,做了什么,为什么这么做,讲清楚就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

<figcaption>IoU结果</figcaption>
</figure>

结果与[原始代码结果](https://github.com/ISosnovik/nn4topopt/blob/master/results.ipynb)基本一致
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,这边去掉这句吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


结果与[原始代码结果](https://github.com/ISosnovik/nn4topopt/blob/master/results.ipynb)基本一致

此外将这些计算的指标与论文中展示的对应指标 (`table 1` 与 `table2` 中的) 对比,指标相对误差均小于 10%,以下表格是所有的指标计算结果:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边也是,改成“用表格表示上图指标为:”这类的吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

dtype=paddle.get_default_dtype(),
)
n0 = paddle.add(w01, w00)
n1 = paddle.add(w11, w10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边wxx/nx等变量名改一下吧,让名字有点具体含义,我看这边是求iou计算里的TP、FP、TN、FN?对照改一下吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -0,0 +1,157 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件名字改成小写,topopt里的import也一起改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

learning_rate: 0.001
epsilon:
log_loss: 1.0e-7
optimizer: 1.0e-7
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个epsilon是不是调试的时候基本不会改,那就直接在程序里写死吧,不用特别拿出来写到yaml文件里了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

import paddle
from functions import augmentation
from functions import generate_sampler
from functions import generate_train_test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考tempogan,改成import functions as func_module, 用的时候func_module.augmentation这样吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

=== "模型训练命令"

``` sh
python topopt.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据集我已上传到 https://paddle-org.bj.bcebos.com/paddlescience/datasets/topopt/top_dataset.h5,参考 别的案例比如hpinns 添加一下数据下载命令

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

=== "模型评估命令"

``` sh
python topopt.py 'mode=eval' 'EVAL.pretrained_model_path_dict={"Uniform": "path1", "Poisson5": "path2", "Poisson10": "path3", "Poisson30": "path4"}'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

path1/2/3/4可以先写上,比如 https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams之后我再上传就行
如果你现在有现成的,方便的话可以传到aistudio上放数据集那里吗,我下下来上传,没有的话就我之后跑一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,模型放到aistudio数据集那里了

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit fb6e1bd into PaddlePaddle:develop Nov 21, 2023
4 checks passed
@NKNaN NKNaN deleted the ayase-develop2 branch December 26, 2023 14:57
zhaojiameng pushed a commit to zhaojiameng/PaddleScience that referenced this pull request May 20, 2024
…ePaddle#597)

* update hackthon no. 57

* delete TopOpt

* add example/topopt

* update

* modify hydra

* add md

* fix code

* fix yaml

* fix yaml

* update code

* recommit

* fix code

* fix docs

* remove TopOptModel.py

* add topoptmodel.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants