# Build Your own MoE with LLaMa3-based Experts

## Install mergoo

In [None]:
!pip install mergoo

## Selecting Experts:

With mergoo, you can easily build your own MoE-style LLM by integrating LLaMa3-based experts.

In this tutorial, we have used following LLaMa3-based models:
- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B): The generic LLaMa3 8b model, provided by Meta team.
- [Locutusque/Llama-3-Orca-1.0-8B](https://huggingface.co/Locutusque/Llama-3-Orca-1.0-8B): Fine-tuned LLaMa3 8b on [SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca) for enhancing performance in math, coding, and writing.
- [mlabonne/OrpoLlama-3-8B](https://huggingface.co/mlabonne/OrpoLlama-3-8B): ORPO fine-tuned of LLaMa3 8b on 1k samples of [ORPO dataset](https://huggingface.co/datasets/mlabonne/orpo-dpo-mix-40k).



**Preparing Config:**
- `model_type`: `llama`
- `num_experts_per_tok`: Total number of active experts at each step. These experts are selected sparsely.
- `experts`: List of dictionaries of seed models that would get merged. For each expert, `model_id` is mandatory. The model_id can be either a local path or a Huggingface model id.
- `router_layers`: These are the layer names that would be replaced with MOE layers. Weights of the rest of the layers are aggregated using averaging. In the future, we will support multiple aggregation methods from MergeKit.
- `router_layers_index`: List of indexes. These are the indexes of transformer blocks, layers of these index would be converted to MOE. Default `router_layers_index` is empty meaning the MOE conversion gets applied on all the layers, given that `router_layers` identifier matches. `[None]` can be used when no MOE layer should be kept following the [BTM](https://arxiv.org/abs/2208.03306) architecture.

In [2]:
import torch
from mergoo.compose_experts import ComposeExperts

model_id =  "data/checkpoint_demo"
config = \
{
    "model_type": "llama",
    "num_experts_per_tok": 2,
    "experts":[
        {
            "expert_name" : "base_expert",
            "model_id" : "meta-llama/Meta-Llama-3-8B"
        },
        {
            "expert_name" : "expert_1",
            "model_id" : "Locutusque/Llama-3-Orca-1.0-8B"
        },
        {
            "expert_name" : "expert_2",
            "model_id" : "mlabonne/OrpoLlama-3-8B"
        }
    ],
    "router_layers":[
        "gate_proj",
        "up_proj",
        "down_proj"
    ],
}
# create checkpoint
expertmerger = ComposeExperts( config, torch_dtype=torch.float16 )
expertmerger.compose()
expertmerger.save_checkpoint("data/llama3_moe")

MoE Layer Index : [*]


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.04it/s]


merging expert : unsloth/llama-3-8b


100%|██████████| 291/291 [00:06<00:00, 47.77it/s]
Loading checkpoint shards: 100%|██████████| 9/9 [00:04<00:00,  1.92it/s]


merging expert : Locutusque/Llama-3-Orca-1.0-8B


100%|██████████| 291/291 [00:06<00:00, 45.51it/s]
Downloading shards: 100%|██████████| 4/4 [00:29<00:00,  7.47s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.57it/s]


merging expert : mlabonne/OrpoLlama-3-8B


100%|██████████| 291/291 [00:06<00:00, 44.84it/s]


count_averaged_layers : 195
count_router_layers : 96
count_total_router_layers : 288
The model is bigger than the maximum size per checkpoint (9GB) and is going to be split in 5 checkpoint shards. You can find where each parameters has been saved in the index located at data/llama3_moe/model.safetensors.index.json.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


checkpoint saved at data/llama3_moe


## Training

Now that we have created an MOE checkpoint, all the layers of this model are pretrained except for the gating/routing layers that we added. The routing layer selects the top K experts, in our case K=2. We support HuggingFace trainers: Trainer, SFTrainer. In this example, we are using the [alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset for finetuning. We will train only the router layers, keeping all the other layers frozen.

In [3]:
# load the composed checkkpoint
import torch
from mergoo.models.modeling_llama import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(
    "data/llama3_moe", 
    device_map="auto", 
    torch_dtype=torch.bfloat16,
)# 'gate' / router layers are untrained hence loaded warning would appeare for them

Loading checkpoint shards: 100%|██████████| 5/5 [00:08<00:00,  1.67s/it]
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at data/llama3_moe and are newly initialized: ['model.layers.0.mlp.down_proj.gate.weight', 'model.layers.0.mlp.gate_proj.gate.weight', 'model.layers.0.mlp.up_proj.gate.weight', 'model.layers.1.mlp.down_proj.gate.weight', 'model.layers.1.mlp.gate_proj.gate.weight', 'model.layers.1.mlp.up_proj.gate.weight', 'model.layers.10.mlp.down_proj.gate.weight', 'model.layers.10.mlp.gate_proj.gate.weight', 'model.layers.10.mlp.up_proj.gate.weight', 'model.layers.11.mlp.down_proj.gate.weight', 'model.layers.11.mlp.gate_proj.gate.weight', 'model.layers.11.mlp.up_proj.gate.weight', 'model.layers.12.mlp.down_proj.gate.weight', 'model.layers.12.mlp.gate_proj.gate.weight', 'model.layers.12.mlp.up_proj.gate.weight', 'model.layers.13.mlp.down_proj.gate.weight', 'model.layers.13.mlp.gate_proj.gate.weight', 'model.layers.13.mlp.up_proj.gate.weight', 'model.l

In [4]:
# train only router (gating) layers
n_weights, n_router_weights  = 0,0
for name, weight in model.named_parameters():
    if "gate" not in name:
        weight.requires_grad_(False)
        n_router_weights += 1
    n_weights += 1
n_weights, n_router_weights

(579, 387)

In [5]:
import datasets
import random

dataset = datasets.load_dataset("tatsu-lab/alpaca")['train']
dataset = dataset['text']
random.shuffle(dataset)
dataset_train =  datasets.Dataset.from_dict(dict(prompt=dataset[:-1000]))
dataset_test = datasets.Dataset.from_dict(dict(prompt=dataset[-1000:]))

Downloading readme: 100%|██████████| 7.47k/7.47k [00:00<00:00, 22.7MB/s]
Downloading data: 100%|██████████| 24.2M/24.2M [00:00<00:00, 56.3MB/s]
Generating train split: 52002 examples [00:00, 531424.92 examples/s]


In [6]:
dataset_train, dataset_test

(Dataset({
     features: ['prompt'],
     num_rows: 51002
 }),
 Dataset({
     features: ['prompt'],
     num_rows: 1000
 }))

In [7]:
from trl import SFTTrainer
from transformers import TrainingArguments

trainer_args = TrainingArguments(
    output_dir= "checkpoints/llama_moe",
    per_device_train_batch_size = 1,
    per_device_eval_batch_size = 1, 
    learning_rate= 1e-5,
    save_total_limit=1,
    num_train_epochs=1,
    eval_steps= 5000,
    logging_strategy="steps",
    logging_steps= 25,
    gradient_accumulation_steps=4,
    bf16=True
)

trainer = SFTTrainer(
    model,
    args= trainer_args,
    train_dataset= dataset_train,
    eval_dataset= dataset_test,
    dataset_text_field="prompt",
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map: 100%|██████████| 51002/51002 [00:03<00:00, 16822.74 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 18198.36 examples/s]


In [None]:
trainer.train()