Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding AxoNN's 3D tensor parallelism [WIP] #1086

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions configs/125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
{
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,
"pipe_parallel_size": 0,
"model_parallel_size": 2,

## axonn's arguments
"use_axonn_model_parallelism": true,
## these are the 3 dimensions of AxoNN's TP
"depth_model_parallel_size": 2,
"row_model_parallel_size": 1,
"column_model_parallel_size": 1,
"optimize_axonn_communication": true,

# model settings
"num_layers": 12,
Expand Down Expand Up @@ -39,7 +47,7 @@

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"stage": 0,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
Expand Down Expand Up @@ -84,11 +92,15 @@
"eval_iters": 10,

# logging
"log_interval": 100,
"log_interval": 1,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,

# networking
"hostfile": "/mock_path"

"data-path": "./data/enwik8/enwik8_text_document",
"vocab-file": "./data/gpt2-vocab.json",
"merge-file": "./data/gpt2-merges.txt"
# networking
#"hostfile": "/mock_path"
}
34 changes: 33 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = bb1b145
Default = 7438b33

current git hash of repository

Expand Down Expand Up @@ -858,6 +858,38 @@ Parallelism Arguments

Default = 1





- **use_axonn_model_parallelism**: bool

Default = False





- **row_model_parallel_size**: int

Default = 1





- **column_model_parallel_size**: int

Default = 1





- **depth_model_parallel_size**: int

Default = 1

Size of the model parallelism.


Expand Down
23 changes: 22 additions & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import deepspeed
import inspect

from axonn import axonn as ax

