Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optimizer for not supporting all kinds of iterables #5355

Merged
merged 11 commits into from
Jun 30, 2021

Conversation

OscarXZQ
Copy link
Contributor

BUG

Traceback (most recent call last):
  File "train.py", line 95, in <module>
    main(args)
  File "train.py", line 48, in main
    model = CycleGANModel(opt)
  File "/home/xuzhiqiu/models/cycleGAN/cycleGAN.py", line 31, in __init__
    self.optimizer_G_A = flow.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(0.5, 0.999))
  File "/home/xuzhiqiu/backup/oneflow/build/python_scripts/oneflow/python/nn/optimizer/adam.py", line 103, in __init__
    self.param_groups.append(ParamGroup(param, self._default_options))
  File "/home/xuzhiqiu/backup/oneflow/build/python_scripts/oneflow/python/nn/optimizer/optimizer.py", line 36, in __init__
    assert "params" in parameters
AssertionError

原因是,在 optimizer.py 里面

if isinstance(parameters, GeneratorType):
            self._parameters = list(parameters)
            self._options = default_options
else:  # Dict
            assert "params" in parameters
            self._parameters = list(parameters["params"])
            self._options = default_options
            for key in self._options:
                if key in parameters:
                    self._options[key] = parameters[key]

但pytorch对应的源码是这样的

param_groups = list(params)
if len(param_groups) == 0:
      raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
      param_groups = [{'params': param_groups}]

itertools.chain生成的确实是一个 iterable,但并不是generator type,所以会被我们这里写的if条件卡住。但Pytorch这里的写法不会卡住其他的 iterable 所以Pytorch可以运行那段代码但我们不行。

最后没有使用Pytorch的写法(过于奇怪),把原本check的generator type改成了collections.abc.Iterator,改了所有带这个判断的optimizer

@oneflow-ci-bot oneflow-ci-bot self-requested a review June 30, 2021 11:40
@oneflow-ci-bot oneflow-ci-bot merged commit ed82d1d into master Jun 30, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the fix_optimizer branch June 30, 2021 13:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants