Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record autotest wrong code #5923

Merged
merged 29 commits into from Aug 19, 2021
Merged

Record autotest wrong code #5923

merged 29 commits into from Aug 19, 2021

Conversation

BBuf
Copy link
Contributor

@BBuf BBuf commented Aug 17, 2021

  • 在自动测试框架失败的时候,自动生成绿色的代码文本。
  • 添加lonflow.log和oneflow.og1p的自动测试。
  • 重写比较运算符的Tensor方法测试的样例。
  • refine Tensor.xxx的一些测试方法,删掉一些重复的测试样例。

以Conv3d为例子,产生的代码文本如下:

Conv3d(in_channels=3, out_channels=3, kernel_size=2, stride=[4, 1, 4], padding=1, dilation=1, groups=3, padding_mode='zeros')
train(True)
to('cpu')
to('cpu')
sum()
backward()

log1p产生的代码文本如下:

to('cuda')
log1p()
sum()
backward()

在pow的module中制造一个bug,然后测试可以发现输出了可以帮助定位BUG的代码文本:

图片

图片

@BBuf BBuf requested a review from daquexian August 17, 2021 09:54
python/oneflow/test/tensor/test_tensor.py Show resolved Hide resolved
@@ -33,6 +33,9 @@
def torch_tensor_to_flow(x):
return flow.tensor(x.cpu().numpy())

note_pytorch_method_names = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是作为全局的一个列表来用,我建议用全大写字母,并且放在最上面

@oneflow-ci-bot oneflow-ci-bot self-requested a review August 18, 2021 09:04


def note_print_kwargs(x, y, end=True):
if end == True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if end == True:
if end:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的


def note_print_kwargs(x, y, end=True):
if end == True:
if type(y) is str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if type(y) is str:
if isinstance(y, str):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

if type(y) is str:
print("\033[32m{}='{}'\033[0m".format(x, y), end="")
else:
print("\033[32m{}={}\033[0m".format(x, y), end="")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用 f-string,比 str.format 更易用一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@@ -168,6 +190,7 @@ def dual_method(self, *args, **kwargs):
*pytorch_args, **pytorch_kwargs
)
except Exception as e:
clear_note_fake_program()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这和 python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py:373 的改动是不是重复了

如果代码存在多个返回点(比如多个 return、抛出异常)的话,用 with 语句块的 __exit__ 代替在每个返回点手动调用 clear_note_fake_program() 吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,是重复了,想了下只需要在程序中保留一个清空函数就可以了。

@daquexian
Copy link
Contributor

有没有可能性把 module 对象的 __call__ 方法的参数也打印出来,以及对 tensor x=torch.tensor([1,2]) 调用 x.pow() 的时候,打印 tensor([1,2]).pow()

@oneflow-ci-bot oneflow-ci-bot removed their request for review August 18, 2021 10:22
@oneflow-ci-bot oneflow-ci-bot self-requested a review August 18, 2021 15:29
@BBuf BBuf added the bug label Aug 18, 2021
@BBuf BBuf requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 18, 2021 15:53
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 18, 2021 17:10
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 18, 2021 19:27
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 18, 2021 20:38
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 18, 2021 21:35
@oneflow-ci-bot oneflow-ci-bot self-requested a review August 18, 2021 23:28
@oneflow-ci-bot oneflow-ci-bot removed their request for review August 19, 2021 00:49
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot August 19, 2021 02:18
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

PyTorch resnet50 time: 139.5ms (= 6977.0ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 127.7ms (= 6386.4ms / 50, input_shape=[16, 3, 224, 224], backward is enabled)
Relative speed: 1.09 (= 139.5ms / 127.7ms)

PyTorch resnet50 time: 83.2ms (= 4159.5ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 74.2ms (= 3710.7ms / 50, input_shape=[8, 3, 224, 224], backward is enabled)
Relative speed: 1.12 (= 83.2ms / 74.2ms)

PyTorch resnet50 time: 59.3ms (= 2964.6ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 47.3ms (= 2365.4ms / 50, input_shape=[4, 3, 224, 224], backward is enabled)
Relative speed: 1.25 (= 59.3ms / 47.3ms)

PyTorch resnet50 time: 48.5ms (= 2424.0ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 38.2ms (= 1908.1ms / 50, input_shape=[2, 3, 224, 224], backward is enabled)
Relative speed: 1.27 (= 48.5ms / 38.2ms)

PyTorch resnet50 time: 44.2ms (= 2212.1ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
OneFlow resnet50 time: 36.9ms (= 1847.3ms / 50, input_shape=[1, 3, 224, 224], backward is enabled)
Relative speed: 1.20 (= 44.2ms / 36.9ms)

@oneflow-ci-bot oneflow-ci-bot removed their request for review August 19, 2021 03:31
@oneflow-ci-bot oneflow-ci-bot merged commit 738480b into master Aug 19, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the record_autotest_wrong_code branch August 19, 2021 03:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants