Skip to content

Change executor to load new execution plan #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 82 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
6b3d728
init pipeline
Binyang2014 May 19, 2025
995f09b
WIP
Binyang2014 May 20, 2025
d5e13d6
compile pass
Binyang2014 May 20, 2025
b97f744
WIP
Binyang2014 May 20, 2025
32bfcb0
WIP
Binyang2014 May 21, 2025
0aa7777
add more ops
Binyang2014 May 21, 2025
d8ab0fd
WIP
Binyang2014 May 21, 2025
e3c12ea
WIP
Binyang2014 May 22, 2025
ccae90c
WIP
Binyang2014 May 22, 2025
9dab715
WIP
Binyang2014 May 22, 2025
32778b7
WIP
Binyang2014 May 22, 2025
2fec3e0
start hard part
Binyang2014 May 22, 2025
537b379
WIP
Binyang2014 May 23, 2025
033e940
remove chunkGroups
Binyang2014 May 23, 2025
47d63d4
WIP
Binyang2014 May 25, 2025
27350ad
WIP
Binyang2014 May 25, 2025
995fb5e
WIP
Binyang2014 May 27, 2025
5238989
WIP
Binyang2014 May 27, 2025
47cd0e8
WIP
Binyang2014 May 27, 2025
09e4723
WIP
Binyang2014 May 28, 2025
256c80b
revert
Binyang2014 May 28, 2025
b2be1bb
Merge branch 'feature/dsl' into binyli/refactor
Binyang2014 May 28, 2025
10d7285
WIP
Binyang2014 May 28, 2025
bb7b8cc
WIP
Binyang2014 May 28, 2025
419c102
format
Binyang2014 May 28, 2025
e6a8351
kernel update
Binyang2014 May 28, 2025
7c87feb
rename
Binyang2014 May 29, 2025
fd3a926
update to make json wokrs
Binyang2014 May 29, 2025
abd8f86
WIP
Binyang2014 May 30, 2025
725a44f
WIP
Binyang2014 May 30, 2025
0ce3f14
Merge branch 'feature/dsl' into binyli/refactor
Binyang2014 May 30, 2025
c3fedd3
WIP
Binyang2014 May 30, 2025
f2654f6
WIP
Binyang2014 May 30, 2025
6830e15
work for nvls
Binyang2014 May 30, 2025
0264068
update with nvls json
Binyang2014 May 30, 2025
a593f1b
update pipeline
Binyang2014 Jun 1, 2025
57160b2
WIP
Binyang2014 Jun 1, 2025
7ed2c4a
WIP
Binyang2014 Jun 2, 2025
445a1fd
WIP
Binyang2014 Jun 3, 2025
5de882f
pipeline with WA
Binyang2014 Jun 3, 2025
c0eb84c
WIP
Binyang2014 Jun 3, 2025
46bb6e4
correctness issue fixed
Binyang2014 Jun 3, 2025
49f1271
WIP
Binyang2014 Jun 4, 2025
ed9720d
WIP
Binyang2014 Jun 5, 2025
48ea2a7
try to only load execution per rank
Binyang2014 Jun 5, 2025
382bfb6
WIP
Binyang2014 Jun 5, 2025
22bd20d
WIP
Binyang2014 Jun 6, 2025
c560548
WIP
Binyang2014 Jun 7, 2025
ce29ed6
WIP
Binyang2014 Jun 8, 2025
3cc56c3
clean up
Binyang2014 Jun 8, 2025
17b8f70
WIP
Binyang2014 Jun 9, 2025
f49dbc4
WIP
Binyang2014 Jun 9, 2025
98ae324
WIP
Binyang2014 Jun 9, 2025
55f2ecb
code clean
Binyang2014 Jun 9, 2025
8c75ef3
run with hang
Binyang2014 Jun 10, 2025
15f95b3
bug fix
Binyang2014 Jun 11, 2025
7af4a4f
minor fix
Binyang2014 Jun 11, 2025
6a79375
Merge branch 'feature/dsl' into binyli/refactor
Binyang2014 Jun 11, 2025
1f984fa
bug fix
Binyang2014 Jun 11, 2025
6c7184d
Merge branch 'feature/dsl' into binyli/refactor
Binyang2014 Jun 11, 2025
748676a
update
Binyang2014 Jun 11, 2025
bad29fb
add back more ops
Binyang2014 Jun 11, 2025
437dc2d
WIP
Binyang2014 Jun 12, 2025
4ff95ed
finish all ops and ready to debug
Binyang2014 Jun 12, 2025
4eb136b
Align with json plan
Binyang2014 Jun 13, 2025
4f95aed
enable double scratch buffer
Binyang2014 Jun 17, 2025
f771bfa
Merge branch 'feature/dsl' into binyli/refactor
Binyang2014 Jun 17, 2025
cc66300
remove some files
Binyang2014 Jun 17, 2025
3ea04ac
address comments
Binyang2014 Jun 18, 2025
00e78e9
fix barrier issue
Binyang2014 Jun 18, 2025
b3b9272
update for npkit
Binyang2014 Jun 18, 2025
4ac3d93
code clean
Binyang2014 Jun 18, 2025
5843b12
update
Binyang2014 Jun 19, 2025
e180fc0
update
Binyang2014 Jul 1, 2025
9a4fc74
Merge branch 'feature/dsl' into binyli/refactor
Binyang2014 Jul 1, 2025
d2a55b2
some bug fix
Binyang2014 Jul 1, 2025
8a07297
update
Binyang2014 Jul 2, 2025
391e29e
add reduce kernel
Binyang2014 Jul 2, 2025
0b5fe79
update structure
Binyang2014 Jul 2, 2025
014fe26
Merge branch 'feature/dsl' into binyli/refactor
Binyang2014 Jul 3, 2025
d7aa1f2
WIP
Binyang2014 Jul 3, 2025
9f85d70
fix
Binyang2014 Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/mscclpp/copy_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ MSCCLPP_DEVICE_INLINE void copy(void* dst, void* src, uint64_t bytes, uint32_t t
}
}

