Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev adam graph conf #5709

Merged
merged 50 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
5e79c4c
add conf for adam
MARD1NO Aug 3, 2021
8784447
add unittest for adam
MARD1NO Aug 3, 2021
562103f
Merge branch 'master' into dev_adam_graph_conf
MARD1NO Aug 4, 2021
cd0d197
revert back
MARD1NO Aug 4, 2021
513099b
Add eager like unittest, still have problem
MARD1NO Aug 4, 2021
7a17a45
Merge branch 'master' into dev_adam_graph_conf
MARD1NO Aug 4, 2021
dc3f117
init tensor in graph cause failure
MARD1NO Aug 4, 2021
ea56866
add sgd optim unittest
MARD1NO Aug 4, 2021
f9ad6ff
small fix
MARD1NO Aug 4, 2021
47730b5
modify back
MARD1NO Aug 4, 2021
71ca83d
update code
MARD1NO Aug 4, 2021
fe63384
add adam optimizer unittest
MARD1NO Aug 5, 2021
89206d0
remove redundant
MARD1NO Aug 5, 2021
c63f453
test sgd
MARD1NO Aug 5, 2021
650c367
modify back
MARD1NO Aug 5, 2021
8c9b0f9
Merge branch 'master' into dev_adam_graph_conf
MARD1NO Aug 5, 2021
e12787c
remove useless code
MARD1NO Aug 5, 2021
79ce30b
Merge branch 'dev_adam_graph_conf' of https://github.com/Oneflow-Inc/…
MARD1NO Aug 5, 2021
82294c8
fix comment
MARD1NO Aug 5, 2021
582c0e0
support do bias correction
MARD1NO Aug 6, 2021
199798c
add unittest when do bias correction is True
MARD1NO Aug 6, 2021
3ba357c
add unittest
MARD1NO Aug 6, 2021
37bdf3e
remove
MARD1NO Aug 6, 2021
59815bf
add do bias correction
MARD1NO Aug 6, 2021
7fab457
Merge branch 'master' into dev_adam_graph_conf
Ldpe2G Aug 9, 2021
9645ac5
auto format by CI
oneflow-ci-bot Aug 9, 2021
f8f28d2
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 9, 2021
b45c415
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 9, 2021
291ab71
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 9, 2021
b9a8a41
Merge branch 'master' into dev_adam_graph_conf
MARD1NO Aug 10, 2021
b60bc21
remove scale params
MARD1NO Aug 10, 2021
6214f29
fix unittest
MARD1NO Aug 10, 2021
40cd318
fix unittest
MARD1NO Aug 10, 2021
8700038
remove scale
MARD1NO Aug 10, 2021
4d2b2e1
remove scale in unittest
MARD1NO Aug 10, 2021
e81e8c1
remove scale
MARD1NO Aug 10, 2021
bce2435
fix to allclose
MARD1NO Aug 10, 2021
5da86af
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 10, 2021
c0d925a
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 10, 2021
1958017
auto format by CI
oneflow-ci-bot Aug 10, 2021
6a086c4
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 10, 2021
6bf8ee9
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 10, 2021
e6d8f16
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 10, 2021
9d707df
add test util
MARD1NO Aug 10, 2021
17edb71
remove scale
MARD1NO Aug 10, 2021
8977784
fix graph unittest
MARD1NO Aug 10, 2021
9e0f733
Merge branch 'dev_adam_graph_conf' of https://github.com/Oneflow-Inc/…
MARD1NO Aug 10, 2021
36480bd
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 10, 2021
a635978
auto format by CI
oneflow-ci-bot Aug 10, 2021
e7c9a2d
Merge branch 'master' into dev_adam_graph_conf
oneflow-ci-bot Aug 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions python/oneflow/nn/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""
import collections
import math
from typing import Callable, Dict, Iterator, List, Tuple, Union

import oneflow as flow
Expand Down Expand Up @@ -51,7 +52,7 @@ class Adam(Optimizer):
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
scale (float, optional): the scale factor of loss (default: 1.0)
do_bias_correction (bool, optional): Whether do bias correction (default: False)

.. _Adam\\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
Expand All @@ -68,7 +69,7 @@ def __init__(
eps: float = 1e-08,
weight_decay: float = 0,
amsgrad: bool = False,
scale: float = 1.0,
do_bias_correction: bool = False,
):
super().__init__()
assert lr >= 0.0, f"Invalid learning rate: {lr}"
Expand All @@ -80,14 +81,13 @@ def __init__(
betas[1] >= 0.0 and betas[1] < 1.0
), f"Invalid beta parameter at index 1: {betas[1]}"
assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}"
assert scale > 0.0, f"Invalid scale factor: {scale}"
assert amsgrad is False, "Not support AMSGrad now!"
self.do_bias_correction = do_bias_correction
self._default_options["lr"] = lr
self._default_options["eps"] = eps
self._default_options["betas"] = betas
self._default_options["weight_decay"] = weight_decay
self._default_options["amsgrad"] = amsgrad
self._default_options["scale"] = scale
if isinstance(parameters, collections.abc.Iterator):
self.param_groups.append(ParamGroup(parameters, self._default_options))
else:
Expand Down Expand Up @@ -124,12 +124,17 @@ def step(self, closure: Callable = None):
for param_group in self.param_groups:
kwargs = {
"learning_rate_val": param_group["lr"],
"scale": param_group["scale"],
"l2": param_group["weight_decay"],
"beta1": param_group["betas"][0],
"beta2": param_group["betas"][1],
"epsilon": param_group["eps"],
}
if self.do_bias_correction:
kwargs["learning_rate_val"] = (
param_group["lr"]
* math.sqrt(1 - kwargs["beta2"] ** (self._state["step"] + 1))
/ (1 - kwargs["beta1"] ** (self._state["step"] + 1))
)
for param in param_group.parameters:
if param.grad is None:
continue
Expand All @@ -138,3 +143,29 @@ def step(self, closure: Callable = None):
self._op(param, param.grad, m_tensor, v_tensor, **kwargs)
self._state["step"] = self._state["step"] + 1
return loss

def generate_conf_for_graph(self, train_conf, vars_conf):
for param_group in self.param_groups:
optimizer_conf = train_conf.mutable_optimizer_conf().Add()

lr = param_group["lr"]
l2 = param_group["weight_decay"]
beta1 = param_group["betas"][0]
beta2 = param_group["betas"][1]

epsilon = param_group["eps"]
# TODO(): optimizer_conf need to have loss_scale_factor field to support multi scale factor

optimizer_conf.set_base_learning_rate(lr)

optimizer_conf.mutable_adam_conf().set_beta1(beta1)
optimizer_conf.mutable_adam_conf().set_beta2(beta2)
optimizer_conf.mutable_adam_conf().set_epsilon(epsilon)
optimizer_conf.mutable_adam_conf().set_do_bias_correction(
Ldpe2G marked this conversation as resolved.
Show resolved Hide resolved
self.do_bias_correction
) # TODO(zzk): Check this option

