Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement post training #1407

Open
wants to merge 2 commits into
base: scatter_moe
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/axolotl/monkeypatch/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,37 @@ def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
)

def _post_training(self, model, name):
# get original weights back: reverse the concat + stack in the fused experts
# Get original weights back: reverse the concat + stack in the fused experts
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)

# TODO: recreate MoE class with original weights
experts = []
for i in range(self.num_experts):
pass
# Recreate the structure of the original MixtralSparseMoeBlock
original_moe = nn.Module()
original_moe.hidden_dim = self.hidden_dim
original_moe.ffn_dim = self.ffn_dim
original_moe.num_experts = self.num_experts
original_moe.top_k = self.top_k

# Recreate the gating module
original_moe.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
original_moe.gate.weight.data = self.gate.weight.data

# Recreate the experts as a ModuleList
original_moe.experts = nn.ModuleList()
for expert_idx in range(self.num_experts):
expert = nn.Module()
expert.w1 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
expert.w3 = nn.Linear(self.hidden_dim, 2 * self.ffn_dim, bias=False)
expert.act_fn = self.experts.activation

expert.w1.weight.data = torch.cat([w1s[expert_idx], w3s[expert_idx]], dim=0)
expert.w2.weight.data = w2s[expert_idx]

original_moe.experts.append(expert)

# Replace the SparseMoeBlock with the recreated MixtralSparseMoeBlock structure
setattr(model, name, original_moe)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down