def initialize_megatron(neox_args, allow_no_cuda=False):
"""Set initialize distributed and set autoresume and random seeds.
Expand Down Expand Up @@ -188,6 +188,27 @@ def _initialize_distributed(neox_args):
fp32_allreduce=neox_args.fp32_allreduce,
)



if neox_args.use_axonn_model_parallelism:
row_mp = neox_args.row_model_parallel_size
column_mp = neox_args.column_model_parallel_size
depth_mp = neox_args.depth_model_parallel_size
assert row_mp * column_mp * depth_mp == neox_args.model_parallel_size, "product of row-model-parallel-size, column-model-parallel-sizem and depth-model-parallel-size should equal model-parallel-size"
assert neox_args.pipe_parallel_size == 0, "AxoNN's tensor parallelism has not been tested with pipeline parallelism"
ax.init(
G_inter= pp,
G_data = dp,
G_intra_r = neox_args.row_model_parallel_size,
G_intra_c = neox_args.column_model_parallel_size,
G_intra_d = neox_args.depth_model_parallel_size,
)
print(
f"> initialized AxoNN with G_intra_r={neox_args.row_model_parallel_size},"
f"G_intra_c={neox_args.column_model_parallel_size}",
f"G_intra_d={neox_args.depth_model_parallel_size}",
)

# Init DeepSpeed Activation Checkpointing Features
setup_deepspeed_random_and_activation_checkpointing(neox_args=neox_args)

Expand Down
165 changes: 123 additions & 42 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
)
from megatron.model.utils import configure_sparse_attention

from axonn.intra_layer import Linear, drop, gather

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
Expand Down Expand Up @@ -93,30 +95,57 @@ def __init__(
if self.activation_type == "geglu"
else ff_mult * neox_args.hidden_size
)
self.dense_h_to_4h = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
)
if neox_args.use_axonn_model_parallelism:
self.dense_h_to_4h = Linear(
in_features = neox_args.hidden_size,
out_features = ff_dim,
init_method = init_method,
skip_bias_add = True
)
else:
self.dense_h_to_4h = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
)
ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim_in,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
)

if neox_args.use_axonn_model_parallelism:
self.dense_4h_to_h = Linear(
in_features = ff_dim_in,
out_features = neox_args.hidden_size,
init_method = output_layer_init_method,
skip_bias_add = True,
transpose=True
)
assert not parallel_output, "ToDO: Implement axonn support for parallel_output=True (gpt j residual)"

else:
self.dense_4h_to_h = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim_in,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
)


self.use_axonn_model_parallelism = neox_args.use_axonn_model_parallelism

def forward(self, hidden_states):

# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.use_axonn_model_parallelism:
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states,
scatter_input=False, gather_output=False)
else:
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if (
self.activation_type == "gelu" and self.bias_gelu_fusion
Expand All @@ -130,7 +159,11 @@ def forward(self, hidden_states):
)

# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
if self.use_axonn_model_parallelism:
output, output_bias = self.dense_4h_to_h(intermediate_parallel,
scatter_input=False, gather_output=False)
else:
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias


Expand Down Expand Up @@ -162,6 +195,9 @@ def __init__(

ff_dim = int(2 * neox_args.hidden_size * 4 / 3)
ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)

assert not neox_args.use_axonn_model_parallelism, "ToDo: Implement AxoNN TP for LLaMAParallelMLP"

self.w1 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
Expand Down Expand Up @@ -275,7 +311,10 @@ def __init__(
self.attention_softmax_in_fp32 = True
self.layer_number = layer_number
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
if neox_args.use_axonn_model_parallelism:
world_size = neox_args.row_model_parallel_size
else:
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size)
self.hidden_size_per_attention_head = mpu.divide(
neox_args.hidden_size, neox_args.num_attention_heads
Expand All @@ -286,14 +325,24 @@ def __init__(
self.pos_emb = neox_args.pos_emb

# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
)
self.use_axonn_model_parallelism = neox_args.use_axonn_model_parallelism
if neox_args.use_axonn_model_parallelism:
self.query_key_value = Linear(
in_features=neox_args.hidden_size,
out_features=3 * neox_args.hidden_size,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
skip_bias_add=True
)
else:
self.query_key_value = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
)

coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -377,16 +426,27 @@ def __init__(
self.attention_dropout = nn.Dropout(self.dropout_p)

# Output.
self.dense = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=neox_args.use_bias_in_attn_linear,
)
if neox_args.use_axonn_model_parallelism:
self.dense = Linear(
in_features=neox_args.hidden_size,
out_features=neox_args.hidden_size,
init_method=output_layer_init_method,
skip_bias_add=True,
bias=neox_args.use_bias_in_attn_linear,
transpose=True
)
assert not parallel_output, "ToDO: Implement axonn support for parallel_output=True (gpt j residual)"
else:
self.dense = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=neox_args.use_bias_in_attn_linear,
)

def attention(
self, query_layer, key_layer, value_layer, layer_past, attention_mask
Expand Down Expand Up @@ -625,7 +685,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
# =====================

# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
if self.use_axonn_model_parallelism:
mixed_x_layer, _ = self.query_key_value(hidden_states, scatter_input=False, gather_output=False)
else:
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
Expand Down Expand Up @@ -710,7 +773,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer)
if self.use_axonn_model_parallelism:
output, bias = self.dense(context_layer, scatter_input=False, gather_output=False)
else:
output, bias = self.dense(context_layer)

if self.use_cache:
output = [output, present]
Expand Down Expand Up @@ -739,11 +805,17 @@ def __init__(

super().__init__()
self.layer_number = layer_number
self.is_first_layer = ( layer_number == 0 )
self.is_last_layer = ( layer_number == neox_args.num_layers - 1 )

norm, eps = get_norm(neox_args)

# Layernorm on the input data.
self.input_layernorm = norm(neox_args.hidden_size, eps=eps)
if neox_args.use_axonn_model_parallelism:
self.input_layernorm = norm(mpu.divide(neox_args.hidden_size,
neox_args.column_model_parallel_size), eps=eps)
else:
self.input_layernorm = norm(neox_args.hidden_size, eps=eps)
self.use_cache = use_cache

self.hidden_dropout = neox_args.hidden_dropout
Expand Down Expand Up @@ -771,7 +843,11 @@ def __init__(
# Layernorm on the output of the attention layer.
# If GPT-J residuals are used, this is surpurfulous but leaving it in
# leads to cleaner code
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)
if neox_args.use_axonn_model_parallelism:
self.post_attention_layernorm = norm(mpu.divide(neox_args.hidden_size,
neox_args.column_model_parallel_size), eps=eps)
else:
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)

# MLP
if neox_args.mlp_type == "regular":
Expand Down Expand Up @@ -807,6 +883,9 @@ def _get_bias_dropout(self):
def forward(self, x, attention_mask, layer_past=None):
layer_past = layer_past if layer_past is not None else self.layer_past
bias_dropout_fn = self._get_bias_dropout()

if self.is_first_layer:
x = drop(x, batch_dim=1)
# x: [b, s, h]
if self.gpt_j_residual:
# pseudocode:
Expand Down Expand Up @@ -904,6 +983,8 @@ def forward(self, x, attention_mask, layer_past=None):
prob=self.hidden_dropout,
)

if self.is_last_layer:
output = gather(output, batch_dim=1)
return output


Expand Down
Loading
Loading