Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix backward index for gamma beta #6149

Merged
merged 14 commits into from
Sep 7, 2021
Merged

fix backward index for gamma beta #6149

merged 14 commits into from
Sep 7, 2021

Conversation

MARD1NO
Copy link
Contributor

@MARD1NO MARD1NO commented Sep 3, 2021

问题描述:

Eager下 layernorm前向结果一致,后向不一致,而且weight,bias的grad和torch的是相反过来

初步解决:

猜测是functional赋梯度赋反了,也确实是weight, bias的梯度赋错了,已经修复

  • 遗留问题:输入的梯度依旧不对
  • 遗留问题:Autotest

关于layernorm和batchnorm的联系

layernorm可以看作是batchnorm的一种变种

以输入(b, c, t),其layernorm(c, t) 等价于,将输入reshape成 (1, b, c*t) 做batchnorm(num_features=3)

import torch
import torch.nn as nn
import numpy as np

x = np.random.randn(3, 2, 2)

torch_x_tensor1 = torch.Tensor(x)
torch_x_tensor2 = torch.Tensor(x)
torch_x_tensor2 = torch_x_tensor2.reshape((1, 3, -1)) # 1, 3, 4

layernorm = nn.LayerNorm((2, 2))
layernorm_out = layernorm(torch_x_tensor1)
print(layernorm.weight.shape) # its weight should be (2, 2)

print("Layer norm out is: ", layernorm_out)

bn = nn.BatchNorm1d(num_features=3) # its weight should be (3, ) equals to num of channels. 
print(bn.weight.shape)
bn_out = bn(torch_x_tensor2)
bn_out = torch.reshape(bn_out, shape=(3, 2, 2))
print("Bn out is: ", bn_out)

但我认为这两者并不严格等价,layernorm的weight和bias的形状应该是(2, 2),而bn的weight是(3, )

阅读资料

cudnn文档

batchnorm反向传播推导

目前进展:

  1. 和喧哥确认,之前lazy实现GPT,和tensorflow,pytorch是能够对齐的
  2. 根据chengpeng,liaoxingyu的反馈,在graph跑layernorm也是没问题的

所以还是能够认为以前的这套实现没有问题,可能还是eager,functional这块儿有些问题

@yuanms2
Copy link
Contributor

yuanms2 commented Sep 3, 2021

这个发现很有意思,有没有其他同事能一起确认一下

@yuanms2
Copy link
Contributor

yuanms2 commented Sep 3, 2021

查看了一下git blame, 是蔡晟航从基于别的同事开发的版本迁移过来的,当时也发现了计算结果和tf 有一定的差别:

“目前测试发现 epsilon 较大的时候,和 tf 的误差也会较大,甚至达到 1e-1,旧版本也是如此”

#2781

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 3, 2021

这个发现很有意思,有没有其他同事能一起确认一下

更新了下目前能够确认的内容,目前有guoran,yinggang,depeng在帮忙一起看

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 3, 2021

无意间在python端,改变了gamma 和 beta的传入顺序

虽然输入不对齐,但是梯度更新正确了

后续定位到在传入layernorm_param_grad的时候,把gamma传成了beta,修复后能与torch对齐

@wyg1997
Copy link
Contributor

wyg1997 commented Sep 3, 2021

还有个问题先记录一下:LayerNorm module 的 forward 中只是用到了 self.normalized_shape 的长度来计算出 norm_axis,并没有真正去 check 成员的 shape 是否对得上。在 TensorDescInfer 时才 CHECK 出错,报错信息也不够清晰,bug 修复后需要注意一下。

