Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix upsample sbp infer bug and add global test #7884

Merged
merged 44 commits into from
Apr 11, 2022

Conversation

clackhan
Copy link
Contributor

添加var 和 upsample global 测试

python/oneflow/test/modules/test_consistent_var.py Outdated Show resolved Hide resolved

# backward compute result of oneflow is not same with pytorch
@autotest(n=1, auto_backward=False, check_graph=False)
def _test_global_upsample2d_bicubic(test_case, placement, sbp):
Copy link
Contributor Author

@clackhan clackhan Mar 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bicubic模式下,oneflow 后向计算结果与pytorch对不上,不知道是否是在实现上有差异

复现命令:设置```auto_backward=True````,

python test_consistent_upsample.py --verbose --failfast

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要确定是否真的对不上,我们要以Pytorch1.11为准。如果确实对不上,那么就更新到:https://github.com/Oneflow-Inc/OneTeam/issues/1207#issuecomment-1073432125 ,我来debug。

Copy link
Contributor Author

@clackhan clackhan Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要确定是否真的对不上,我们要以Pytorch1.11为准。如果确实对不上,那么就更新到:Oneflow-Inc/OneTeam#1207 (comment) ,我来debug。

升级pytorch到1.11后(原本是在1.10下侧的),后向计算结果仍然不一样,已更新在Oneflow-Inc/OneTeam#1207 (comment)

Copy link
Contributor

@BBuf BBuf Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clackhan 此bug已在#7916 中修复。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clackhan 此bug已在https://github.com/Oneflow-Inc/oneflow/pull/7916中修复。

合并 pr7916 后,打开后向测试,直接 Abroted 的了,关闭后向没有问题,可以正常跑,报错信息如下:

python test_consistent_upsample.py --verbose --failfast
test_global_upsample2d_bicubic (__main__.TestGlobalUpsample2d) ... Environment has been initialized, this env init will do nothing.
/home/hanbinbin/anaconda3/envs/oneflow/lib/python3.8/site-packages/torch/_tensor.py:1104: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.)
  return self._grad
free(): corrupted unsorted chunks
Aborted (core dumped)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

神奇,我再看看

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

神奇,我再看看

好的

Comment on lines +79 to +83
@unittest.skip(
"The nearest interpolate operation in pytorch has bug, https://github.com/pytorch/pytorch/issues/65200"
)
@globaltest
def test_global_upsample2d_nearest(test_case):
Copy link
Contributor Author

@clackhan clackhan Mar 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个里pytroch中的issue已经close了,但是这测结果还是和pytorch对不上,不知道只oneflow的问题还是pytorch的问题

复现命令:注释@unittest.skip

python test_consistent_upsample.py --verbose --failfast

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能和我们CI环境下的PyTorch版本比较旧有关,这里暂时也和Local一样skip吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch 兼容计划里面有个初步约定是以 PyTorch 1.11 作为兼容标准,按说后面CI可以同统一升级到 PyTorch 1.11 ?

@BBuf @hjchen2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@caishenghang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前我和盛航在升级的过程中碰到了很多问题,还没有来得及一一解决。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边重复测试了一下,发现原因是因为pytorch的cpu和gpu结果在缩放系数不是整数情况下跑出的结果对不上。这个bug我之前确实反馈了,但pytorch不修就直接把我issue关了,这个问题先不管吧。

@github-actions
Copy link
Contributor

github-actions bot commented Apr 9, 2022

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7884/

@github-actions
Copy link
Contributor

github-actions bot commented Apr 9, 2022

CI failed when running job: cpu-misc. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label Apr 9, 2022
@github-actions
Copy link
Contributor

github-actions bot commented Apr 9, 2022

Speed stats:

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7884/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

✔️ OneFlow resnet50 time: 128.8ms (= 12875.5ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 140.2ms (= 14020.3ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.09 (= 140.2ms / 128.8ms)

OneFlow resnet50 time: 80.1ms (= 8005.1ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.2ms (= 8415.7ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.05 (= 84.2ms / 80.1ms)

OneFlow resnet50 time: 49.1ms (= 9822.5ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 55.3ms (= 11056.8ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.13 (= 55.3ms / 49.1ms)

OneFlow resnet50 time: 44.3ms (= 8861.1ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 47.5ms (= 9493.3ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.07 (= 47.5ms / 44.3ms)

OneFlow resnet50 time: 38.5ms (= 7705.8ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 38.6ms (= 7720.6ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.00 (= 38.6ms / 38.5ms)

OneFlow swin dataloader time: 0.250s (= 50.057s / 200, num_workers=1)
PyTorch swin dataloader time: 0.251s (= 50.246s / 200, num_workers=1)
✔️ Relative speed: 1.004 (= 0.251s / 0.250s)

OneFlow swin dataloader time: 0.066s (= 13.257s / 200, num_workers=4)
PyTorch swin dataloader time: 0.068s (= 13.576s / 200, num_workers=4)
✔️ Relative speed: 1.024 (= 0.068s / 0.066s)

OneFlow swin dataloader time: 0.036s (= 7.240s / 200, num_workers=8)
PyTorch swin dataloader time: 0.038s (= 7.600s / 200, num_workers=8)
✔️ Relative speed: 1.050 (= 0.038s / 0.036s)

✔️ OneFlow resnet50 time: 135.5ms (= 13554.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 161.1ms (= 16107.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 161.1ms / 135.5ms)

OneFlow resnet50 time: 87.0ms (= 8696.5ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 98.6ms (= 9857.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.13 (= 98.6ms / 87.0ms)

OneFlow resnet50 time: 61.5ms (= 12293.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 76.2ms (= 15248.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.24 (= 76.2ms / 61.5ms)

OneFlow resnet50 time: 51.9ms (= 10382.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.3ms (= 13258.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.28 (= 66.3ms / 51.9ms)

OneFlow resnet50 time: 49.8ms (= 9953.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 62.1ms (= 12425.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.25 (= 62.1ms / 49.8ms)

@mergify mergify bot merged commit d3d7f2c into master Apr 11, 2022
@mergify mergify bot deleted the add_var_upsample_global_test branch April 11, 2022 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants