Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add intermediate_size to GPT-NeoX models #1212

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion megatron/model/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def __init__(
self.d_state = 16 # state dimensions per channel
self.d_conv = 4 # convolution width
self.expand = 2 # linear projection expansion factors
self.d_inner = int(self.expand * self.d_model)
if neox_args.intermediate_size == None:
neox_args.d_inner = self.expand * self.d_model
else:
neox_args.d_inner = neox_args.intermediate_size
self.dt_rank = math.ceil(self.d_model / 16) # rank of dt / Delta parameter
self.dt_scale = 1.0

Expand Down
12 changes: 7 additions & 5 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def __init__(self, neox_args, layer_number):
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False)
self.key = nn.Linear(neox_args.hidden_size, neox_args.ff_dim, bias=False)
self.receptance = nn.Linear(
neox_args.hidden_size, neox_args.hidden_size, bias=False
)
self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False)
self.value = nn.Linear(neox_args.ff_dim, neox_args.hidden_size, bias=False)

def forward(self, x):
xx = self.time_shift(x) - x
Expand All @@ -277,12 +277,14 @@ def __init__(self, neox_args, layer_number):
self.bf16 = neox_args.precision == "bfloat16"
if not hasattr(neox_args, "dim_att"):
neox_args.dim_att = neox_args.hidden_size
if not hasattr(neox_args, "dim_ffn"):
if neox_args.intermediate_size == None:
# Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic
neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32)
neox_args.ff_dim = int((neox_args.hidden_size * 3.5) // 32 * 32)
else:
neox_args.ff_dim = neox_args.intermediate_size
assert neox_args.hidden_size % 32 == 0
assert neox_args.dim_att % 32 == 0
assert neox_args.dim_ffn % 32 == 0
assert neox_args.ff_dim % 32 == 0
self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads
self.head_size = self.neox_args.head_size
self.num_attention_heads = neox_args.num_attention_heads
Expand Down
20 changes: 13 additions & 7 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,19 @@ def __init__(
self.activation_type = neox_args.activation
self.bias_gelu_fusion = neox_args.bias_gelu_fusion

# auto scale so geglu has equal parameters
ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4
ff_dim = (
int(ff_mult * neox_args.hidden_size) * 2
if self.activation_type == "geglu"
else ff_mult * neox_args.hidden_size
)
#TODO: revisit this when we add swiglu to make sure ff_dim is initialized correctly.
if neox_args.intermediate_size:
ff_dim = neox_args.intermediate_size

else:
# auto scale so geglu has equal parameters
ff_mult = int(4 * 2 / 3) if self.activation_type == "geglu" else 4
ff_dim = (
int(ff_mult * neox_args.hidden_size) * 2
if self.activation_type == "geglu"
else int(ff_mult * neox_args.hidden_size)
)

self.dense_h_to_4h = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
Expand Down
Loading