diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 5d08cbc..3a4ea3a 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -43,10 +43,19 @@ def gather( OVERLAP_REDUCE_SCATTER = False OVERLAP_ALL_REDUCE = False ALL_GATHER_ITERATOR = None +ALL_GATHER_DTYPE = torch.float32 +REDUCE_SCATTER_DTYPE = torch.bfloat16 handles = [] pending_grad_accumulations = [] weights_cache = {} +def set_all_gather_dtype(dtype): + global ALL_GATHER_DTYPE + ALL_GATHER_DTYPE = dtype + +def set_reduce_scatter_dtype(dtype): + global REDUCE_SCATTER_DTYPE + REDUCE_SCATTER_DTYPE = dtype def register_handle(handle): # ToDo: This might be unnecesary since @@ -99,10 +108,10 @@ def trigger_async_all_gathers(model): assert weight.ndim == 1 output_shape = weight.shape[0] * world_size all_gathered_weight = torch.empty( - output_shape, dtype=weight.dtype, device=weight.device + output_shape, dtype=ALL_GATHER_DTYPE, device=weight.device ) handle = dist.all_gather_into_tensor( - all_gathered_weight, weight, group=process_group, async_op=True + all_gathered_weight, weight.to(ALL_GATHER_DTYPE), group=process_group, async_op=True ) weights_cache[weight] = [all_gathered_weight, handle] yield diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index a6c3265..c64fdda 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -44,12 +44,13 @@ def _gather(input_, dim, process_group=None, cache=False): input_ = input_.contiguous() # Size and dimension. rank = dist.get_rank(process_group) + + from axonn.intra_layer import ALL_GATHER_DTYPE tensor_list = [ - torch.empty_like(input_) for _ in range(dist.get_world_size(process_group)) + torch.empty_like(input_, dtype=ALL_GATHER_DTYPE) for _ in range(dist.get_world_size(process_group)) ] - tensor_list[rank] = input_ - dist.all_gather(tensor_list, input_, group=process_group) + dist.all_gather(tensor_list, input_.to(ALL_GATHER_DTYPE), group=process_group) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=dim).contiguous() @@ -70,17 +71,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False): assert input_.shape[dim] % total_chunks == 0 tensor_shape = list(input_.shape) tensor_shape[dim] //= total_chunks + + from axonn.intra_layer import REDUCE_SCATTER_DTYPE + output = torch.empty( - tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device() + tensor_shape, dtype=REDUCE_SCATTER_DTYPE, device=torch.cuda.current_device() ) if hasattr(torch.distributed, "reduce_scatter_tensor"): handle = torch.distributed.reduce_scatter_tensor( - output, input_, group=process_group, async_op=overlap_comm + output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm ) else: handle = torch.distributed._reduce_scatter_base( - output, input_, group=process_group, async_op=overlap_comm + output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm ) if overlap_comm: diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index f4561e4..0f56f9d 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -79,6 +79,7 @@ def forward( ctx.backward_comm_async = backward_comm_async if not forward_comm_async: output = input_.matmul(weight.t()) + dist.all_reduce(output, group=forward_all_reduce_group, async_op=False) else: assert input_.shape[0] % 2 == 0