Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add intermediate_size to GPT-NeoX models #1212

Merged
merged 11 commits into from
Sep 7, 2024
2 changes: 1 addition & 1 deletion configs/llama/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
2 changes: 1 addition & 1 deletion configs/llama/30B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
2 changes: 1 addition & 1 deletion configs/llama/65B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
2 changes: 1 addition & 1 deletion configs/llama/7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
"activation": "swiglu",
}
12 changes: 6 additions & 6 deletions megatron/data/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}

} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down Expand Up @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down
38 changes: 17 additions & 21 deletions megatron/model/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,23 @@


def get_activation(neox_args):
"""retrieves the activation function specified in neox_args"""
"""retrieves the activation function specified in neox_args and whether or not the activation is gated"""
is_gated = False
if neox_args.activation == "geglu":
activation_func = GEGLU(neox_args=neox_args)
is_gated = True
activation_func = F.gelu
elif neox_args.activation == "reglu":
is_gated = True
activation_func = F.relu
elif neox_args.activation == "bilinear":
is_gated = True
activation_func = lambda x: x
elif neox_args.activation == "swiglu":
is_gated = True
activation_func = swish
elif neox_args.activation == "glu":
is_gated = True
activation_func = F.sigmoid
elif neox_args.activation == "gelu":
if neox_args.onnx_safe and neox_args.bias_gelu_fusion:
raise ValueError("onnx_safe + bias_gelu_fusion not compatible")
Expand All @@ -49,7 +63,7 @@ def get_activation(neox_args):
activation_func = F.silu
else:
raise ValueError(f"Activation function {neox_args.activation} not recognized")
return activation_func
return activation_func, is_gated


###### BIAS GELU FUSION/ NO AUTOGRAD ################
Expand Down Expand Up @@ -119,21 +133,3 @@ def swish(x, beta: float = 1.0):
@torch.jit.script
def mish(x):
return x * torch.tanh(F.softplus(x))


class GEGLU(torch.nn.Module):
def __init__(self, neox_args):
super(GEGLU, self).__init__()
if neox_args.onnx_safe:
self.activation_func = erf_gelu
else:
self.activation_func = F.gelu

def forward(self, x, bias=None):
x, gate = x.chunk(2, dim=-1)
if bias is not None:
bias_1, bias_2 = bias.chunk(2, dim=-1)
x = x + bias_1
gate = gate + bias_2
intermediate_parallel = self.activation_func(gate)
return intermediate_parallel * x
2 changes: 1 addition & 1 deletion megatron/model/gmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
init_method=init_method,
skip_bias_add=True,
)
self.activation_func = get_activation(neox_args)
self.activation_func, _ = get_activation(neox_args)
ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size())
if neox_args.attention_config[layer_number] == "amlp":
d_attn = neox_args.gmlp_attn_dim
Expand Down
9 changes: 7 additions & 2 deletions megatron/model/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,17 @@ def __init__(
neox_args.mamba_use_bias_in_linears and neox_args.mamba_inner_func_fusion
), "Mamba fused inner fn and bias in x_proj not compatible!"

assert neox_args.intermediate_size == None or neox_args.expansion_factor == None, "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections"

# set variables, mostly following mamba defaults
self.d_model = neox_args.hidden_size
self.d_state = 16 # state dimensions per channel
self.d_conv = 4 # convolution width
self.expand = 2 # linear projection expansion factors
self.d_inner = int(self.expand * self.d_model)
jahatef marked this conversation as resolved.
Show resolved Hide resolved
if neox_args.intermediate_size:
self.d_inner = neox_args.intermediate_size
else:
self.expand = neox_args.expansion_factor if neox_args.expansion_factor else 2
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) # rank of dt / Delta parameter
self.dt_scale = 1.0

Expand Down
17 changes: 11 additions & 6 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def __init__(self, neox_args, layer_number):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
jahatef marked this conversation as resolved.
Show resolved Hide resolved
self.key = nn.Linear(neox_args.hidden_size, neox_args.ffn_dim, bias=False)
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.hidden_size, bias=False
)
self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)
self.value = nn.Linear(neox_args.ffn_dim, neox_args.hidden_size, bias=False)

def forward(self, x):
xx = self.time_shift(x) - x
Expand All @@ -275,14 +275,19 @@ def __init__(self, neox_args, layer_number):
self.layer_number = layer_number
self.fp16 = neox_args.precision == "fp16"
self.bf16 = neox_args.precision == "bfloat16"
assert neox_args.intermediate_size == None or neox_args.expansion_factor == None, "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections"
if not hasattr(neox_args, "dim_att"):
neox_args.dim_att = neox_args.hidden_size
if not hasattr(neox_args, "dim_ffn"):
# Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic
neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32)
if neox_args.intermediate_size:
neox_args.ffn_dim = neox_args.intermediate_size
else:
self.expand = neox_args.expansion_factor if neox_args.expansion_factor else 3.5
neox_args.ffn_dim = int(self.expand * neox_args.hidden_size)
# Make hidden size 3.5x by default. Round to nearest multiple of 32 until we add hdim rounding logic
neox_args.ffn_dim = int(neox_args.ffn_dim // 32 * 32)
assert neox_args.hidden_size % 32 == 0
assert neox_args.dim_att % 32 == 0
assert neox_args.dim_ffn % 32 == 0
assert neox_args.ffn_dim % 32 == 0
self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads
self.head_size = self.neox_args.head_size
self.num_attention_heads = neox_args.num_attention_heads
Expand Down
Loading
Loading