Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LazyInterpret for FeedVariableOpExpr #5490

Merged
merged 6 commits into from
Jul 15, 2021
Merged

LazyInterpret for FeedVariableOpExpr #5490

merged 6 commits into from
Jul 15, 2021

Conversation

chengtbf
Copy link
Contributor

LazyInterpret 支持 传入 EagerTensor 构造 VariableOp

// NOTE(chengcheng): Record variable op output LazyTenosr
TensorNameScope::Global()->Record(outputs->at(0), op_name + "/" + obn);
// NOTE(chengcheng): Record EagerTensor as variable tensor name
TensorNameScope::Global()->Record(input_tensor, op_name + "/" + obn);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我会把输入的 EagerTensor 也记录下来,这样后续的 UserOp LazyInterpret 里的 input 如果是 EagerTensor 也能找到正确的 lbn。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variable 的 lazy tensor 和 eager tensor 的内存共享会在哪里实现?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NNGraph 里,LazyInterpret 不关心这个事情。 由 python 端的 nn.Graph 把 Variable 的 EagerTensor 和 对应的 names 传进 NNGraph 里,NNGraph记录该信息,在 Runtime 启动时绑定 Regst。

// TODO(chengcheng): GenerateParallelDistributionString by tensor.
}
if (!input_tensor->requires_grad()) { var_conf->set_trainable(false); }
// TODO(chengcheng, xuxiaoyu): Set L1/L2 RegularizerConf by nn.Graph Optimizer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我列了一个 TODO: 现在单凭 EagerTensor 是不知道 Variable 的 L1 和 L2 参数的,PyTorch 的应该是放在了 Optimizer 里? @strint 啸宇你后续研究一下,看 nn.Graph 如何支持配置每个 Variable 的 L1 和 L2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不同参数配置不同的learning rate么

可以不同参数被不同optimizer绑定,每个opt一个lr
https://pytorch.org/docs/stable/optim.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有这种:

optim.SGD([
                {'params': model.base.parameters()},
                {'params': model.classifier.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

optimizer中把参数分为多个group,每个group一个lr,而且还有个默认的lr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是,LR 是 learning rate,我说的这个是 l1、l2 正则化的参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch 没有统一,都可以在定义loss时,自己手动写,另外标准的写法:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看我们是否有必要,在optimizer的分group的参数配置中,加一个 l1_norm/l2_norm参数,进行配置 @wyg1997

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是所有Optimizer构造时传的参数,都支持单独指定,如果l1、l2在Optimizer的参数列表里,每个ParamGroup都也应该支持的

out_tensor = var_op.apply([x_tensor_in_c], attrs)[0]
test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10))
test_case.assertTrue(out_tensor.is_lazy)
test_case.assertTrue(out_tensor.is_consistent)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前默认产出都是consistent的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy 其实只有 Consistent 概念; Mirror 也是展成 Consistent 的,即使你传进去的 是 一个 local tensor,也会翻译成 placement 是本 rank 的 ConsistentTensor


OperatorConf op_conf;
op_conf.set_name(op_expr.op_name()); // construct by python nn.Graph
op_conf.set_scope_symbol_id(scope_symbol_id); // TODO(chengcheng): NewScope by cur scope.
Copy link
Contributor

@strint strint Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

外部创建了一个对该variable的scope,因为是单Tensor的,貌似可以复用,不用NewScope了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要,外部你创建的 scope 是 Block 的,没有真实的 ParallelDesc 信息。ParallelDesc 相关的 Scope 一定需要在 LazyInterpret· 里现场创建,无论是 Input、Variable 还是普通的 UserOp,因为这些都在输入的 Tensor 上保存的。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要原因是我们现在要根据tensor去推理ParallelDesc,所以Block层面创建的Scope里面的ParallelDesc往往就没用了对吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 15, 2021 11:41
@oneflow-ci-bot oneflow-ci-bot self-requested a review July 15, 2021 13:30
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 15, 2021 15:48
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot July 15, 2021 17:21
@oneflow-ci-bot oneflow-ci-bot merged commit b0c3d7e into master Jul 15, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the dev_cc_feed_var branch July 15, 2021 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants