## Install Mergoo

In [None]:
!pip install mergoo

## Create Mergoo-MOE Checkpoint

**Selecting Experts:**

The advantage of Mergoo is that we can leverage the fine-tuned models, which excel in their respective domains. In this notebook, we merge three experts into a unified MOE architecture that we will later fine-tune.

The Code Llama Python model is a fine-tuned version of Llama 7b on Python programming texts, hence it performs well for Python queries. The WikiChat model uses Wikipedia and a 7-stage pipeline to ensure its responses are factual and minimize hallucinations. Check out their [repository](https://github.com/stanford-oval/WikiChat) for a detailed explanation.

We also add the base model Llama-2-7b-hf as an expert so that the final MOE model is well-balanced and performs excellently in diverse domains. The unified model will have features of all the experts and will perform well in Python programming and wiki questions.

Here we are creating an MOE model using the following models:
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- [codellama/CodeLlama-7b-Python-hf](https://huggingface.co/codellama/CodeLlama-7b-Python-hf)
- [stanford-oval/Llama-2-7b-WikiChat-fused](https://huggingface.co/stanford-oval/Llama-2-7b-WikiChat-fused)

**Preparing Config:**
- `model_type`: llama/mistral/bert. This is the base model family of the experts. At the moment, all the experts should come from the same base model family.
- `num_experts_per_tok`: Total number of active experts at each step. These experts are selected sparsely.
- `experts`: List of dictionaries of seed models that would get merged. For each expert, `model_id` is mandatory. The model_id can be either a local path or a Huggingface model id.
- `router_layers`: These are the layer names that would be replaced with MOE layers. Weights of the rest of the layers are aggregated using averaging. In the future, we will support multiple aggregation methods from MergeKit.
- `router_layers_index`: List of indexes. These are the indexes of transformer blocks, layers of these index would be converted to MOE. Default `router_layers_index` is empty meaning the MOE conversion gets applied on all the layers, given that `router_layers` identifier matches. `[None]` can be used when no MOE layer should be kept following the [BTM](https://arxiv.org/abs/2208.03306) architecture.

In [None]:
import torch
from mergoo.compose_experts import ComposeExperts

model_id = "data/checkpoint_demo"
config = \
{
    "model_type": "llama",
    "num_experts_per_tok": 2,
    "experts":[
        {
            "expert_name" : "base_expert",
            "model_id" : "meta-llama/Llama-2-7b-hf"
        },
        {
            "expert_name" : "expert_1",
            "model_id" : "codellama/CodeLlama-7b-Python-hf"
        },
        {
            "expert_name" : "expert_2",
            "model_id" : "stanford-oval/Llama-2-7b-WikiChat-fused"
        }
    ],
    "router_layers":[
        "gate_proj",
        "up_proj",
        "down_proj"
    ],
}
# create checkpoint
expertmerger = ComposeExperts( config, torch_dtype=torch.float16 )
expertmerger.compose()
expertmerger.save_checkpoint(model_id)

## Training

Now that we have created an MOE checkpoint, all the layers of this model are pretrained except for the gating/routing layers that we added. The routing layer selects the top K experts, in our case K=2. We support HuggingFace trainers: Trainer, SFTrainer. In this example, we are using the Python_code_instructions_18k_alpaca dataset for finetuning. We will train only the router layers, keeping all the other layers frozen.

In [1]:
# load the composed checkkpoint
import torch
from mergoo.models.modeling_llama import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(
    model_id, 
    device_map="auto", 
    torch_dtype=torch.bfloat16,
)# 'gate' / router layers are untrained hence loaded warning would appeare for them

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at data/llama_moe and are newly initialized: ['model.layers.0.mlp.down_proj.gate.weight', 'model.layers.0.mlp.gate_proj.gate.weight', 'model.layers.0.mlp.up_proj.gate.weight', 'model.layers.1.mlp.down_proj.gate.weight', 'model.layers.1.mlp.gate_proj.gate.weight', 'model.layers.1.mlp.up_proj.gate.weight', 'model.layers.10.mlp.down_proj.gate.weight', 'model.layers.10.mlp.gate_proj.gate.weight', 'model.layers.10.mlp.up_proj.gate.weight', 'model.layers.11.mlp.down_proj.gate.weight', 'model.layers.11.mlp.gate_proj.gate.weight', 'model.layers.11.mlp.up_proj.gate.weight', 'model.layers.12.mlp.down_proj.gate.weight', 'model.layers.12.mlp.gate_proj.gate.weight', 'model.layers.12.mlp.up_proj.gate.weight', 'model.layers.13.mlp.down_proj.gate.weight', 'model.layers.13.mlp.gate_proj.gate.weight', 'model.layers.13.mlp.up_proj.gate.weight', 'model.layers.14.mlp.down_proj.gate.weight', 'model.layers.14.mlp.gate_proj.gate.w

In [2]:
# train only router (gating) layers
n_weights, n_router_weights  = 0,0
for name, weight in model.named_parameters():
    if "gate" not in name:
        weight.requires_grad_(False)
        n_router_weights += 1
    n_weights += 1
n_weights, n_router_weights

(579, 387)

In [3]:
import datasets
import random

dataset = datasets.load_dataset("iamtarun/python_code_instructions_18k_alpaca")['train']
dataset = dataset['prompt']
random.shuffle(dataset)
dataset_train =  datasets.Dataset.from_dict(dict(prompt=dataset[:-1000]))
dataset_test = datasets.Dataset.from_dict(dict(prompt=dataset[-1000:]))

In [4]:
dataset_train, dataset_test

(Dataset({
     features: ['prompt'],
     num_rows: 25085
 }),
 Dataset({
     features: ['prompt'],
     num_rows: 1000
 }))

In [7]:
from trl import SFTTrainer
from transformers import TrainingArguments

trainer_args = TrainingArguments(
    output_dir= "checkpoints/llama_moe",
    per_device_train_batch_size = 1,
    per_device_eval_batch_size = 1, 
    learning_rate= 1e-5,
    save_total_limit=1,
    num_train_epochs=1,
    eval_steps= 5000,
    logging_strategy="steps",
    logging_steps= 25,
    gradient_accumulation_steps=4,
    bf16=True
)

trainer = SFTTrainer(
    model,
    args= trainer_args,
    train_dataset= dataset_train,
    eval_dataset= dataset_test,
    dataset_text_field="prompt",
)

Map:   0%|          | 0/25085 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [8]:
trainer.train()

Step,Training Loss
25,2.7135
50,1.8588
75,1.4371
100,1.2607
