-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
split vector-matrix norm #5478
split vector-matrix norm #5478
Conversation
oneflow/python/nn/modules/norm.py
Outdated
def __init__(self, ord=None, dim=None, keepdim=False) -> None: | ||
super().__init__() | ||
|
||
self.ord = ord | ||
self.ord = 2 if ord == None else ord |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续感觉还需要一个if else判断,让他变成float输入到后续
if not isinstance(ord, str):
ord = float(ord)
oneflow/python/nn/modules/norm.py
Outdated
), | ||
1.0 / ord, | ||
) | ||
|
||
elif isinstance(ord, int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支处理没有必要,前面做一个float转换。因为实际上int分支和float分支的计算逻辑是一样的
oneflow/python/nn/modules/norm.py
Outdated
@@ -51,6 +58,19 @@ def _vector_norm(self, x, ord, dim): | |||
else: | |||
raise ValueError("Invalid norm order: {}".format(ord)) | |||
|
|||
def forward(self, x): | |||
return self._vector_norm(x.reshape((1, -1))[0], ord = self.ord, dim=self.dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape两维是不正确的,torch是可以支持多维的。
oneflow/python/nn/modules/norm.py
Outdated
def __init__(self, ord=None, dim=None, keepdim=False) -> None: | ||
super().__init__() | ||
|
||
self.ord = ord | ||
self.ord = 2 if ord == None else ord | ||
self.dim = dim | ||
self.keepdim = keepdim | ||
|
||
def _vector_norm(self, x, ord, dim): | ||
if isinstance(ord, str) and ord in ["fro", "nuc"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还需要检查dim是否超过了len(x.shape)
self.ord = ord | ||
if ord == None: | ||
self.ord = 2.0 | ||
elif isinstance(ord, (int, float)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float情况可以去除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去除了 float 的话,传入 float 就会报错了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
噢噢那这里不用去
oneflow/python/nn/modules/norm.py
Outdated
|
||
|
||
class Vector_Norm(Module): | ||
def __init__(self, ord, dim, keepdim) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oneflow/python/nn/modules/norm.py
Outdated
num_dims = len(x.shape) | ||
dim = check_dim(num_dims, self.dim) | ||
if dim == None: | ||
return self._vector_norm(x.reshape((1, -1))[0], ord = self.ord, dim=self.dim, keepdim= self.keepdim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape((1, -1))这里可以改成flatten吧
oneflow/python/nn/modules/norm.py
Outdated
def _matrix_norm(self, x, ord, dim): | ||
|
||
class Matrix_Norm(Module): | ||
def __init__(self, ord, dim, keepdim) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oneflow/python/nn/modules/norm.py
Outdated
raise TypeError("linalg_matrix_norm(): argument 'ord' must be Number, not {}".format(type(ord))) | ||
if isinstance(dim,tuple) and len(dim) == 2 and dim[0] != dim[1]: | ||
self.dim = dim | ||
elif dim == None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分按照前面设置好默认值,可以去掉这部分判断逻辑,不用判断dim==None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这边感觉不能去掉吧,linalg.norm 的默认 dim == None,在那个里面调 Matrix_norm 的话就会把 none 传进来了(line 157)。或者就是我在 linalg.norm 里判断一下
oneflow/python/nn/modules/norm.py
Outdated
raise NotImplementedError | ||
elif ord == "fro": | ||
return flow.experimental.sqrt( | ||
flow.experimental.sum(flow.experimental.square(x), dim=dim, keepdim= keepdim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keepdim= 后面有个空格
oneflow/python/nn/modules/norm.py
Outdated
|
||
def _norm_min_max(input, ord,dim,keepdim): | ||
if ord > 0: | ||
return flow.experimental.max(input, dim= dim,keepdim = keepdim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
=后面这些空格都去掉
oneflow/python/nn/modules/norm.py
Outdated
|
||
def _norm_min_max(input, ord,dim,keepdim): | ||
if ord > 0: | ||
return flow.experimental.max(input, dim= dim,keepdim = keepdim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
=后面这些空格都去掉
self.ord = ord | ||
if ord == None: | ||
self.ord = 2.0 | ||
elif isinstance(ord, (int, float)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
噢噢那这里不用去
torch.norm
后续会弃用,因此选择与torch.linalg.norm
接口对齐。vector_norm
支持 float 类型 order