-
Notifications
You must be signed in to change notification settings - Fork 756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.Graph with_dynamic_input_shape #9984
Conversation
return self.my_linear(x) | ||
|
||
linear_g = LinearGraph() | ||
linear_g.enable_save_runtime_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个做成kv参数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数输入在哪里呢,graph 的 init 函数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,LinearGraph(enable_save_runtime_state_dict=True)
这样?因为这个配置 1)不会不停变化状态,没必要提供函数接口 2)有调用时机的限制,提供函数接口反而令人要多一层纠结调用时机
Speed stats:
|
class LinearGraph(flow.nn.Graph): | ||
@flow.nn.Graph.with_dynamic_input_shape | ||
def __init__(self): | ||
super().__init__(enable_runtime_state_dict=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable_runtime_state_dict with kwargs
warnings.warn( | ||
f"nn.Graph {self._name} WARNING: current oneflow version ({oneflow.__version__}) is loading " | ||
f"runtime_state_dict from a different version ({state_dict['oneflow_version']}), " | ||
"there may has compatibility problems." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.Graph LinearGraph_1 WARNING: current oneflow version (0.9.1+cu112.git.62e58c9390) is loading runtime_state_dic
t from a different version (0.9.1+cu112.git.92e58c939), there may has compatibility problems.
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
linear_reshape = LinearReshapeModule(linear, with_reshape) | ||
|
||
class LinearGraph(flow.nn.Graph): | ||
@flow.nn.Graph.with_dynamic_input_shape(size=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
支持设置 cache size
python/oneflow/nn/graph/graph.py
Outdated
@@ -169,6 +170,10 @@ def __init__(self): | |||
|
|||
# For load graph from runtime states. | |||
self._enable_save_runtime_state_dict = False | |||
self.enable_save_runtime_state_dict(enable_runtime_state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里初始化的话, 172 行是不是多余了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是应该删掉
python/oneflow/nn/graph/graph.py
Outdated
@@ -103,7 +104,7 @@ class Graph(object): | |||
""" | |||
_child_init_cnt = dict() | |||
|
|||
def __init__(self): | |||
def __init__(self, enable_runtime_state_dict: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable_runtime_state_dict 这个初看没理解参数的含义,是不是应该叫:
save_runtime_state_dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个接口是决定是否可以调用 graph.runtime_state_dict 的,因为生成 runtime_state_dict 会额外保存一些东西。
destination = OrderedDict() | ||
destination._metadata = OrderedDict() | ||
|
||
for (key, graph) in self._cache.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 key 是 input shape list ?
这里 save runtime state dict,如果有多个 graph ,是都会写到一个 file 里吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key 是 input tensor DFS 遍历后 shape 组成的 tuple,然后对 tuple 做 hash 得到的。
如果 cache 中有多个 graph,则按 graph name 作为 key,runtime_state_dict 作为 value,放到一个词典里面。最好用 flow.save(state_dict, file_name) 保存到一个 file 里面。
graph = self._base_graph | ||
else: | ||
# Create new graph from base | ||
graph = self._base_graph.__class__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new graph 的 name 是什么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new graph 的 name 是什么
是自动生成的,跟之前的生成方式一样
|
||
self._cache = LRUCache(self._cache_size) | ||
for _, sub_state_dict in sorted(graph_dict.items()): | ||
cache_key = sub_state_dict["cache_key"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是不关心 graph name,只关心 shape hash
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是不关心 graph name,只关心 shape hash
是的,完全根据输入 shape 来作为缓存 key,因为一个 graph 实例对应一个缓存,缓存里面存了该 graph 实例的多个 shape 的子 graph。
…c/oneflow into feat/graph_with_cache
Speed stats:
|
@@ -103,7 +104,7 @@ class Graph(object): | |||
""" | |||
_child_init_cnt = dict() | |||
|
|||
def __init__(self): | |||
def __init__(self, enable_get_runtime_state_dict: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经从 enable_runtime_state_dict 改为 enable_get_runtime_state_dict
Speed stats:
|
CI failed when running job: cuda-module. PR label automerge has been removed |
https://github.com/Oneflow-Inc/oneflow/actions/runs/4415299017/jobs/7741189188
https://github.com/Oneflow-Inc/oneflow/actions/runs/4416988939/jobs/7742090625
无关功能报错 |
CI failed when running job: cuda-module. PR label automerge has been removed |
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9984/ |
Depends on: Oneflow-Inc/oneflow#9984 --------- Co-authored-by: jackalcooper <jackalcooper@gmail.com>
nn.Graph run with dynamic input shape by the cache.
define graph (with dynamic input shape)
init and call graph (with dynamic input shape)
config cache size
Save and load nn.Graph with dynamic input shape
save
load