for param in param_group.parameters:
vars_conf[param].l2 = l2
if param.requires_grad:
optimizer_conf.add_variable_op_names(vars_conf[param].name)
25 changes: 4 additions & 21 deletions python/oneflow/nn/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class SGD(Optimizer):
lr (float, optional): learning rate (default: 1e-3)
momentum (float, optional): Momentum factor (default: 0.0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
scale (float, optional): the scale factor of loss (default: 1.0)

"""

Expand All @@ -59,15 +58,12 @@ def __init__(
lr: float = 0.001,
momentum: float = 0.0,
weight_decay: float = 0.0,
scale: float = 1.0,
):
super().__init__()
assert lr >= 0.0, f"Invalid learning rate: {lr}"
assert momentum >= 0.0, f"Invalid momentum: {momentum}"
assert scale >= 0.0, f"Invalid scale factor: {scale}"
assert weight_decay >= 0.0, f"Invalid weight_decay: {weight_decay}"
self._default_options["lr"] = lr
self._default_options["scale"] = scale
self._default_options["momentum"] = momentum
self._default_options["weight_decay"] = weight_decay
if isinstance(parameters, collections.abc.Iterator):
Expand Down Expand Up @@ -106,15 +102,12 @@ def step(self, closure: Callable = None):
loss = closure()
for param_group in self.param_groups:
lr = param_group["lr"]
scale = param_group["scale"]
l2 = param_group["weight_decay"]
for param in param_group.parameters:
if param.grad is None:
continue
if param_group["momentum"] == 0.0:
self._sgd(
param, param.grad, learning_rate_val=lr, l2=l2, scale=scale
)
self._sgd(param, param.grad, learning_rate_val=lr, l2=l2)
else:
momentum_buf = self._state[param]["momentum_buf"]
beta = param_group["momentum"]
Expand All @@ -124,7 +117,6 @@ def step(self, closure: Callable = None):
momentum_buf,
learning_rate_val=lr,
l2=l2,
scale=scale,
beta=beta,
)
self._state["step"] = self._state["step"] + 1
Expand All @@ -135,17 +127,9 @@ def generate_conf_for_graph(self, train_conf, vars_conf):
optimizer_conf = train_conf.mutable_optimizer_conf().Add()
lr = param_group["lr"]
beta = param_group["momentum"]
scale = param_group["scale"]
l2 = param_group["weight_decay"]
# TODO(): optimizer_conf need to have loss_scale_factor field to support multi scale factor
base_scale = train_conf.loss_scale_factor()
assert math.isclose(base_scale, 1, rel_tol=1e-4) or math.isclose(
scale, base_scale, rel_tol=1e-4
), "nn.Graph only support one scale factor at the moment, base_scale {} vs scale {}".format(
base_scale, scale
)

train_conf.set_loss_scale_factor(scale)

optimizer_conf.set_base_learning_rate(lr)
if beta == 0:
optimizer_conf.mutable_naive_conf()
Expand All @@ -154,6 +138,5 @@ def generate_conf_for_graph(self, train_conf, vars_conf):

for param in param_group.parameters:
vars_conf[param].l2 = l2
if not param.requires_grad:
continue
optimizer_conf.add_variable_op_names(vars_conf[param].name)
if param.requires_grad:
optimizer_conf.add_variable_op_names(vars_conf[param].name)
144 changes: 144 additions & 0 deletions python/oneflow/test/graph/test_graph_adam_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
from collections import OrderedDict
from test_util import GenArgList
import numpy as np

import oneflow as flow
import oneflow.unittest


def compare_with_numpy_adam(
test_case,
device,
x_shape,
learning_rate,
train_iters,
betas,
weight_decay,
eps,
do_bias_correction,
):
random_grad_seq = []
for _ in range(train_iters):
random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))
init_value = np.random.uniform(size=x_shape).astype(np.float32)

class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.para0 = flow.nn.Parameter(
flow.Tensor(init_value, device=flow.device(device))
)

def forward(self, mask):
return self.para0 * mask

simp_module = CustomModule()
simp_module.to(device)
simp_module.train()

adam0 = flow.optim.Adam(
[
{
"params": simp_module.parameters(),
"lr": learning_rate,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
}
],
do_bias_correction=do_bias_correction,
)

class CustomAdamGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = simp_module
self.add_optimizer("adam", adam0)

def build(self, mask_tensor):
loss = flow.sum(self.m(mask_tensor))
loss.backward()
return loss

of_res_list = []
adam_graph = CustomAdamGraph()

for i in range(train_iters):
mask_tensor = flow.Tensor(
random_grad_seq[i], requires_grad=False, device=flow.device(device)
)
adam_x = adam_graph(mask_tensor)

of_res_list.append(simp_module.para0.numpy())

np_res_list = []

def train_by_numpy():
x = init_value
vt = np.zeros_like(x)
st = np.zeros_like(x)
beta1 = betas[0]
beta2 = betas[1]

def np_train_one_iter(iter, grad):
grad = grad + weight_decay * x

if do_bias_correction:
lr = (
learning_rate
* np.sqrt(1 - beta2 ** (iter + 1))
/ (1 - beta1 ** (iter + 1))
)
else:
lr = learning_rate

v = beta1 * vt + (1 - beta1) * grad
s = beta2 * st + (1 - beta2) * grad * grad
param = x - lr * (v / (np.sqrt(s) + eps))
return (param, v, s)

for i in range(train_iters):
(x, vt, st) = np_train_one_iter(i, random_grad_seq[i])
np_res_list.append(x)
return x

train_by_numpy()

test_case.assertTrue(np.allclose(of_res_list, np_res_list, rtol=0.001, atol=0.001))


@flow.unittest.skip_unless_1n1d()
class TestAdam(flow.unittest.TestCase):
def test_adam(test_case):
arg_dict = OrderedDict()
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["x_shape"] = [(10,)]
arg_dict["learning_rate"] = [1, 1e-3]
arg_dict["train_iters"] = [10]
arg_dict["betas"] = [(0.99, 0.9)]
arg_dict["weight_decay"] = [0.001, 0.0]
arg_dict["eps"] = [1e-8]
arg_dict["do_bias_correction"] = [True, False]
for arg in GenArgList(arg_dict):
compare_with_numpy_adam(test_case, *arg)


if __name__ == "__main__":
unittest.main()
Loading