Skip to content

Commit

Permalink
Add test_graph_activation_checkpoint.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tingkuanpei committed Sep 7, 2021
1 parent 6b20bd4 commit 9a2b3b0
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 4 deletions.
2 changes: 2 additions & 0 deletions oneflow/api/python/framework/framework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("GetSerializedJobSet", []() { return py::bytes(GetSerializedJobSet()); });
m.def("GetSerializedStructureGraph", &GetSerializedStructureGraph /* a prototxt saved to file*/);
m.def("GetSerializedCurrentJob", []() { return py::bytes(GetSerializedCurrentJob()); });
m.def("GetSerializedJob",
[](const std::string& job_name) { return py::bytes(GetSerializedJob(job_name)); });

m.def("GetFunctionConfigDef", &GetFunctionConfigDef);
m.def("GetScopeConfigDef", &GetScopeConfigDef);
Expand Down
8 changes: 8 additions & 0 deletions oneflow/api/python/framework/framework.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ inline Maybe<std::string> GetSerializedCurrentJob() {
return job_ctx->job().SerializeAsString();
}

inline Maybe<std::string> GetSerializedJob(const std::string& job_name) {
auto* job_ctx_mgr = Global<LazyJobBuildAndInferCtxMgr>::Get();
CHECK_NOTNULL_OR_RETURN(job_ctx_mgr);
auto* job_ctx = JUST(job_ctx_mgr->FindJobBuildAndInferCtx(job_name));
CHECK_NOTNULL_OR_RETURN(job_ctx);
return job_ctx->job().SerializeAsString();
}

inline Maybe<std::string> GetFunctionConfigDef() {
std::string ret;
google::protobuf::TextFormat::PrintToString(GlobalFunctionConfigDef(), &ret);
Expand Down
4 changes: 4 additions & 0 deletions oneflow/api/python/framework/framework_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ inline std::string GetSerializedCurrentJob() {
return oneflow::GetSerializedCurrentJob().GetOrThrow();
}

inline std::string GetSerializedJob(const std::string& job_name) {
return oneflow::GetSerializedJob(job_name).GetOrThrow();
}

inline std::string GetFunctionConfigDef() { return oneflow::GetFunctionConfigDef().GetOrThrow(); }

inline std::string GetScopeConfigDef() { return oneflow::GetScopeConfigDef().GetOrThrow(); }
Expand Down
7 changes: 7 additions & 0 deletions python/oneflow/framework/c_api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,10 @@ def GetCurrentJob():
ret = job_pb.Job()
ret.ParseFromString(serialized_job)
return ret


def GetJob(job_name):
serialized_job = oneflow._oneflow_internal.GetSerializedJob(job_name)
ret = job_pb.Job()
ret.ParseFromString(serialized_job)
return ret
23 changes: 19 additions & 4 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def __init__(self):
self._grad_scaler = None
self._variables_conf = OrderedDict()
self._is_compiled = False
self._job_proto = None
# forward graph job proto
self._forward_job_proto = None
# forward and backward graph job proto
self._full_job_proto = None
self._args_repr = []
self._outs_repr = []
self._debug = False
Expand Down Expand Up @@ -338,7 +341,15 @@ def _optimization_conf_proto(self):

@property
def _graph_proto(self):
return self._job_proto
return self._forward_job_proto

@property
def _full_graph_proto(self):
if self._debug and self._full_job_proto is not None:
return self._full_job_proto
else:
print("[ERROR](You can't get full graph when debug mode is disable)")
raise

def _generate_name(self):
child_name = self.__class__.__name__
Expand Down Expand Up @@ -481,8 +492,12 @@ def _build_forward_graph(self, *args):
state_op_names, self._states_tensor_tuple
)

# Save job proto for debug
self._job_proto = c_api_util.GetCurrentJob()
# Save forward job proto for debug
self._forward_job_proto = c_api_util.GetCurrentJob()

# Save forward and backward graph job proto for debug
if self._debug:
self._full_job_proto = c_api_util.GetJob(self.config.proto.job_name())

return list_to_func_return(self._eager_outputs_buffer[0])

Expand Down
79 changes: 79 additions & 0 deletions python/oneflow/test/graph/test_graph_activation_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import re
import unittest

import numpy as np

import oneflow
import oneflow as flow
import oneflow.framework.graph_build_util as graph_build_util
import oneflow.unittest


class TestGraphActivationCheckpoint(flow.unittest.TestCase):
def test_activation_checkpoint(test_case):
loss_fn = flow.nn.MSELoss(reduction="sum")
model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1))
optimizer = flow.optim.SGD(model.parameters(), lr=1e-6)

class SubModule0(flow.nn.Module):
def __init__(self):
super().__init__()
self.model = model

def forward(self, x):
scope = oneflow.current_scope()
scope_proto = graph_build_util.scope_to_proto(scope)
ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool
test_case.assertEqual(ck_bool, True)
out = self.model(x)
return out

class LinearTrainGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model = SubModule0()
self.loss_fn = loss_fn
# Add an optimizer
self.add_optimizer(optimizer)
self.model.config.activation_checkpointing = True

def build(self, x, y):
y_pred = self.model(x)
loss = self.loss_fn(y_pred, y)
loss.backward()
return loss

linear_graph = LinearTrainGraph()
linear_graph.debug()
x = flow.randn(10, 3)
y = flow.randn(10)
linear_graph._compile(x, y)

graph_proto = linear_graph._full_graph_proto
for op in graph_proto.net.op:
# Check flatten gradient operator take checkpoiting as input
if re.search("flatten.*grad", op.name, re.I) is not None:
find_check_point = False
for value in op.user_conf.input.values():
if re.search("checkpointing", str(value), re.I) is not None:
find_check_point = True
test_case.assertTrue(find_check_point)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9a2b3b0

Please sign in to comment.