diff --git a/src/axolotl/monkeypatch/moe/moe.py b/src/axolotl/monkeypatch/moe/moe.py index f2a506883..7c651d0a2 100644 --- a/src/axolotl/monkeypatch/moe/moe.py +++ b/src/axolotl/monkeypatch/moe/moe.py @@ -21,14 +21,37 @@ def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k): ) def _post_training(self, model, name): - # get original weights back: reverse the concat + stack in the fused experts + # Get original weights back: reverse the concat + stack in the fused experts w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1) w2s = torch.unbind(self.experts.output_experts.weight, dim=0) - # TODO: recreate MoE class with original weights - experts = [] - for i in range(self.num_experts): - pass + # Recreate the structure of the original MixtralSparseMoeBlock + original_moe = nn.Module() + original_moe.hidden_dim = self.hidden_dim + original_moe.ffn_dim = self.ffn_dim + original_moe.num_experts = self.num_experts + original_moe.top_k = self.top_k + + # Recreate the gating module + original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + original_moe.gate.weight.data = self.gate.weight.data + + # Recreate the experts as a ModuleList + original_moe.experts = nn.ModuleList() + for expert_idx in range(self.num_experts): + expert = nn.Module() + expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False) + expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False) + expert.act_fn = self.experts.activation + + expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0) + expert.w2.weight.data = w2s[expert_idx] + + original_moe.experts.append(expert) + + # Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure + setattr(model, name, original_moe) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape