## Install mergoo

In [None]:
!pip install mergoo

## Build Mixture of Adapters (MoE on LoRAs) Checkpoint

### Selecting Experts:
The advantage of mergoo is that you can leverage the fine-tuned experts, which excel in their respective domains. In this notebook, we merge three LoRA fine-tuned experts into a unified mixture-of-adapters architecture, which can be later fine-tuned.

We choose experts that are trained on customer support datasets: 
- [predibase/customer_support](https://huggingface.co/predibase/customer_support)
- [predibase/customer_support_accounts](https://huggingface.co/predibase/customer_support_accounts)
- [predibase/customer_support_orders](https://huggingface.co/predibase/customer_support_orders)

### Preparing Config:
- `model_type`: ```llama/mistral/bert```. This is the base model family of the experts. At the moment, all the experts should come from the same base model family.
- `num_experts_per_tok`: Total number of active experts at each step. These experts are selected sparsely.
- `base_model`: Model id for the base model. Make sure that all the LaRA experts are having an common base model, since the mixture-of-adapters configuration trains a router on top of LoRA, while keeping rest all of the layers frozen.   
- `experts`: List of dictionaries of seed LoRA models for merging. For each expert, `model_id` is mandatory. The model_id can be either a local path or a Huggingface model id.

In [1]:
import torch
from mergoo.compose_experts import ComposeExperts

model_id = "data/mistral_lora_moe"
config = {
    "model_type": "mistral",
    "num_experts_per_tok": 2,
    "base_model": "mistralai/Mistral-7B-v0.1",
    "experts": [
        {
            "expert_name": "adapter_1", 
            "model_id": "predibase/customer_support"
        },
        {
            "expert_name": "adapter_2", 
            "model_id": "predibase/customer_support_accounts"
        },
        {
            "expert_name": "adapter_3", 
            "model_id": "predibase/customer_support_orders"
        }
    ],
}
# create checkpoint
expertmerger = ComposeExperts( config, torch_dtype=torch.bfloat16 )
expertmerger.compose()
expertmerger.save_checkpoint(  "data/mistral_lora_moe" )

MoE Layer Index : [*]


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 675/675 [00:00<00:00, 300962.60it/s]


count_averaged_layers : 227
count_router_layers : 448
count_total_router_layers : 448
The model is bigger than the maximum size per checkpoint (9GB) and is going to be split in 2 checkpoint shards. You can find where each parameters has been saved in the index located at data/mistral_lora_moe/model.safetensors.index.json.
checkpoint saved at data/mistral_lora_moe


### Training

Now that we have created the checkpoint of mrged model, all the layers are pretrained except for the gating/routing layers that are added. The routing layer selects the top K experts (K=2, here). We support HuggingFace trainers: Trainer, SFTrainer.  
In this example, we are using ```Python_code_instructions_18k_alpaca``` dataset for finetuning. We will train only the router layers, keeping all the other layers frozen.

In [1]:
# load the composed checkkpoint
import torch
from mergoo.models.modeling_mistral import MistralForCausalLM

model = MistralForCausalLM.from_pretrained(
    "data/mistral_lora_moe",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)# 'gate' / router layers are untrained hence loaded warning would appeare for them

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of MistralForCausalLM were not initialized from the model checkpoint at data/mistral_lora_moe and are newly initialized: ['model.layers.0.self_attn.q_proj.gate.weight', 'model.layers.0.self_attn.v_proj.gate.weight', 'model.layers.1.self_attn.q_proj.gate.weight', 'model.layers.1.self_attn.v_proj.gate.weight', 'model.layers.10.self_attn.q_proj.gate.weight', 'model.layers.10.self_attn.v_proj.gate.weight', 'model.layers.11.self_attn.q_proj.gate.weight', 'model.layers.11.self_attn.v_proj.gate.weight', 'model.layers.12.self_attn.q_proj.gate.weight', 'model.layers.12.self_attn.v_proj.gate.weight', 'model.layers.13.self_attn.q_proj.gate.weight', 'model.layers.13.self_attn.v_proj.gate.weight', 'model.layers.14.self_attn.q_proj.gate.weight', 'model.layers.14.self_attn.v_proj.gate.weight', 'model.layers.15.self_attn.q_proj.gate.weight', 'model.layers.15.self_attn.v_proj.gate.weight', 'model.layers.16.self_attn.q_proj.gate.weight', 'model.layers.16.self_attn.v_proj.gate.weight', 'mode

In [2]:
# train only router (gating) layers
n_weights, n_router_weights  = 0,0
for name, weight in model.named_parameters():
    if "gate" in name:
        weight.requires_grad_(True)
        n_router_weights += 1
    else:
        weight.requires_grad_(False)
    n_weights += 1
n_weights, n_router_weights

(739, 96)

In [19]:
import datasets
import random

dataset = datasets.load_dataset("bitext/Bitext-customer-support-llm-chatbot-training-dataset")['train']
# 90% train, 10% test + validation
train_testvalid = dataset.train_test_split(test_size=0.1)
# Split the 10% test + valid in 0.7 test, half valid
test_valid = train_testvalid['test'].train_test_split(test_size=0.7)
# gather everyone if you want to have a single DatasetDict
dataset = datasets.DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'val': test_valid['train']})

In [20]:
dataset

DatasetDict({
    train: Dataset({
        features: ['flags', 'instruction', 'category', 'intent', 'response'],
        num_rows: 24184
    })
    test: Dataset({
        features: ['flags', 'instruction', 'category', 'intent', 'response'],
        num_rows: 1882
    })
    val: Dataset({
        features: ['flags', 'instruction', 'category', 'intent', 'response'],
        num_rows: 806
    })
})

In [21]:
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['response'][i]}"
        output_texts.append(text)
    return output_texts

In [23]:
from trl import SFTTrainer
from transformers import TrainingArguments

trainer_args = TrainingArguments(
    output_dir= "data/mistral_cs_lora_moe",
    per_device_train_batch_size = 1,
    per_device_eval_batch_size = 1, 
    learning_rate= 1e-5,
    save_total_limit=1,
    num_train_epochs=1,
    eval_steps= 5000,
    logging_strategy="steps",
    logging_steps= 25,
    gradient_accumulation_steps=4,
    bf16=True
)

trainer = SFTTrainer(
    model,
    args= trainer_args,
    train_dataset= dataset['train'],
    eval_dataset= dataset['val'],
    formatting_func=formatting_prompts_func,
    max_seq_length=512
)

Map:   0%|          | 0/24184 [00:00<?, ? examples/s]

Map:   0%|          | 0/806 [00:00<?, ? examples/s]



In [24]:
trainer.train()

Step,Training Loss
25,1.0562
50,0.8857
75,0.842
100,0.8277
