Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mixtral 8*7B MOE #667

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,14 @@ def _add_network_size_args(parser):
'launch to improve the utilization and performance by '
'leveraging the Grouped GEMM feature introduced since '
'CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).')
group.add_argument('--num-experts-per-tok', type=int, default=2,
help='The num-experts-per-tok parameter specifies the number of MLP experts'
' to use for each input token in the Switch Transformer model')
group.add_argument('--moe-type', type=str, default=None,
help='Extra type of MOE network to use, default None means switch transformers, Optional: mixtral')
group.add_argument('--moe-load-balancing-mode', default=None,
help="Balancing the probability of each expert's vote suppored in mixtral moe."
"Only sinkhorn and None supported now.")
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
return parser
Expand Down
121 changes: 120 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,122 @@ def forward(self, hidden_states):
return output_total, output_bias_total


class MixtralParallelMLP(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.ffn_dim = config.ffn_hidden_size
self.hidden_dim = config.hidden_size

self.w1 = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=False,
gather_output=False,
skip_bias_add=True,
is_expert=False,
)

self.w3 = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=False,
gather_output=False,
skip_bias_add=True,
is_expert=False,
)

self.w2 = tensor_parallel.RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=False,
skip_bias_add=True,
input_is_parallel=True,
is_expert=False,
)

self.act_fn = F.silu

def forward(self, hidden_states):
selects, h = hidden_states.shape
hidden_states = hidden_states.view(selects, 1, h)
current_hidden_states = self.act_fn(self.w1(hidden_states)[0]) * self.w3(hidden_states)[0]
current_hidden_states = self.w2(current_hidden_states)[0].view(selects, h)
return current_hidden_states


class MixtralSparseMoeBlock(MegatronModule):
"""
This is a megatron implementation refer to HuggingFace Mixtral Model.
Which strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""

def __init__(self, config):
super().__init__()
args = get_args()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.ffn_hidden_size
self.num_experts = getattr(args, "num_experts", 8)
self.top_k = getattr(args, "num_experts_per_tok", 2)
self.moe_load_balancing_mode = getattr(args, "moe_load_balancing_mode", None)

# gating
self.gate = torch.nn.Linear(self.hidden_dim, self.num_experts, bias=False)

self.experts = torch.nn.ModuleList(
[MixtralParallelMLP(config) for _ in range(self.num_experts)]
)

def forward(self, hidden_states: torch.Tensor):
s, b, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)

# route: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
if self.moe_load_balancing_mode == "sinkhorn":
router_logits = sinkhorn(router_logits)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)

output_total = torch.zeros_like(hidden_states)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be solicited
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
row, column = torch.where(expert_mask[expert_idx])

if column.shape[0] == 0:
continue

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_hidden_states = expert_layer(hidden_states[column]) * routing_weights[column, row, None]
output_total[column] = output_total[column] + current_hidden_states.to(hidden_states.dtype)

output_total = output_total.view(s, b, h).contiguous()

return output_total, None


class CoreAttention(MegatronModule):

def __init__(self, layer_number, config,
Expand Down Expand Up @@ -899,7 +1015,10 @@ def __init__(self, config,

# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(config)
if args.moe_type == "mixtral":
self.mlp = MixtralSparseMoeBlock(config)
else:
self.mlp = SwitchMLP(config)
else:
self.mlp = ParallelMLP(config)

Expand Down