template <typename T>
MSCCLPP_DEVICE_INLINE void write(void* dst, uint64_t index, const T& v) {
*(reinterpret_cast<T*>(dst) + index) = v;
}

template <typename T>
MSCCLPP_DEVICE_INLINE T read(void* src, uint64_t index) {
return *(reinterpret_cast<T*>(src) + index);
}

/// Read data from the origin and write packets to the target buffer.
///
/// @param targetPtr The target buffer.
Expand Down
304 changes: 141 additions & 163 deletions src/executor/execution_plan.cc

Large diffs are not rendered by default.

136 changes: 65 additions & 71 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,17 @@ struct ExecutionContext {
std::shared_ptr<ProxyService> proxyService;
std::unordered_map<int, std::shared_ptr<Connection>> connections;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
std::unordered_map<std::pair<BufferType, int>, mscclpp::RegisteredMemory> registeredMemories;

// For registered memories, registeredMemoryAddresses is used for memoryChannel and registeredMemoryIds is used for
// proxy channel
std::vector<mscclpp::RegisteredMemory> registeredMemories;
std::vector<void*> registeredMemoryAddresses;
std::vector<mscclpp::MemoryId> registeredMemoryIds;

std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
std::vector<mscclpp::MemoryChannel> memoryChannels;
std::vector<mscclpp::PortChannel> portChannels;
std::vector<mscclpp::BaseMemoryChannel> memoryChannels;
std::vector<mscclpp::BasePortChannel> portChannels;
std::vector<mscclpp::NvlsConnection::DeviceMulticastPointer> nvlsChannels;
std::unordered_map<DeviceExecutionPlanKey, std::vector<DeviceExecutionPlan>> deviceExecutionPlans;
std::unordered_map<DeviceExecutionPlanKey, std::shared_ptr<char>> deviceExecutionPlansBuffers;
Expand Down Expand Up @@ -165,7 +171,7 @@ struct Executor::Impl {
}

plan.impl_->reset();
plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);
plan.impl_->loadExecutionPlan(rank, inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);

ExecutionContext context;
size_t maxScratchBufferSize = plan.impl_->getMaxScratchBufferSize(rank);
Expand All @@ -177,8 +183,8 @@ struct Executor::Impl {
context.proxyService = std::make_shared<ProxyService>();
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
this->setupConnections(context, rank, plan, sendMemRange, recvMemRange);
this->setupChannels(context, rank, plan);
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupNvlsChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan);
context.deviceExecutionPlansBuffers[devicePlanKey] =
Expand All @@ -192,18 +198,16 @@ struct Executor::Impl {
return context;
}

TransportFlags getTransportFlags(std::vector<ChannelInfo>& infos, int rank) {
TransportFlags getTransportFlags(const BufferInfo& info, int rank) {
TransportFlags flags;
for (ChannelInfo& info : infos) {
if (info.channelType == ChannelType::MEMORY) {
for (const ChannelType& type : info.accessChannelTypes) {
if (type == ChannelType::MEMORY) {
flags |= Transport::CudaIpc;
} else if (info.channelType == ChannelType::PORT) {
for (int peer : info.connectedPeers) {
if (!inSameNode(rank, peer, this->nranksPerNode)) {
flags |= IBs[rank % this->nranksPerNode];
} else
flags |= Transport::CudaIpc;
}
} else if (type == ChannelType::PORT) {
if (!inSameNode(rank, info.accessRank, this->nranksPerNode)) {
flags |= IBs[rank % this->nranksPerNode];
} else
flags |= Transport::CudaIpc;
}
}
return flags;
Expand Down Expand Up @@ -244,40 +248,42 @@ struct Executor::Impl {
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
};
auto getConnectedPeers = [&](std::vector<ChannelInfo>& infos) {
std::set<int> peers;
for (ChannelInfo& info : infos) {
for (int peer : info.connectedPeers) {
peers.insert(peer);
}

// Add local src,dst and scratch to registeredMemoryIds
for (auto& bufferType : {BufferType::INPUT, BufferType::OUTPUT, BufferType::SCRATCH}) {
TransportFlags flags = Transport::CudaIpc;
#if defined(USE_IBVERBS)
flags |= IBs[rank % this->nranksPerNode];
#endif
RegisteredMemory localMemory;
auto bufferInfo = getBufferInfo(bufferType);
if (bufferInfo.second > 0) {
localMemory =
this->comm->registerMemory(getBufferInfo(bufferType).first, getBufferInfo(bufferType).second, flags);
}
return std::vector<int>(peers.begin(), peers.end());
};
context.proxyService->addMemory(localMemory);
}

std::vector<BufferType> bufferTypes = plan.impl_->getConnectedBufferTypes(rank);
for (BufferType bufferType : bufferTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfosByDstRank(rank, bufferType);
TransportFlags transportFlags = getTransportFlags(channelInfos, rank);
for (const auto& bufferInfo : plan.impl_->getLocalBufferToSend(rank)) {
RegisteredMemory memory =
this->comm->registerMemory(getBufferInfo(bufferType).first, getBufferInfo(bufferType).second, transportFlags);
std::vector<int> connectedPeers = getConnectedPeers(channelInfos);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
for (int peer : connectedPeers) {
comm->sendMemory(memory, peer, 0);
}
channelInfos = plan.impl_->getChannelInfos(rank, bufferType);
connectedPeers = getConnectedPeers(channelInfos);
for (int peer : connectedPeers) {
remoteRegMemoryFutures.push_back(comm->recvMemory(peer, 0));
}
for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) {
context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get());
this->comm->registerMemory(getBufferInfo(bufferInfo.bufferType).first,
getBufferInfo(bufferInfo.bufferType).second, getTransportFlags(bufferInfo, rank));
comm->sendMemory(memory, bufferInfo.accessRank, 0);
}
for (const auto& bufferInfo : plan.impl_->getRemoteBufferInfos(rank)) {
std::shared_future<RegisteredMemory> remoteRegMemoryFuture = comm->recvMemory(bufferInfo.rank, 0);
context.registeredMemories.emplace_back(std::move(remoteRegMemoryFuture.get()));
for (ChannelType chanType : bufferInfo.accessChannelTypes) {
if (chanType == ChannelType::MEMORY) {
context.registeredMemoryAddresses.push_back(context.registeredMemories.back().data());
} else if (chanType == ChannelType::PORT) {
context.registeredMemoryIds.push_back(context.proxyService->addMemory(context.registeredMemories.back()));
}
}
}
}

void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
void setupChannels(ExecutionContext& context, int rank, const ExecutionPlan& plan) {
const auto channelTypes = {ChannelType::MEMORY, ChannelType::PORT};
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
Expand All @@ -300,44 +306,22 @@ struct Executor::Impl {
// Current semaphore construction requires two-way communication, e.g., to construct a semaphore signaling from
// rank 0 to rank 1, both rank 0 and rank 1 need to send a message to each other. This PR fixes an executor bug
// that fails to conduct two-way communication for constructing such one-way semaphores, and instead hangs
// during the semaphore construction. In the future, we may need to change the implementation to construct
// semaphore via one-way communication.
// during the semaphore construction.
channelInfos = plan.impl_->getUnpairedChannelInfos(rank, nranks, channelType);
processChannelInfos(channelInfos);
}
context.memorySemaphores = std::move(memorySemaphores);
context.proxySemaphores = std::move(proxySemaphores);

auto getBufferSize = [&](BufferType type) {
switch (type) {
case BufferType::INPUT:
return sendBufferSize;
case BufferType::OUTPUT:
return recvBufferSize;
case BufferType::SCRATCH:
return context.scratchBufferSize;
default:
throw Error("Invalid buffer type", ErrorCode::ExecutorError);
}
};

for (ChannelType channelType : channelTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(rank, channelType);
int index = 0;
for (ChannelInfo& info : channelInfos) {
void* src = getBuffer(info.srcBufferType, sendbuff, recvbuff, context.scratchBuffer.get());
size_t bufferSize = getBufferSize(info.srcBufferType);
TransportFlags transport = getTransportFlags(channelInfos, rank);
RegisteredMemory localMemory = this->comm->registerMemory(src, bufferSize, transport);
for (int peer : info.connectedPeers) {
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
if (channelType == ChannelType::MEMORY) {
context.memoryChannels.emplace_back(context.memorySemaphores[index++],
context.registeredMemories[{info.dstBufferType, peer}], src, nullptr);
context.memoryChannels.emplace_back(context.memorySemaphores[index++]);
} else if (channelType == ChannelType::PORT) {
context.portChannels.emplace_back(context.proxyService->portChannel(
context.proxySemaphores[index++],
context.proxyService->addMemory(context.registeredMemories[{info.dstBufferType, peer}]),
context.proxyService->addMemory(localMemory)));
context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++]));
}
}
}
Expand Down Expand Up @@ -367,17 +351,27 @@ struct Executor::Impl {
deviceExecutionPlan.nMemoryChannels = plan.impl_->threadblockMemoryChannelMap.at(rank).at(threadblock).size();
deviceExecutionPlan.nPortChannels = plan.impl_->threadblockPortChannelMap.at(rank).at(threadblock).size();
int chanIndex = 0;
for (const auto& [index, _] : plan.impl_->threadblockMemoryChannelMap.at(rank).at(threadblock)) {
for (const int index : plan.impl_->threadblockMemoryChannelMap.at(rank).at(threadblock)) {
deviceExecutionPlan.channels.memoryChannels[chanIndex++] = mscclpp::deviceHandle(context.memoryChannels[index]);
}
chanIndex = 0;
for (const auto& [index, _] : plan.impl_->threadblockPortChannelMap.at(rank).at(threadblock)) {
for (const int index : plan.impl_->threadblockPortChannelMap.at(rank).at(threadblock)) {
deviceExecutionPlan.channels.portChannels[chanIndex++] = mscclpp::deviceHandle(context.portChannels[index]);
}
chanIndex = 0;
for (const auto& [index, _] : plan.impl_->threadblockNvlsChannelMap.at(rank).at(threadblock)) {
for (const int index : plan.impl_->threadblockNvlsChannelMap.at(rank).at(threadblock)) {
deviceExecutionPlan.channels.nvlsChannels[chanIndex++] = mscclpp::deviceHandle(context.nvlsChannels[index]);
}
int memIndex = 0;
for (const int index : plan.impl_->threadblockMemoryChannelBufferMap.at(rank).at(threadblock)) {
deviceExecutionPlan.remoteBuffers.remoteBuffersViaMemoryChannel[memIndex++] =
context.registeredMemoryAddresses[index];
}
memIndex = 0;
for (const int index : plan.impl_->threadblockPortChannelBufferMap.at(rank).at(threadblock)) {
deviceExecutionPlan.remoteBuffers.remoteBuffersViaPortChannel[memIndex++] = context.registeredMemoryIds[index];
}

if (ops.size() > MAX_OPERATION) {
throw Error("Executor plan launching " + std::to_string(ops.size()) +
" operations, exceeding device execution plan support (" + std::to_string(MAX_OPERATION) + ")",
Expand Down
62 changes: 40 additions & 22 deletions src/include/execution_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ namespace mscclpp {
constexpr int MAX_CHANNEL = 16;
constexpr int MAX_CHANNEL_PER_OPERATION = 8;
constexpr int MAX_OPERATION = 64;
constexpr int INPUT_BUFFER_MEMORY_ID = 0;
constexpr int OUTPUT_BUFFER_MEMORY_ID = 1;
constexpr int SCRATCH_BUFFER_MEMORY_ID = 2;
constexpr int MAX_RESERVED_MEMORY_IDS = 3;

enum class BufferType : uint8_t {
NONE,
Expand Down Expand Up @@ -51,47 +55,60 @@ enum class OperationType : uint8_t {
READ_REDUCE_COPY,
READ_REDUCE_COPY_SEND,
MULTI_LOAD_REDUCE_STORE,
RELAXED_SIGNAL,
RELAXED_WAIT,
PIPELINE,
};

struct Channels {
mscclpp::DeviceHandle<mscclpp::MemoryChannel> memoryChannels[MAX_CHANNEL];
mscclpp::DeviceHandle<mscclpp::PortChannel> portChannels[MAX_CHANNEL];
mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel> memoryChannels[MAX_CHANNEL];
mscclpp::DeviceHandle<mscclpp::BasePortChannel> portChannels[MAX_CHANNEL];
mscclpp::DeviceHandle<mscclpp::NvlsConnection::DeviceMulticastPointer> nvlsChannels[MAX_CHANNEL];
};

struct RemoteBuffers {
void* remoteBuffersViaMemoryChannel[MAX_CHANNEL];
MemoryId remoteBuffersViaPortChannel[MAX_CHANNEL];
};

union BufferRef {
uint8_t id;
BufferType type;
};

struct Operation {
OperationType type;
ChannelType channelType;
BufferType srcBufferType;
BufferType dstBufferType;
uint8_t nInputs;
uint8_t nOutputs;
union {
// For ops which require reading from multiple remote sources
uint8_t inputChannelIndexes[MAX_CHANNEL_PER_OPERATION];
// For ops which require reading from multiple local sources
BufferType inputBufferType;
BufferRef inputBufferRefs[MAX_CHANNEL_PER_OPERATION];
uint8_t nvlsInputIndex;
};
union {
// For ops which require writing to multiple remote destinations
uint8_t outputChannelIndexes[MAX_CHANNEL_PER_OPERATION];
// For ops which require writing to multiple local destinations
BufferType outputBufferType;
BufferRef outputBufferRefs[MAX_CHANNEL_PER_OPERATION];
uint8_t nvlsOutputIndex;
};

union {
// For Barrier operation
struct {
uint32_t deviceSyncerIndex;
uint32_t nThreadBlocks;
};
struct {
uint8_t channelIndexes[MAX_CHANNEL_PER_OPERATION];
uint32_t inputOffsets[MAX_CHANNEL_PER_OPERATION];
uint32_t outputOffsets[MAX_CHANNEL_PER_OPERATION];
uint32_t srcOffset;
uint32_t dstOffset;
uint32_t size;
uint32_t inputBufferSizes[MAX_CHANNEL_PER_OPERATION];
uint32_t outputBufferSizes[MAX_CHANNEL_PER_OPERATION];

uint8_t nChannels;
uint8_t nInputs;
uint8_t nOutputs;
};
struct {
uint32_t unitSize;
uint32_t maxBufferSize;
uint8_t nIterations;
uint8_t nOperations;
};
struct {
uint32_t deviceSyncerIndex;
uint32_t nThreadBlocks;
};
};
};
Expand All @@ -102,6 +119,7 @@ struct __attribute__((aligned(16))) DeviceExecutionPlan {
uint8_t nPortChannels; // 1 bytes
uint16_t nOperations; // 2 bytes
Channels channels; // 2304 bytes
RemoteBuffers remoteBuffers; // 192 bytes
Operation operations[MAX_OPERATION]; // 64 * 100 = 6400 bytes
};

Expand Down
Loading