Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Graph with_dynamic_input_shape #9984

Merged
merged 10 commits into from
Mar 14, 2023
Merged

nn.Graph with_dynamic_input_shape #9984

merged 10 commits into from
Mar 14, 2023

Conversation

strint
Copy link
Contributor

@strint strint commented Mar 13, 2023

nn.Graph run with dynamic input shape by the cache.

define graph (with dynamic input shape)

import oneflow as flow
class LinearGraph(flow.nn.Graph):
    # [New] use this decorator to enable dynamic input shape
    @flow.nn.Graph.with_dynamic_input_shape()
    def __init__(self):
        super().__init__()
        self.my_linear = flow.nn.Linear(3, 8, False)

    def build(self, x):
        return self.my_linear(x)

init and call graph (with dynamic input shape)

linear_g = LinearGraph()

# x_with_shape0 and x_with_shape1 can have different shape
linear_g(x_with_shape0)
linear_g(x_with_shape1)

config cache size

define graph (with dynamic input shape)
```python
import oneflow as flow
class LinearGraph(flow.nn.Graph):
    # [New] config cache size with size parameter
    @flow.nn.Graph.with_dynamic_input_shape(size=8)
    def __init__(self):
        super().__init__()
        self.my_linear = flow.nn.Linear(3, 8, False)

    def build(self, x):
        return self.my_linear(x)

Save and load nn.Graph with dynamic input shape

save

class LinearGraph(flow.nn.Graph):
    @flow.nn.Graph.with_dynamic_input_shape()
    def __init__(self):
        # [New]enable get runtime_state_dict
        super().__init__(enable_get_runtime_state_dict=True)
        self.my_linear = flow.nn.Linear(3, 8, False)

    def build(self, x):
        return self.my_linear(x)

linear_g = LinearGraph()

# call graph
linear_g(x_with_shape0)
linear_g(x_with_shape1)

# [New]save graph (with dynamic input shape)
state_dict = linear_g.runtime_state_dict()
flow.save(state_dict, filename)

load

linear_g = LinearGraph()

state_dict = flow.load(filename)
# [New]load graph(with dynamic input shape)
linear_g.load_runtime_state_dict(state_dict)

# call graph
linear_g(x_with_shape0)
linear_g(x_with_shape1)

@strint strint marked this pull request as ready for review March 14, 2023 01:45
@strint strint requested review from BBuf and daquexian as code owners March 14, 2023 01:45
@strint strint added graph graph mode feature labels Mar 14, 2023
@strint strint requested a review from oneflow-ci-bot March 14, 2023 01:54
return self.my_linear(x)

linear_g = LinearGraph()
linear_g.enable_save_runtime_state_dict()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个做成kv参数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数输入在哪里呢,graph 的 init 函数?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,LinearGraph(enable_save_runtime_state_dict=True) 这样?因为这个配置 1)不会不停变化状态,没必要提供函数接口 2)有调用时机的限制,提供函数接口反而令人要多一层纠结调用时机

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.2ms (= 14121.0ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 143.7ms (= 14368.5ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.02 (= 143.7ms / 141.2ms)

OneFlow resnet50 time: 82.7ms (= 8273.9ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 87.6ms (= 8761.1ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.06 (= 87.6ms / 82.7ms)

OneFlow resnet50 time: 51.2ms (= 10230.4ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.7ms (= 11549.0ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.13 (= 57.7ms / 51.2ms)

OneFlow resnet50 time: 33.9ms (= 6781.6ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 45.9ms (= 9171.0ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.35 (= 45.9ms / 33.9ms)

OneFlow resnet50 time: 26.5ms (= 5302.8ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 43.0ms (= 8608.4ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.62 (= 43.0ms / 26.5ms)

OneFlow swin dataloader time: 0.239s (= 47.838s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.021s / 200, num_workers=1)
Relative speed: 0.628 (= 0.150s / 0.239s)

OneFlow swin dataloader time: 0.068s (= 13.647s / 200, num_workers=4)
PyTorch swin dataloader time: 0.042s (= 8.343s / 200, num_workers=4)
Relative speed: 0.611 (= 0.042s / 0.068s)

OneFlow swin dataloader time: 0.045s (= 9.074s / 200, num_workers=8)
PyTorch swin dataloader time: 0.023s (= 4.550s / 200, num_workers=8)
Relative speed: 0.501 (= 0.023s / 0.045s)

❌ OneFlow resnet50 time: 153.0ms (= 15296.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 166.8ms (= 16675.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.09 (= 166.8ms / 153.0ms)

OneFlow resnet50 time: 92.9ms (= 9286.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.8ms (= 10382.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.12 (= 103.8ms / 92.9ms)

OneFlow resnet50 time: 60.3ms (= 12066.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 80.1ms (= 16020.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 80.1ms / 60.3ms)

OneFlow resnet50 time: 43.7ms (= 8732.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 72.3ms (= 14451.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.65 (= 72.3ms / 43.7ms)

OneFlow resnet50 time: 36.6ms (= 7326.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 75.1ms (= 15013.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 2.05 (= 75.1ms / 36.6ms)

class LinearGraph(flow.nn.Graph):
@flow.nn.Graph.with_dynamic_input_shape
def __init__(self):
super().__init__(enable_runtime_state_dict=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_runtime_state_dict with kwargs

warnings.warn(
f"nn.Graph {self._name} WARNING: current oneflow version ({oneflow.__version__}) is loading "
f"runtime_state_dict from a different version ({state_dict['oneflow_version']}), "
"there may has compatibility problems."
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.Graph LinearGraph_1 WARNING: current oneflow version (0.9.1+cu112.git.62e58c9390) is loading runtime_state_dic
t from a different version (0.9.1+cu112.git.92e58c939), there may has compatibility problems.

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

linear_reshape = LinearReshapeModule(linear, with_reshape)

class LinearGraph(flow.nn.Graph):
@flow.nn.Graph.with_dynamic_input_shape(size=4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持设置 cache size

@@ -169,6 +170,10 @@ def __init__(self):

# For load graph from runtime states.
self._enable_save_runtime_state_dict = False
self.enable_save_runtime_state_dict(enable_runtime_state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里初始化的话, 172 行是不是多余了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是应该删掉

@@ -103,7 +104,7 @@ class Graph(object):
"""
_child_init_cnt = dict()

def __init__(self):
def __init__(self, enable_runtime_state_dict: bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_runtime_state_dict 这个初看没理解参数的含义,是不是应该叫:

save_runtime_state_dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个接口是决定是否可以调用 graph.runtime_state_dict 的,因为生成 runtime_state_dict 会额外保存一些东西。

destination = OrderedDict()
destination._metadata = OrderedDict()

for (key, graph) in self._cache.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 key 是 input shape list ?

这里 save runtime state dict,如果有多个 graph ,是都会写到一个 file 里吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key 是 input tensor DFS 遍历后 shape 组成的 tuple,然后对 tuple 做 hash 得到的。

如果 cache 中有多个 graph,则按 graph name 作为 key,runtime_state_dict 作为 value,放到一个词典里面。最好用 flow.save(state_dict, file_name) 保存到一个 file 里面。

graph = self._base_graph
else:
# Create new graph from base
graph = self._base_graph.__class__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new graph 的 name 是什么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new graph 的 name 是什么

是自动生成的,跟之前的生成方式一样


self._cache = LRUCache(self._cache_size)
for _, sub_state_dict in sorted(graph_dict.items()):
cache_key = sub_state_dict["cache_key"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是不关心 graph name,只关心 shape hash

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是不关心 graph name,只关心 shape hash

是的,完全根据输入 shape 来作为缓存 key,因为一个 graph 实例对应一个缓存,缓存里面存了该 graph 实例的多个 shape 的子 graph。

@github-actions
Copy link
Contributor

Speed stats:

@@ -103,7 +104,7 @@ class Graph(object):
"""
_child_init_cnt = dict()

def __init__(self):
def __init__(self, enable_get_runtime_state_dict: bool = False):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经从 enable_runtime_state_dict 改为 enable_get_runtime_state_dict

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 140.8ms (= 14080.3ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 142.5ms (= 14249.9ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.01 (= 142.5ms / 140.8ms)

OneFlow resnet50 time: 80.6ms (= 8059.4ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 85.5ms (= 8547.7ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.06 (= 85.5ms / 80.6ms)

OneFlow resnet50 time: 48.8ms (= 9768.9ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 58.8ms (= 11750.6ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.20 (= 58.8ms / 48.8ms)

OneFlow resnet50 time: 32.6ms (= 6517.7ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 46.9ms (= 9383.9ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.44 (= 46.9ms / 32.6ms)

OneFlow resnet50 time: 25.6ms (= 5111.3ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 37.2ms (= 7445.1ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.46 (= 37.2ms / 25.6ms)

OneFlow swin dataloader time: 0.241s (= 48.210s / 200, num_workers=1)
PyTorch swin dataloader time: 0.151s (= 30.111s / 200, num_workers=1)
Relative speed: 0.625 (= 0.151s / 0.241s)

OneFlow swin dataloader time: 0.072s (= 14.404s / 200, num_workers=4)
PyTorch swin dataloader time: 0.043s (= 8.646s / 200, num_workers=4)
Relative speed: 0.600 (= 0.043s / 0.072s)

OneFlow swin dataloader time: 0.040s (= 7.939s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.419s / 200, num_workers=8)
Relative speed: 0.557 (= 0.022s / 0.040s)

❌ OneFlow resnet50 time: 152.2ms (= 15221.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 161.8ms (= 16184.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.06 (= 161.8ms / 152.2ms)

OneFlow resnet50 time: 91.0ms (= 9102.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.7ms (= 10373.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.14 (= 103.7ms / 91.0ms)

OneFlow resnet50 time: 59.4ms (= 11877.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 80.9ms (= 16186.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.36 (= 80.9ms / 59.4ms)

OneFlow resnet50 time: 42.7ms (= 8549.7ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.3ms (= 13665.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.60 (= 68.3ms / 42.7ms)

OneFlow resnet50 time: 35.8ms (= 7152.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.4ms (= 13482.1ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.88 (= 67.4ms / 35.8ms)

@github-actions
Copy link
Contributor

CI failed when running job: cuda-module. PR label automerge has been removed

@strint
Copy link
Contributor Author

strint commented Mar 14, 2023

https://github.com/Oneflow-Inc/oneflow/actions/runs/4415299017/jobs/7741189188

FAILED python/oneflow/test/modules/test_loss.py::TestCrossEntropyLossModule::test_cross_entropy_prob_loss_with_random_data_dim_2

https://github.com/Oneflow-Inc/oneflow/actions/runs/4416988939/jobs/7742090625

FAILED python/oneflow/test/modules/test_ctc_loss.py::TestCTCLoss1n1d::test_ctc_loss_functional

无关功能报错

@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot March 14, 2023 14:45
@github-actions
Copy link
Contributor

CI failed when running job: cuda-module. PR label automerge has been removed

@github-actions
Copy link
Contributor

Speed stats:

@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot March 14, 2023 14:58
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.0ms (= 14102.3ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 143.1ms (= 14306.2ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.01 (= 143.1ms / 141.0ms)

OneFlow resnet50 time: 80.7ms (= 8074.5ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 83.5ms (= 8350.7ms / 100, input_shape=[8, 3, 224, 224])
❌ Relative speed: 1.03 (= 83.5ms / 80.7ms)

OneFlow resnet50 time: 48.7ms (= 9739.3ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.6ms (= 11529.8ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.18 (= 57.6ms / 48.7ms)

OneFlow resnet50 time: 33.1ms (= 6611.4ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 50.4ms (= 10078.0ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.52 (= 50.4ms / 33.1ms)

OneFlow resnet50 time: 25.1ms (= 5015.4ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 43.4ms (= 8689.4ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.73 (= 43.4ms / 25.1ms)

OneFlow swin dataloader time: 0.255s (= 50.991s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.057s / 200, num_workers=1)
Relative speed: 0.589 (= 0.150s / 0.255s)

OneFlow swin dataloader time: 0.069s (= 13.856s / 200, num_workers=4)
PyTorch swin dataloader time: 0.045s (= 9.068s / 200, num_workers=4)
Relative speed: 0.654 (= 0.045s / 0.069s)

OneFlow swin dataloader time: 0.046s (= 9.239s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.429s / 200, num_workers=8)
Relative speed: 0.479 (= 0.022s / 0.046s)

❌ OneFlow resnet50 time: 152.7ms (= 15272.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 162.2ms (= 16220.5ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.06 (= 162.2ms / 152.7ms)

OneFlow resnet50 time: 90.8ms (= 9080.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 104.3ms (= 10433.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 104.3ms / 90.8ms)

OneFlow resnet50 time: 58.6ms (= 11726.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 78.3ms (= 15658.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.34 (= 78.3ms / 58.6ms)

OneFlow resnet50 time: 41.8ms (= 8358.1ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 79.0ms (= 15791.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.89 (= 79.0ms / 41.8ms)

OneFlow resnet50 time: 36.5ms (= 7306.5ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.7ms (= 13546.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.85 (= 67.7ms / 36.5ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9984/

@strint strint requested review from oneflow-ci-bot and removed request for oneflow-ci-bot March 14, 2023 16:07
@mergify mergify bot merged commit cabf428 into master Mar 14, 2023
@mergify mergify bot deleted the feat/graph_with_cache branch March 14, 2023 17:48
jackalcooper added a commit to siliconflow/onediff that referenced this pull request Mar 20, 2023
Depends on: Oneflow-Inc/oneflow#9984

---------

Co-authored-by: jackalcooper <jackalcooper@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants