Skip to content

Commit

Permalink
add dtypes for reduce scatters and all gathers
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Feb 27, 2024
1 parent da66c0a commit d5a8ec5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
10 changes: 9 additions & 1 deletion axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,19 @@ def gather(
OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
ALL_GATHER_ITERATOR = None
ALL_GATHER_DTYPE = torch.bfloat16
ALL_GATHER_DTYPE = torch.float32
REDUCE_SCATTER_DTYPE = torch.float32
handles = []
pending_grad_accumulations = []
weights_cache = {}

def set_all_gather_dtype(dtype):
global ALL_GATHER_DTYPE
ALL_GATHER_DTYPE = dtype

def set_reduce_scatter_dtype(dtype):
global REDUCE_SCATTER_DTYPE
REDUCE_SCATTER_DTYPE = dtype

def register_handle(handle):
# ToDo: This might be unnecesary since
Expand Down
15 changes: 10 additions & 5 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def _gather(input_, dim, process_group=None, cache=False):
input_ = input_.contiguous()
# Size and dimension.
rank = dist.get_rank(process_group)

from axonn.intra_layer import ALL_GATHER_DTYPE

tensor_list = [
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
torch.empty_like(input_, dtype=ALL_GATHER_DTYPE) for _ in range(dist.get_world_size(process_group))
]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=process_group)
dist.all_gather(tensor_list, input_.to(ALL_GATHER_DTYPE), group=process_group)

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
Expand All @@ -70,17 +72,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
assert input_.shape[dim] % total_chunks == 0
tensor_shape = list(input_.shape)
tensor_shape[dim] //= total_chunks

from axonn.intra_layer import REDUCE_SCATTER_DTYPE

output = torch.empty(
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
tensor_shape, dtype=REDUCE_SCATTER_DTYPE, device=torch.cuda.current_device()
)

if hasattr(torch.distributed, "reduce_scatter_tensor"):
handle = torch.distributed.reduce_scatter_tensor(
output, input_, group=process_group, async_op=overlap_comm
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
)
else:
handle = torch.distributed._reduce_scatter_base(
output, input_, group=process_group, async_op=overlap_comm
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
)

if overlap_comm:
Expand Down

0 comments on commit d5a8ec5

Please sign in to comment.