Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add communication dtypes for all-gathers and reduce scatters in depth tensor parallelism #64

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 11 additions & 2 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,19 @@ def gather(
OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
ALL_GATHER_ITERATOR = None
ALL_GATHER_DTYPE = torch.float32
REDUCE_SCATTER_DTYPE = torch.bfloat16
handles = []
pending_grad_accumulations = []
weights_cache = {}

def set_all_gather_dtype(dtype):
global ALL_GATHER_DTYPE
ALL_GATHER_DTYPE = dtype

def set_reduce_scatter_dtype(dtype):
global REDUCE_SCATTER_DTYPE
REDUCE_SCATTER_DTYPE = dtype

def register_handle(handle):
# ToDo: This might be unnecesary since
Expand Down Expand Up @@ -99,10 +108,10 @@ def trigger_async_all_gathers(model):
assert weight.ndim == 1
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
output_shape, dtype=ALL_GATHER_DTYPE, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight, weight, group=process_group, async_op=True
all_gathered_weight, weight.to(ALL_GATHER_DTYPE), group=process_group, async_op=True
)
weights_cache[weight] = [all_gathered_weight, handle]
yield
Expand Down
16 changes: 10 additions & 6 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def _gather(input_, dim, process_group=None, cache=False):
input_ = input_.contiguous()
# Size and dimension.
rank = dist.get_rank(process_group)

from axonn.intra_layer import ALL_GATHER_DTYPE

tensor_list = [
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
torch.empty_like(input_, dtype=ALL_GATHER_DTYPE) for _ in range(dist.get_world_size(process_group))
]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=process_group)
dist.all_gather(tensor_list, input_.to(ALL_GATHER_DTYPE), group=process_group)

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
Expand All @@ -70,17 +71,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
assert input_.shape[dim] % total_chunks == 0
tensor_shape = list(input_.shape)
tensor_shape[dim] //= total_chunks

from axonn.intra_layer import REDUCE_SCATTER_DTYPE

output = torch.empty(
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
tensor_shape, dtype=REDUCE_SCATTER_DTYPE, device=torch.cuda.current_device()
)

if hasattr(torch.distributed, "reduce_scatter_tensor"):
handle = torch.distributed.reduce_scatter_tensor(
output, input_, group=process_group, async_op=overlap_comm
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
)
else:
handle = torch.distributed._reduce_scatter_base(
output, input_, group=process_group, async_op=overlap_comm
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
)

if overlap_comm:
Expand Down
1 change: 1 addition & 0 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(
ctx.backward_comm_async = backward_comm_async
if not forward_comm_async:
output = input_.matmul(weight.t())

dist.all_reduce(output, group=forward_all_reduce_group, async_op=False)
else:
assert input_.shape[0] % 2 == 0
Expand Down