ctx->has_normalized_diff = ctx->scale && inputs.at(0)->requires_grad();
if (ctx->has_gamma_diff || ctx->has_normalized_diff) {
ctx->gamma_index = ctx->SaveTensorForBackward(inputs.at(gamma_index));
ctx->gamma_index = ctx->SaveTensorForBackward(inputs.at(1)); // save gamma.
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 inputs.at(1) 算不算硬编码

if (ctx->has_gamma_diff) {
in_grads->at(ctx->has_beta_diff + 1) = results->at(ctx->has_beta_diff);
in_grads->at(1) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at(1) 算不算硬编码

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在输入和输出都是有序的,这个顺序是在op expr build的时候就确定下来的,这里硬编码没什么问题,效率更高

if (ctx->has_gamma_diff) {
in_grads->at(ctx->has_beta_diff + 1) = results->at(ctx->has_beta_diff);
in_grads->at(1) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在输入和输出都是有序的,这个顺序是在op expr build的时候就确定下来的,这里硬编码没什么问题,效率更高

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 6, 2021

记录一个梯度的问题

有些情况下,还是会出现layernorm输出和梯度与torch不对齐的情况(gpu/cpu都会出现)

以一个例子:

import oneflow as flow
import torch
import numpy as np 

x_np = np.array([[[[-1.83965693, -1.82964566]]]])
print(x_np.shape)
affine = False
device = "cuda"

normalized_shape = (1, 1, 2)

of_x_tensor = flow.Tensor(x_np).to(device)
of_x_tensor.requires_grad = True


of_layernorm = flow.nn.LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=affine).to(device)

of_out = of_layernorm(of_x_tensor)
print("Oneflow out is: ", of_out)
of_out = of_out.sum()
of_out.backward()
print("Of x tensor grad is: ", of_x_tensor.grad)


torch_x_tensor = torch.Tensor(x_np).to(device)
torch_x_tensor.requires_grad = True

torch_layernorm = torch.nn.LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=affine).to(device)

torch_out = torch_layernorm(torch_x_tensor)
print("Pytorch out is: ", torch_out)
torch_out = torch_out.sum()
torch_out.backward()
print("Torch x tensor grad is: ", torch_x_tensor.grad)

以下是一些实验和猜测:

实验1

使用最naive的方式,用mean和var拼凑:

import oneflow as flow
# import torch as flow 
import numpy as np 


device = "cpu"
eps = 1e-5
x_np = np.array([[[[-1.83965693, -1.82964566]]]])
of_x_tensor = flow.Tensor(x_np).to(device)
of_x_tensor.requires_grad = True

mean = flow.mean(of_x_tensor, dim=(2, 3), keepdim=True)
var = (flow.var(of_x_tensor, dim=(2, 3), keepdim=True, unbiased=False) + eps).rsqrt()
of_out = (of_x_tensor - mean) * var
print("X out is: ", of_out)

of_out = of_out.sum()
of_out.backward()
print("Grad is: ", of_x_tensor.grad)

按照拼凑的方式,有一些结论:

  1. torch拼凑出来的,和oneflow拼凑出来的,结果,梯度一致
  2. 尽管按照原公式拼凑,结果依旧不等价LayerNorm算子(torch.Layernorm输出为[[[[-0.8428, 0.8428]]]],拼凑实现输出为[[[[-0.8454, 0.8454]]]])

猜测是:

  1. 现有的拼凑方式有问题
  2. pytorch的LayerNorm有bug

实验2

与Paddle的实现进行对比

import paddle 
import numpy as np 


x = np.array([[[[-1.83965693, -1.82964566]]]]).astype(np.float32)

x_tensor = paddle.to_tensor(x).cpu()
x_tensor.stop_gradient = False

layernorm = paddle.nn.LayerNorm(normalized_shape=(1, 1, 2), epsilon=1e-5)
    
out = layernorm(x_tensor)
print("Out is: ", out)
out = out.sum()
out.backward()

print("X grad is: ", x_tensor.grad.numpy())

"""
Out is:  Tensor(shape=[1, 1, 1, 2], dtype=float32, place=CPUPlace, stop_gradient=False,
       [[[[-0.84543723,  0.84541708]]]])

X grad is:  [[[[-0.00143835  0.00143831]]]]
"""
  1. paddle的输出,和oneflow Layernorm,以及拼凑的实现,得到的梯度,结果都一致,存在一定可能pytorch代码存在bug

@simonJJJ
Copy link
Contributor

simonJJJ commented Sep 6, 2021

This is because of the different algorithms of LayerNorm between PyTorch and oneflow. The Welford algorithm in PyTorch needs higher float precision, after converting torch tensor to float64 could align with oneflow.

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 6, 2021

This is because of the different algorithms of LayerNorm between PyTorch and oneflow. The Welford algorithm in PyTorch needs higher float precision, after converting torch tensor to float64 could align with oneflow.

感谢世杰帮助

oneflow, lightseq, deepx_core都采用的是比较naive的计算方式,统计得到sum,和sum_square,最后同步,根据公式

D(X) = E(X^2) - E(X)^2 

得到方差

而pytorch采取了welford在线算法,在线更新均值和方差。https://changyaochen.github.io/welford/

在dtype为torch.float64下,得到的结果能够对齐

Copy link
Contributor

@wyg1997 wyg1997 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 python LayerNorm forward 最上面加一个 check 吧:

for i in range(0, len(self.normalized_shape)):
    if x.shape[i + self.begin_params_axis] != self.normalized_shape[i]:
        raise RuntimeError(f"Given normalized_shape={self.normalized_shape}, expected input with shape [*, {self.normalized_shape[-1]}], but got input of size {x.shape}")

@MARD1NO MARD1NO marked this pull request as ready for review September 6, 2021 23:31
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot September 6, 2021 23:37
@MARD1NO MARD1NO requested review from oneflow-ci-bot and removed request for oneflow-ci-bot September 6, 2021 23:37
@github-actions
Copy link
Contributor

github-actions bot commented Sep 7, 2021

CI failed, removing label automerge

@github-actions github-actions bot removed the automerge label Sep 7, 2021
@github-actions
Copy link
Contributor

github-actions bot commented Sep 7, 2021

CI failed, removing label automerge

@oneflow-ci-bot oneflow-ci-bot removed their request for review September 7, 2021 00:06
@oneflow-ci-bot oneflow-ci-bot removed their request for review September 7, 2021 00:34
@github-actions
Copy link
Contributor

github-actions bot commented Sep 7, 2021

Speed stats:
GPU Name: GeForce GTX 1080 

OneFlow resnet50 time: 128.6ms (= 6427.8ms / 50, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 142.0ms (= 7101.7ms / 50, input_shape=[16, 3, 224, 224])
Relative speed: 1.10 (= 142.0ms / 128.6ms)

OneFlow resnet50 time: 74.8ms (= 3740.4ms / 50, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 86.9ms (= 4345.2ms / 50, input_shape=[8, 3, 224, 224])
Relative speed: 1.16 (= 86.9ms / 74.8ms)

OneFlow resnet50 time: 48.7ms (= 2434.7ms / 50, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 58.1ms (= 2903.4ms / 50, input_shape=[4, 3, 224, 224])
Relative speed: 1.19 (= 58.1ms / 48.7ms)

OneFlow resnet50 time: 44.1ms (= 2202.7ms / 50, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 47.8ms (= 2391.0ms / 50, input_shape=[2, 3, 224, 224])
Relative speed: 1.09 (= 47.8ms / 44.1ms)

OneFlow resnet50 time: 40.2ms (= 2011.6ms / 50, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 41.5ms (= 2076.1ms / 50, input_shape=[1, 3, 224, 224])
Relative speed: 1.03 (= 41.5ms / 40.2ms)

OneFlow resnet50 time: 154.6ms (= 7729.0ms / 50, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 159.4ms (= 7969.1ms / 50, input_shape=[16, 3, 224, 224], ddp, world size=2)
Relative speed: 1.03 (= 159.4ms / 154.6ms)

OneFlow resnet50 time: 102.2ms (= 5108.5ms / 50, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 100.1ms (= 5007.4ms / 50, input_shape=[8, 3, 224, 224], ddp, world size=2)
Relative speed: 0.98 (= 100.1ms / 102.2ms)

OneFlow resnet50 time: 76.5ms (= 3824.0ms / 50, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 84.0ms (= 4200.8ms / 50, input_shape=[4, 3, 224, 224], ddp, world size=2)
Relative speed: 1.10 (= 84.0ms / 76.5ms)

OneFlow resnet50 time: 66.9ms (= 3344.0ms / 50, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 63.1ms (= 3155.3ms / 50, input_shape=[2, 3, 224, 224], ddp, world size=2)
Relative speed: 0.94 (= 63.1ms / 66.9ms)

OneFlow resnet50 time: 70.5ms (= 3523.2ms / 50, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.7ms (= 3336.4ms / 50, input_shape=[1, 3, 224, 224], ddp, world size=2)
Relative speed: 0.95 (= 66.7ms / 70.5ms)

@oneflow-ci-bot oneflow-ci-bot removed their request for review September 7, 2021 01:44
@oneflow-ci-bot oneflow-ci-bot merged commit 491c9ea into master Sep 7, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the debug_layernorm branch September 7, 2021 01:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants