Skip to content

Ensure Consistency among Data Accesses in the same Thread Block #551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions python/mscclpp/language/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
tb_channel_ids = get_program().setup_channel(tb, self)

op = GetOperation(
src_buff=[RemoteChunk(tb_chunk_id, src_chunk.index, src_chunk.size)],
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
Expand Down Expand Up @@ -82,7 +82,7 @@ def put(

op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(tb_chunk_id, dst_chunk.index, dst_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
)
Expand Down Expand Up @@ -113,7 +113,7 @@ def put_packet(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int, from_packet: b

op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(tb_chunk_id, dst_chunk.index, dst_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
from_packet=from_packet,
Expand Down Expand Up @@ -152,14 +152,15 @@ def reduce(

remote_chunks = [
RemoteChunk(
chunk.buffer,
chunk.index,
chunk.size,
get_program().setup_remote_chunk(
self.src_rank,
tb,
RemoteBuffer(local_src_chunk.rank, chunk.rank, chunk.buffer, self.channel_type),
self.channel_type,
),
chunk.index,
chunk.size,
)
for chunk in remote_src_chunks
]
Expand Down Expand Up @@ -235,7 +236,7 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):

op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(tb_chunk_id, dst_chunk.index, dst_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
)
Expand All @@ -262,7 +263,7 @@ def put_with_signal(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):

op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(tb_chunk_id, dst_chunk.index, dst_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
with_signal=True,
Expand Down Expand Up @@ -290,7 +291,7 @@ def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int)

op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(tb_chunk_id, dst_chunk.index, dst_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
with_signal_and_flush=True,
Expand Down Expand Up @@ -323,7 +324,7 @@ def put_packet(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):

op = PutOperation(
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(tb_chunk_id, dst_chunk.index, dst_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
from_packet=True,
Expand Down
19 changes: 19 additions & 0 deletions python/mscclpp/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,22 @@ def init_buffers(self):
}
rank_buffers.append(buffers)
return rank_buffers


class ReduceScatter(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "reducescatter"

# Initializes input buffer for an allgather
def init_buffers(self):
rank_buffers = []
for rank in range(self.num_ranks):
input_buffer_size = self.num_ranks * self.chunk_factor
output_buffer_size = self.chunk_factor
buffers = {
BufferType.input: BaseBuffer(rank, BufferType.input, 0, input_buffer_size),
BufferType.output: BaseBuffer(rank, BufferType.output, 0, output_buffer_size),
}
rank_buffers.append(buffers)
return rank_buffers
88 changes: 88 additions & 0 deletions python/mscclpp/language/internal/buffer_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from sortedcontainers import SortedDict
from typing import List
from mscclpp.language.internal.types import BufferType, DataAccessType
from mscclpp.language.internal.operations import *
from enum import Enum


class BuffersAccess:
def __init__(self):
self.intervals = {
BufferType.input: SortedDict(),
BufferType.output: SortedDict(),
BufferType.scratch: SortedDict(),
}

def process_operations(self, operations):
result_operations = []
for operation in operations:
if operation.name == Instruction.nop or operation.name == Instruction.barrier:
self.clear_data_access()
else:
data_access = operation.local_data_access()
sync_added = False
for data_access_element in data_access:
if self.compute_data_access(data_access_element) and not sync_added:
result_operations.append(SyncOperation())
sync_added = True

result_operations.append(operation)

return result_operations

def compute_data_access(self, data_access: DataAccess) -> bool:
keys = self.intervals[data_access.buffer_type].keys()
idx = self.lower_bound(0, len(keys) - 1, keys, data_access)
conflict = False

while len(keys) > 0 and data_access.overlaps(keys[idx]):
conflict_data_access = keys[idx]
conflict_operation_type = self.intervals[data_access.buffer_type][conflict_data_access]
if data_access.check_conflict(conflict_data_access):
self.clear_data_access()
conflict = True
break

self.intervals[data_access.buffer_type].pop(conflict_data_access)
if conflict_data_access.end > data_access.end:
self.intervals[data_access.buffer_type][
DataAccess(
conflict_data_access.operation_id,
data_access.end + 1,
conflict_data_access.end,
conflict_data_access.buffer_type,
conflict_operation_type,
)
] = conflict_operation_type
if conflict_data_access.start < data_access.start:
self.intervals[data_access.buffer_type][
DataAccess(
conflict_data_access.operation_id,
conflict_data_access.start,
data_access.start - 1,
conflict_data_access.buffer_type,
conflict_operation_type,
)
] = conflict_operation_type

keys = self.intervals[data_access.buffer_type].keys()
idx = self.lower_bound(0, len(keys) - 1, keys, data_access)

self.intervals[data_access.buffer_type][data_access] = data_access.data_access_type
return conflict

def clear_data_access(self):
self.intervals[BufferType.input].clear()
self.intervals[BufferType.output].clear()
self.intervals[BufferType.scratch].clear()

def lower_bound(self, init_pos, final_pos, data_access_list, data_access):
if init_pos >= final_pos:
return init_pos

mid_pos = (init_pos + final_pos) // 2
if data_access.start <= data_access_list[mid_pos].end:
final_pos = mid_pos
else:
init_pos = mid_pos + 1
return self.lower_bound(init_pos, final_pos, data_access_list, data_access)
24 changes: 15 additions & 9 deletions python/mscclpp/language/internal/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from mscclpp.language.internal.threadblock import ThreadBlock
from mscclpp.language.internal.operations import BaseOperation
from dataclasses import dataclass, field
from collections import *
from typing import List


Expand All @@ -12,16 +13,16 @@ class Gpu:
output_chunks: int = 0
scratch_chunks: int = 0
threadblocks: list = field(default_factory=list)
remote_buffers: dict = field(default_factory=dict)
remote_buffers: OrderedDict = field(default_factory=OrderedDict)

__channels: dict = field(default_factory=dict, init=False)
__nvls_channels: dict = field(default_factory=dict, init=False)
__nvls_channels: list = field(default_factory=list, init=False)

def add_channel(self, channel):
if channel.channel_type == ChannelType.switch:
if channel.buffer_type not in self.__nvls_channels:
self.__nvls_channels[channel.buffer_type] = Gpu.NVLSChannel(buffer_type=channel.buffer_type)
self.__nvls_channels[channel.buffer_type].rank_groups.append(channel.rank_group)
self.__nvls_channels.append(
Gpu.NVLSChannel(buffer_type=channel.buffer_type, rank_groups=[channel.rank_group])
)
else:
if channel.channel_type not in self.__channels:
self.__channels[channel.channel_type] = Gpu.Channel(channel_type=channel.channel_type)
Expand All @@ -37,8 +38,9 @@ def add_remote_buffer(self, tb: int, remote_buffer: RemoteBuffer, channel_access
if remote_buffer not in self.remote_buffers:
remote_buffer_id = len(self.remote_buffers)
else:
remote_buffer_id = self.remote_buffers.pop(remote_buffer)
self.remote_buffers[remote_buffer] = remote_buffer_id
remote_buffer_id, existing_remote_buffer = self.remote_buffers[remote_buffer]
remote_buffer.channel_access |= existing_remote_buffer.channel_access
self.remote_buffers[remote_buffer] = (remote_buffer_id, remote_buffer)

for i in range(len(self.threadblocks), tb + 1):
self.threadblocks.append(ThreadBlock(self.id, i))
Expand All @@ -59,6 +61,10 @@ def adding_data_sync(self):
for tb in self.threadblocks:
tb.adding_data_sync()

def resolve_data_dependency(self):
for tb in self.threadblocks:
tb.resolve_data_dependency()

def to_json(self) -> dict:
return {
"id": self.id,
Expand All @@ -67,8 +73,8 @@ def to_json(self) -> dict:
"scratch_chunks": self.scratch_chunks,
"threadblocks": [tb.to_json() for tb in self.threadblocks],
"channels": [ch.to_json() for ch in self.__channels.values()]
+ [ch.to_json() for ch in self.__nvls_channels.values()],
"remote_buffers": [rb.to_json() for rb in self.remote_buffers.keys()],
+ [ch.to_json() for ch in self.__nvls_channels],
"remote_buffers": [rb[1].to_json() for rb in self.remote_buffers.values()],
}

@dataclass
Expand Down
Loading