Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

流水并行出错记录: header_size == rhs->blob_desc().ByteSizeOfBlobHeader() 报错 #6226

Closed
doombeaker opened this issue Sep 10, 2021 · 4 comments · Fixed by #6240
Closed
Labels

Comments

@doombeaker
Copy link
Contributor

复现代码,迭代 200 次,报错;改成更小的次数(如100),就不报错。怀疑和内存的分配及回收有关。

import oneflow as flow
import oneflow.nn as nn
import oneflow.utils.vision.transforms as transforms

DEVICE = "cuda"
BROADCAST = [flow.sbp.broadcast]
P0 = flow.placement("cuda", {0: [0]})
P1 = flow.placement("cuda", {0: [1]})

class Stage0Module(flow.nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = flow.nn.Flatten()
        self.linear0 = flow.nn.Linear(28*28, 512)
        self.relu0 = flow.nn.ReLU()

        self.flatten.to_consistent(placement=P0, sbp=BROADCAST)
        self.linear0.to_consistent(placement=P0, sbp=BROADCAST)
        self.relu0.to_consistent(placement=P0, sbp=BROADCAST)

    def forward(self, x):
        out = self.flatten(x)
        out = self.linear0(out)
        out = self.relu0(out)
        return out

class Stage1Module(flow.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = flow.nn.Linear(512, 512)
        self.relu1 = flow.nn.ReLU()
        self.linear2 = flow.nn.Linear(512, 10)
        self.relu2 = flow.nn.ReLU()
        
        self.linear1.to_consistent(placement=P1, sbp=BROADCAST)
        self.relu1.to_consistent(placement=P1, sbp=BROADCAST)
        self.linear2.to_consistent(placement=P1, sbp=BROADCAST)
        self.relu2.to_consistent(placement=P1, sbp=BROADCAST)

    def forward(self, x):
        x = x.to_consistent(placement=P1, sbp=x.sbp)
        out = self.linear1(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        return out

class PipelineModule(flow.nn.Module):
    def __init__(self):
        super().__init__()
        self.m_stage0 = Stage0Module()
        self.m_stage1 = Stage1Module()

    def forward(self, x):
        out_stage0 = self.m_stage0(x)
        out_stage1 = self.m_stage1(out_stage0)
        return out_stage1

module_pipeline = PipelineModule()

class PipelineGraph(flow.nn.Graph):
    def __init__(self):
        super().__init__()
        self.module_pipeline = module_pipeline
        self.module_pipeline.m_stage0.config.stage_id = 0
        self.module_pipeline.m_stage1.config.stage_id = 1
        self.config.set_gradient_accumulation_steps(2)

    def build(self, x):
        module_pipeline.eval()
        out = self.module_pipeline(x)
        return out

graph_pipeline = PipelineGraph()

for i in range(200):
    x = flow.randn(16, 1, 28, 28)
    x = x.to_consistent(placement=P0, sbp=BROADCAST)
    loss = graph_pipeline(x)

报错信息:

E0910 11:38:04.811247792 2671012 socket_utils_common_posix.cc:222] check for SO_REUSEPORT: {"created":"@1631245084.811230253","description":"SO_REUSEPORT unavailable on compiling system","file":"/home/yaochi/oneflow/build/grpc/src/grpc/src/core/lib/iomgr/socket_utils_common_posix.cc","file_line":190}
E0910 11:38:04.829096761 2671011 socket_utils_common_posix.cc:222] check for SO_REUSEPORT: {"created":"@1631245084.829077605","description":"SO_REUSEPORT unavailable on compiling system","file":"/home/yaochi/oneflow/build/grpc/src/grpc/src/core/lib/iomgr/socket_utils_common_posix.cc","file_line":190}
F0910 11:38:07.040753 2671648 blob.cpp:62] Check failed: header_size == rhs->blob_desc().ByteSizeOfBlobHeader() (32 vs. 16) 
*** Check failure stack trace: ***
    @     0x7fc683fa12ad  google::LogMessage::Fail()
    @     0x7fc683fa2b3d  google::LogMessage::SendToLog()
    @     0x7fc683fa0d7d  google::LogMessage::Flush()
    @     0x7fc683fa4799  google::LogMessageFatal::~LogMessageFatal()
    @     0x7fc68109aed1  oneflow::Blob::CopyHeaderFrom()
    @     0x7fc6807f9e4c  (unknown)
    @     0x7fc6807f9aaa  (unknown)
    @     0x7fc680eb32ed  (unknown)
    @     0x7fc680eb4297  oneflow::Kernel::Forward()
    @     0x7fc680eb44b9  oneflow::Kernel::Launch()
    @     0x7fc6805682f0  oneflow::Actor::AsyncLaunchKernel()
    @     0x7fc6805c6c0b  oneflow::NaiveActor::Act()
    @     0x7fc68056b80c  oneflow::Actor::ActUntilFail()
    @     0x7fc68056b9f5  oneflow::Actor::HandlerNormal()
    @     0x7fc6810ba9a8  oneflow::Thread::PollMsgChannel()
    @     0x7fc6810badaf  (unknown)
    @     0x7fc6799bb19d  execute_native_thread_routine
    @     0x7fc6d66ae609  start_thread
    @     0x7fc6d65d5293  clone
    @              (nil)  (unknown)
Killing subprocess 2671011
Killing subprocess 2671012
Traceback (most recent call last):
  File "/home/yaochi/anaconda3/envs/oneflow-dev-gcc7/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/yaochi/anaconda3/envs/oneflow-dev-gcc7/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/yaochi/oneflow/python/oneflow/distributed/launch.py", line 211, in <module>
    main()
  File "/home/yaochi/oneflow/python/oneflow/distributed/launch.py", line 199, in main
    sigkill_handler(signal.SIGTERM, None)
  File "/home/yaochi/oneflow/python/oneflow/distributed/launch.py", line 168, in sigkill_handler
    returncode=last_return_code, cmd=cmd
subprocess.CalledProcessError: Command '['/home/yaochi/anaconda3/envs/oneflow-dev-gcc7/bin/python3', '-u', './pipeline_test.py']' died with <Signals.SIGABRT: 6>.
@chengtbf
Copy link
Contributor

这个报错很奇怪。 但是因为你的脚本里是 eval 模式(没有 loss backward 和 Optimizer),所以其实配置的 set_gradient_accumulation_steps 和 stage id 都是不起作用的 😂

if (!job_conf.has_train_conf()) { return Maybe<void>::Ok(); }

bool IsEnabled(const JobPassCtx& ctx) const { return ctx.job_desc().IsTrain(); }

报错的具体原因可以再查一下

@chengtbf
Copy link
Contributor

chengcheng@oneflow-21:~/debug $ ONEFLOW_TEST_DEVICE_NUM=2 python3 -m oneflow.distributed.launch --nproc_per_node 2 test_simple.py 
graph input x.shape= oneflow.Size([16, 1, 28, 28])
graph input x.shape= oneflow.Size([16, 1, 28, 28])
cclog: in PushCb, input tensor blob: header_size = 32 header_shape = (16,1,28,28) aligned header size = 64 
 input regst blob: header_size = 32 header_shape = (16,1,28,28): aligned header size = 64
cclog: in PushCb, input tensor blob: header_size = 32 header_shape = (16,1,28,28) aligned header size = 64 
 input regst blob: header_size = 32 header_shape = (16,1,28,28): aligned header size = 64
cclog: in PushCb, input tensor blob: header_size = 32 header_shape = (16,1,28,28) aligned header size = 64 
 input regst blob: header_size = 32 header_shape = (16,1,28,28): aligned header size = 64
cclog: in PushCb, input tensor blob: header_size = 32 header_shape = (16,1,28,28) aligned header size = 64 
 input regst blob: header_size = 32 header_shape = (16,1,28,28): aligned header size = 64
cclog: in PushCb, input tensor blob: header_size = 32 header_shape = (16,1,28,28) aligned header size = 64 
 input regst blob: header_size = 32 header_shape = (16,1,28,28): aligned header size = 64
cclog: in PushCb, input tensor blob: header_size = 32 header_shape = (16,1,28,28) aligned header size = 64 
 input regst blob: header_size = 32 header_shape = (16,1,28,28): aligned header size = 64
cclog: in PushCb, input tensor blob: header_size = 8 header_shape = (0,) aligned header size = 64 
 input regst blob: header_size = 32 header_shape = (16,1,28,28): aligned header size = 64
F0910 15:31:40.086340 3750611 blob.cpp:62] Check failed: header_size == rhs->blob_desc().ByteSizeOfBlobHeader() (32 vs. 8) 
*** Check failure stack trace: ***
    @     0x7fe94e5ba353  google::LogMessage::Fail()
    @     0x7fe94e5bf0fb  google::LogMessage::SendToLog()
    @     0x7fe94e5ba04f  google::LogMessage::Flush()
    @     0x7fe94e5ba87f  google::LogMessageFatal::~LogMessageFatal()
    @     0x7fe949d44669  oneflow::Blob::CopyHeaderFrom()
    @     0x7fe949491200  _ZNSt17_Function_handlerIFvlEZNK7oneflow2vm25RunLazyJobInstructionType15MakeJobInstanceEPNS2_11InstructionEEUllE_E9_M_invokeERKSt9_Any_dataOl
    @     0x7fe949490c50  oneflow::(anonymous namespace)::LazyJobInstance::PushBlobByOpName()
    @     0x7fe949b5253d  oneflow::(anonymous namespace)::InputKernel<>::ForwardDataContent()
    @     0x7fe949b52b3c  oneflow::Kernel::Forward()
    @     0x7fe949b52bfe  oneflow::Kernel::Launch()
    @     0x7fe94912877e  oneflow::Actor::AsyncLaunchKernel()
    @     0x7fe949193d0e  oneflow::NaiveActor::Act()
    @     0x7fe94912807e  oneflow::Actor::ActUntilFail()
    @     0x7fe949128ddb  oneflow::Actor::HandlerNormal()
    @     0x7fe949d60ec7  oneflow::Thread::PollMsgChannel()
    @     0x7fe949d61317  _ZNSt6thread11_State_implINS_8_InvokerISt5tupleIJZN7oneflow6ThreadC4ERKNS3_8StreamIdEEUlvE_EEEEE6_M_runEv
    @     0x7fe94415fde4  (unknown)
    @     0x7fe98cc0c609  start_thread
    @     0x7fe98cd48293  clone
    @              (nil)  (unknown)
Killing subprocess 3750257
Killing subprocess 3750258
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/chengcheng/oneflow/python/oneflow/distributed/launch.py", line 211, in <module>
    main()
  File "/home/chengcheng/oneflow/python/oneflow/distributed/launch.py", line 199, in main
    sigkill_handler(signal.SIGTERM, None)
  File "/home/chengcheng/oneflow/python/oneflow/distributed/launch.py", line 167, in sigkill_handler
    raise subprocess.CalledProcessError(
subprocess.CalledProcessError: Command '['/usr/bin/python3', '-u', 'test_simple.py']' died with <Signals.SIGABRT: 6>.

通过给输入 push callback 加 日志发现,数据会偶尔出现 全 0 的情况。。。。。input 拷贝数据就报错了

@chengtbf
Copy link
Contributor

只要在 python 端 ,print 每次的 x tensor,就不会有这个问题。 但是会死锁,好像是卡在 push callback 的 buffer 里了

@chengtbf
Copy link
Contributor

buffer_mgr->Get(GetInputBufferName(job_name, op_name))->Send(job_instance);

这里会给所有的 input 和 output 都 send job instance。无论 这个 input 有没有本 rank 的分量。 但是 本 rank 没有 input 分量的 op,也不会创建 input kernel 从 buffer 里读这个 job instance 取出来,那么一旦这个 buffer 被塞满了,就阻塞卡住了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants