Skip to content

Commit

Permalink
Use universal start global device id for all streams (#3701)
Browse files Browse the repository at this point in the history
* universal start global device for one machine

* update machine_id() and device_id()

Co-authored-by: lixinqi <lixinqi0703106@163.com>
Co-authored-by: oneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 20, 2020
1 parent fbc0989 commit 112b786
Show file tree
Hide file tree
Showing 17 changed files with 26 additions and 31 deletions.
1 change: 0 additions & 1 deletion oneflow/core/vm/control_stream_type.cpp
Expand Up @@ -161,7 +161,6 @@ ObjectMsgPtr<StreamDesc> ControlStreamType::MakeStreamDesc(const Resource& resou
ret->set_num_machines(1);
ret->set_num_streams_per_machine(1);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/cpu_stream_type.cpp
Expand Up @@ -65,7 +65,6 @@ ObjectMsgPtr<StreamDesc> CpuStreamType::MakeStreamDesc(const Resource& resource,
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id * device_num);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/cuda_copy_d2h_stream_type.cpp
Expand Up @@ -71,7 +71,6 @@ ObjectMsgPtr<StreamDesc> CudaCopyD2HStreamType::MakeStreamDesc(const Resource& r
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id * device_num);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/cuda_copy_h2d_stream_type.cpp
Expand Up @@ -65,7 +65,6 @@ ObjectMsgPtr<StreamDesc> CudaCopyH2DStreamType::MakeStreamDesc(const Resource& r
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id * device_num);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/cuda_stream_type.cpp
Expand Up @@ -71,7 +71,6 @@ ObjectMsgPtr<StreamDesc> CudaStreamType::MakeStreamDesc(const Resource& resource
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id * device_num);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/device_helper_stream_type.cpp
Expand Up @@ -67,7 +67,6 @@ ObjectMsgPtr<StreamDesc> DeviceHelperStreamType::MakeStreamDesc(const Resource&
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id * device_num);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/host_stream_type.cpp
Expand Up @@ -59,7 +59,6 @@ ObjectMsgPtr<StreamDesc> HostStreamType::MakeStreamDesc(const Resource& resource
ret->set_num_machines(1);
ret->set_num_streams_per_machine(1);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/nop_stream_type.cpp
Expand Up @@ -88,7 +88,6 @@ ObjectMsgPtr<StreamDesc> NopStreamType::MakeStreamDesc(const Resource& resource,
ret->set_num_machines(1);
ret->set_num_streams_per_machine(1);
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id);
return ret;
}

Expand Down
9 changes: 3 additions & 6 deletions oneflow/core/vm/object_instruction_type.cpp
Expand Up @@ -86,7 +86,6 @@ class NewObjectInstructionType final : public InstructionType {
FlatMsgView<NewObjectInstruction> view(instr_msg->operand());
const auto& parallel_desc = CHECK_JUST(vm->GetInstructionParallelDesc(*instr_msg));
CHECK(static_cast<bool>(parallel_desc));
const char* device_tag = CHECK_JUST(DeviceTag4DeviceType(parallel_desc->device_type()));
FOR_RANGE(int, i, 0, view->logical_object_id_size()) {
int64_t logical_object_id = GetLogicalObjectId(view->logical_object_id(i));
auto logical_object = ObjectMsgPtr<LogicalObject>::NewFrom(vm->mut_vm_thread_only_allocator(),
Expand All @@ -97,7 +96,7 @@ class NewObjectInstructionType final : public InstructionType {
ForEachMachineIdAndDeviceIdInRange(
*parallel_desc, vm->machine_id_range(), [&](int64_t machine_id, int64_t device_id) {
int64_t global_device_id =
vm->vm_resource_desc().GetGlobalDeviceId(machine_id, device_tag, device_id);
vm->vm_resource_desc().GetGlobalDeviceId(machine_id, device_id);
auto mirrored_object = ObjectMsgPtr<MirroredObject>::NewFrom(
vm->mut_allocator(), logical_object.Mutable(), global_device_id);
CHECK(global_device_id2mirrored_object->Insert(mirrored_object.Mutable()).second);
Expand Down Expand Up @@ -145,7 +144,6 @@ class BroadcastObjectReferenceInstructionType final : public InstructionType {
CHECK_NOTNULL(sole_rw_mutexed_object);
}
CHECK(static_cast<bool>(parallel_desc));
const char* device_tag = CHECK_JUST(DeviceTag4DeviceType(parallel_desc->device_type()));
int64_t new_object = GetLogicalObjectId(args->new_object());
auto logical_object = ObjectMsgPtr<LogicalObject>::NewFrom(vm->mut_vm_thread_only_allocator(),
new_object, parallel_desc);
Expand All @@ -154,7 +152,7 @@ class BroadcastObjectReferenceInstructionType final : public InstructionType {
ForEachMachineIdAndDeviceIdInRange(
*parallel_desc, vm->machine_id_range(), [&](int64_t machine_id, int64_t device_id) {
int64_t global_device_id =
vm->vm_resource_desc().GetGlobalDeviceId(machine_id, device_tag, device_id);
vm->vm_resource_desc().GetGlobalDeviceId(machine_id, device_id);
auto mirrored_object = ObjectMsgPtr<MirroredObject>::NewFrom(
vm->mut_allocator(), logical_object.Mutable(), global_device_id);
mirrored_object->reset_rw_mutexed_object(*sole_rw_mutexed_object);
Expand Down Expand Up @@ -206,12 +204,11 @@ class ReplaceMirroredInstructionType final : public InstructionType {
};
const auto& parallel_desc = CHECK_JUST(vm->GetInstructionParallelDesc(*instr_msg));
CHECK(static_cast<bool>(parallel_desc));
const char* device_tag = CHECK_JUST(DeviceTag4DeviceType(parallel_desc->device_type()));
ForEachMachineIdAndDeviceIdInRange(
*parallel_desc, vm->machine_id_range(), [&](int64_t machine_id, int64_t device_id) {
FOR_RANGE(int, i, 0, args->lhs_object_id_size()) {
int64_t global_device_id =
vm->vm_resource_desc().GetGlobalDeviceId(machine_id, device_tag, device_id);
vm->vm_resource_desc().GetGlobalDeviceId(machine_id, device_id);
int64_t lhs_object_id = GetLogicalObjectId(args->lhs_object_id(i));
auto* lhs = vm->MutMirroredObject(lhs_object_id, global_device_id);
if (lhs != nullptr) { DoEachRhsObject(lhs, global_device_id); }
Expand Down
14 changes: 7 additions & 7 deletions oneflow/core/vm/stream.cpp
Expand Up @@ -20,19 +20,19 @@ limitations under the License.
namespace oneflow {
namespace vm {

void Stream::__Init__(ThreadCtx* thread_ctx, const StreamId& stream_id) {
void Stream::__Init__(ThreadCtx* thread_ctx, const StreamId& stream_id,
const int64_t max_device_num_per_machine) {
set_thread_ctx(thread_ctx);
mut_stream_id()->CopyFrom(stream_id);
// InitDeviceCtx may use max_device_num_per_machine,
// so max_device_num_per_machine must be set before InitDeviceCtx
set_max_device_num_per_machine(max_device_num_per_machine);
stream_type().InitDeviceCtx(mut_device_ctx(), this);
}

int64_t Stream::machine_id() const {
return global_device_id() / thread_ctx().stream_rt_desc().stream_desc().num_streams_per_machine();
}
int64_t Stream::machine_id() const { return global_device_id() / max_device_num_per_machine(); }

int64_t Stream::device_id() const {
return global_device_id() % thread_ctx().stream_rt_desc().stream_desc().num_streams_per_machine();
}
int64_t Stream::device_id() const { return global_device_id() % max_device_num_per_machine(); }

const StreamType& Stream::stream_type() const {
return thread_ctx().stream_rt_desc().stream_type();
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/vm/stream.msg.h
Expand Up @@ -28,7 +28,7 @@ class ThreadCtx;
// clang-format off
OBJECT_MSG_BEGIN(Stream);
// methods
OF_PUBLIC void __Init__(ThreadCtx* thread_ctx, const StreamId& stream_id);
OF_PUBLIC void __Init__(ThreadCtx* thread_ctx, const StreamId& stream_id, const int64_t max_device_num_per_machine);
OF_PUBLIC ObjectMsgPtr<Instruction> NewInstruction(InstructionMsg* instr_msg, const std::shared_ptr<ParallelDesc>& parallel_desc);
OF_PUBLIC void DeleteInstruction(ObjectMsgPtr<Instruction>&&);
OF_PUBLIC int64_t global_device_id() const { return stream_id().global_device_id(); }
Expand All @@ -40,6 +40,7 @@ OBJECT_MSG_BEGIN(Stream);
// fields
OBJECT_MSG_DEFINE_PTR(ThreadCtx, thread_ctx);
OBJECT_MSG_DEFINE_STRUCT(std::unique_ptr<DeviceCtx>, device_ctx);
OBJECT_MSG_DEFINE_OPTIONAL(int64_t, max_device_num_per_machine);

// links
OBJECT_MSG_DEFINE_LIST_LINK(active_stream_link);
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/stream_desc.msg.h
Expand Up @@ -69,7 +69,6 @@ OBJECT_MSG_BEGIN(StreamDesc);
OBJECT_MSG_DEFINE_OPTIONAL(int32_t, num_machines);
OBJECT_MSG_DEFINE_OPTIONAL(int32_t, num_streams_per_machine);
OBJECT_MSG_DEFINE_OPTIONAL(int32_t, num_streams_per_thread);
OBJECT_MSG_DEFINE_OPTIONAL(int32_t, start_global_device_id);

// links
OBJECT_MSG_DEFINE_SKIPLIST_KEY(7, StreamTypeId, stream_type_id);
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/transport_stream_type.cpp
Expand Up @@ -63,7 +63,6 @@ ObjectMsgPtr<StreamDesc> TransportStreamType::MakeTransportStreamDesc(
ret->set_num_streams_per_machine(device_num);
// TODO(lixinqi): refactor to a num_threads_per_machine field
ret->set_num_streams_per_thread(1);
ret->set_start_global_device_id(this_machine_id * device_num);
return ret;
}

Expand Down
5 changes: 3 additions & 2 deletions oneflow/core/vm/virtual_machine.cpp
Expand Up @@ -377,9 +377,10 @@ void VirtualMachine::__Init__(const VmDesc& vm_desc, ObjectMsgAllocator* allocat
for (int j = bs.At(i).begin(); j < bs.At(i).end(); ++j, ++rel_global_device_id) {
StreamId stream_id;
stream_id.__Init__(stream_desc->stream_type_id(),
stream_desc->start_global_device_id() + rel_global_device_id);
this_start_global_device_id() + rel_global_device_id);
auto stream =
ObjectMsgPtr<Stream>::NewFrom(mut_allocator(), thread_ctx.Mutable(), stream_id);
ObjectMsgPtr<Stream>::NewFrom(mut_allocator(), thread_ctx.Mutable(), stream_id,
vm_resource_desc().max_device_num_per_machine());
CHECK(stream_rt_desc->mut_stream_id2stream()->Insert(stream.Mutable()).second);
thread_ctx->mut_stream_list()->PushBack(stream.Mutable());
}
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/vm/virtual_machine.msg.h
Expand Up @@ -50,6 +50,9 @@ OBJECT_MSG_BEGIN(VirtualMachine);
int64_t global_device_id);

OF_PUBLIC int64_t this_machine_id() const;
OF_PUBLIC int64_t this_start_global_device_id() const {
return this_machine_id() * vm_resource_desc().max_device_num_per_machine();
}

// fields
OBJECT_MSG_DEFINE_OPTIONAL(VmResourceDesc, vm_resource_desc);
Expand Down
10 changes: 6 additions & 4 deletions oneflow/core/vm/vm_resource_desc.cpp
Expand Up @@ -29,16 +29,18 @@ void VmResourceDesc::__Init__(int64_t machine_num,
const DeviceTag2DeviceNum& device_tag2device_num) {
set_machine_num(machine_num);
*mutable_device_tag2device_num() = device_tag2device_num;
set_max_device_num_per_machine(0);
for (const auto& pair : device_tag2device_num) {
if (max_device_num_per_machine() < pair.second) { set_max_device_num_per_machine(pair.second); }
}
}

void VmResourceDesc::CopyFrom(const VmResourceDesc& vm_resource_desc) {
__Init__(vm_resource_desc.machine_num(), vm_resource_desc.device_tag2device_num());
}

int64_t VmResourceDesc::GetGlobalDeviceId(int64_t machine_id, const std::string& device_tag,
int64_t device_id) const {
int64_t device_num = device_tag2device_num().at(device_tag);
return machine_id * device_num + device_id;
int64_t VmResourceDesc::GetGlobalDeviceId(int64_t machine_id, int64_t device_id) const {
return machine_id * max_device_num_per_machine() + device_id;
}

void VmResourceDesc::GenerateParallelConf(const char* device_tag, ParallelConf* parallel_conf) {
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/vm/vm_resource_desc.msg.h
Expand Up @@ -35,11 +35,12 @@ OBJECT_MSG_BEGIN(VmResourceDesc);
OF_PUBLIC void __Init__(
int64_t machine_num, const DeviceTag2DeviceNum& device_tag2device_num);
OF_PUBLIC void CopyFrom(const VmResourceDesc& vm_resource_desc);
OF_PUBLIC int64_t GetGlobalDeviceId(int64_t machine_id, const std::string& device_tag, int64_t device_id) const;
OF_PUBLIC int64_t GetGlobalDeviceId(int64_t machine_id, int64_t device_id) const;
OF_PUBLIC void GenerateParallelConf(const char* device_tag, ParallelConf* parallel_conf);

// fields
OBJECT_MSG_DEFINE_OPTIONAL(int64_t, machine_num);
OBJECT_MSG_DEFINE_OPTIONAL(int64_t, max_device_num_per_machine);
OBJECT_MSG_DEFINE_STRUCT(DeviceTag2DeviceNum, device_tag2device_num);
OBJECT_MSG_END(VmResourceDesc);
// clang-format on
Expand Down

0 comments on commit 112b786

Please sign in to comment.