# Federated Tuning With FedMKT methods in FATE-LLM

In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the "FedMKT" module, specifically designed for federated learning with large language models. FedMKT introduces a novel
federated mutual knowledge transfer framework that enables effective knowledge transfer between an LLM deployed on the server and SLMs residing on clients.



The Algorithm is based on paper ["FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models"](https://arxiv.org/pdf/2406.02224), We integrate its code into the FATE-LLM framework.  


## Experiments

Chapter List: 
* settings
  1. DataSet: ARC-Challenge
  2. Models Use in "FEDMKT" Paper
  3. Prepare Optimal Vocabulary Mapping Tables
  4. Training LLMs with Lora
* experiment examples:
  1. Running FEDMKT With Launcher (Experimential Using): 4-SLMs
  2. Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)
  3. Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)
  4. Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT
  5. Running FEDMKT with Pipeline (Industrial Using)

### Dataset: ARC-Challenge

ARC-Challenge is a dataset of 7,787 genuine grade-school level, multiple-choice science questions, assembled to encourage research in advanced question-answering. 

You can refer to following link for more details about [ARC-Challange](https://huggingface.co/datasets/allenai/ai2_arc)

In this section, we will download ARC-Challenge dataset from huggingface and splits it into five parts, part "common" for public dataset and other parts for slms(opt2, gpt2, llama, opt)'s  training. 

In [None]:
import datasets


data = datasets.load_dataset("ai2_arc", "ARC-Challenge", download_mode="force_redownload", ignore_verifications=True)
train_data = data.pop("train")

seed=123
n = train_data.shape[0]
client_num = 4
process_data_output_dir = "" # processed data saved directory should be specified, it will be used in later.

client_data_num = n // (client_num + 1)

for i in range(client_num):
    splits = train_data.train_test_split(train_size=client_data_num, shuffle=True, seed=seed)
    client_name = f"client_{i}"
    data[client_name] = splits["train"]
    train_data = splits["test"]

if train_data.shape[0] == client_data_num:
    data["common"] = train_data
else:
    data["common"] = train_data.train_test_split(
        train_size=client_data_num, shuffle=True, seed=args.seed
    )["train"]

data.save_to_disk(process_data_output_dir)


### Models Use In "FEDMKT" Paper

LLM: [Llama-2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf)  
SLM-0: [opt-1.3b](https://huggingface.co/facebook/opt-1.3b)  
SLM-1: [gpt2-xlarge](https://huggingface.co/openai-community/gpt2-xl)  
SLM-2: [Llama-1.3b](https://huggingface.co/princeton-nlp/Sheared-LLaMA-1.3B)  
SLM-3: [bloom-1.1B](https://huggingface.co/bigscience/bloom-1b1)

Users should download the models from huggingface before the following steps and saved them in local directories, as models are too big, redownload them cost too much times.


In [14]:
# replaoce the names of models to local save directories
llm_pretrained_path = "llama-2-7b-hf"
slm_0_pretrained_path = "opt-1.3b"
slm_1_pretrained_path = "gpt2-xl"
slm_2_pretrained_path = "Sheared-LLaMA-1.3B"
slm_3_pretrained_path = "bloom-1b1"


### Prepare Optimal Vocabulary Mapping Tables

To use "FEDMKT" for federated knowledge transfer, we need to build pptimal vocabulary mapping tables first.
In paper of "FEDMKT", it has One LLM and four SLMs, so we need to build eight pptimal vocabulary mapping tables. For each paired of (LLM, SLM), two tables should be built as co-training are needed.


In [None]:
from fate_llm.algo.fedmkt.token_alignment.vocab_mapping import get_vocab_mappings


llm_slm_pairs = [
    (llm_pretrained_path, slm_0_pretrained_path),
    (llm_pretrained_path, slm_1_pretrained_path),
    (llm_pretrained_path, slm_2_pretrained_path),
    (llm_pretrained_path, slm_3_pretrained_path)
]

vocab_mapping_directory = "" # replace this to actually paths

slm_to_llm_vocab_mapping_paths = ["opt_to_llama.json", "gpt2_to_llama.json", "llama_small_to_llama.json", "bloom_to_llama.json"]
llm_to_slm_vocab_mapping_paths = ["llama_to_opt.json", "llama_to_gpt2.json", "llama_to_llama_small", "llama_to_bloom.json"]

for idx in range(4):
    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + llm_to_slm_vocab_mapping_paths[idx]

for idx, (llm_pretrained, slm_pretrained) in enumerate(llm_slm_pairs):
    slm_to_llm_vocab_mapping_path = slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_path = llm_to_slm_vocab_mapping_paths[idx]
    _ = get_vocab_mappings(slm_pretrained, llm_pretrained, slm_to_llm_vocab_mapping_paths[idx], num_processors=16)
    _ = get_vocab_mappings(llm_pretrained, slm_pretrained, llm_to_slm_vocab_mapping_paths[idx], num_processors=16)
    

### Training LLMs with Lora

In this section, We will introduce the lora configs use in five models listed in paper: one LLM (Llama-2-7B), four SLMs(opt-1.3B, gpt2-xlarge, Llama-1.3B, bloom-1.1B)


LLM models with peft is located on fate_llm/model_zoo, we will give a guide to use them. 

#### Init LLm Llama-2-7B's Lora Config

In [None]:
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
)


#### Init SLMs Lora Config

In [None]:
slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]
slm_lora_target_modules = [
    ["q_proj", "v_proj"],
    ["c_attn"],
    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    ["query_key_value"]
]

def get_slm_conf(slm_idx):
    slm_pretrained_path = slm_pretrained_paths[slm_idx]
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=slm_lora_target_modules[slm_idx]
    )

### Running FEDMKT With Launcher (Experimential Using): 4-SLMs

Using launcher to startup is mainly for experimential. Before running this section, make sure that [FATE-LLM Standalone](https://github.com/FederatedAI/FATE-LLM?tab=readme-ov-file#standalone-deployment) has been deployed.

#### Global Settings

In [18]:
process_data_output_dir = ""
llm_pretrained_path = "Llama-2-7b-hf"
slm_0_pretrained_path = "opt-1.3b"
slm_1_pretrained_path = "gpt2-xl"
slm_2_pretrained_path = "Sheared-LLaMa-1.3B"
slm_3_pretrained_path = "bloom-1b1"
llm_slm_pairs = [
    (llm_pretrained_path, slm_0_pretrained_path),
    (llm_pretrained_path, slm_1_pretrained_path),
    (llm_pretrained_path, slm_2_pretrained_path),
    (llm_pretrained_path, slm_3_pretrained_path)
]

vocab_mapping_directory = ""

slm_to_llm_vocab_mapping_paths = ["opt_to_llama.json", "gpt2_to_llama.json", "llama_small_to_llama.json", "bloom_to_llama.json"]
llm_to_slm_vocab_mapping_paths = ["llama_to_opt.json", "llama_to_gpt2.json", "llama_to_llama_small", "llama_to_bloom.json"]

for idx in range(4):
    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + llm_to_slm_vocab_mapping_paths[idx]

#### all variables has been defined above

slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]
slm_lora_target_modules = [
    ["q_proj", "v_proj"],
    ["c_attn"],
    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    ["query_key_value"]
]

global_epochs = 1
batch_size=4
llm_lr = 3e-5
slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]

#### Init FEDMKTLLM Runner

In [None]:
In this Section, we will introduce how to initialize "FEDMKTLLM" object.

##### Step1: Initialize LLM With LoraConfig

In [None]:
from peft import LoraConfig, TaskType
from fate_llm.model_zoo.pellm.llama import LLaMa
from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM
from fate.ml.nn.homo.fedavg import FedAVGArguments
from fate_llm.dataset.qa_dataset import QaDataset
from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
from transformers import AutoConfig

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
)

model = LLaMa(
    pretrained_path=llm_pretrained_path,
    peft_type="LoraConfig",
    peft_config=lora_config.to_dict(),
    torch_dtype="bfloat16"    
)


##### Step2: Specify Public Dataset

In [None]:
pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,
                     dataset_name="arc_challenge",
                     data_part="common",
                     seq_max_len=512,
                     need_preprocess=True)
pub_data.load(process_data_output_dir)

##### Step3: Initialize FEDMKT Training Args

In [None]:
training_args = FedMKTTrainingArguments(
    global_epochs=global_epochs,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=batch_size,
    learning_rate=llm_lr,
    output_dir="./",
    dataloader_num_workers=4,
    remove_unused_columns=False,
    warmup_ratio=0.008,
    lr_scheduler_type="cosine",
    optim="adamw_torch",
    adam_beta1=0.9,
    adam_beta2=0.95,
    weight_decay=0.1,
    max_grad_norm=1.0,
    use_cpu=False,
    vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size, # pay attention to this, 
                                                                           # vocab_size must be specified to avoid dimension mismatch 
                                                                           # of tokenizer's vocab_size
)

##### Step4: Initialize Other Variables

In [None]:
fed_args = FedAVGArguments(
    aggregate_strategy='epoch',
    aggregate_freq=1
)

slm_to_llm_vocab_mapping = []
for path in slm_to_llm_vocab_mapping_paths:
    with open(path, "r") as fin:
        vocab_mapping = json.loads(fin.read())
        slm_to_llm_vocab_mapping.append(vocab_mapping)

slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]
tokenizer = get_tokenizer(llm_pretrained_path)

##### Step5: New FEDMKTLLM Object

In [None]:
trainer = FedMKTLLM(
    ctx=ctx,
    model=model,
    training_args=training_args,
    fed_args=fed_args,
    train_set=pub_data,
    tokenizer=tokenizer,
    slm_tokenizers=slm_tokenizers,
    slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,
    save_trainable_weights_only=True, # save lora weights only
)

##### Step6: Training And Save Results

In [None]:
trainer.train()
trainer.save_model(output_dir="fill the path to save llm finetuning result")

#### Init FEDMKTSLM Runner

FEDMKTSLM Runner is a slightly different of FEDMKTLLM Runner, we only introduce different variables

##### Import SLMs you need to run, here we choose four Slms Using In Original Paper.

In [None]:
import transformers
from peft import LoraConfig, TaskType    
from fate_llm.model_zoo.pellm.llama import LLaMa
from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM
from fate_llm.model_zoo.pellm.opt import OPT
from fate_llm.model_zoo.pellm.bloom import Bloom
from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM
from fate_llm.dataset.qa_dataset import QaDataset
from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
from transformers import AutoConfig

slm_idx = 0

slm_model_class = [
    OPT,
    GPT2CLM,
    LLaMa,
    Bloom
]
    
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
    target_modules=slm_lora_target_modules[slm_idx]
)

model = slm_model_class[slm_idx](
    pretrained_path=slm_pretrained_paths[slm_idx],
    peft_type="LoraConfig",
    peft_config=lora_config.to_dict(),
    torch_dtype="bfloat16"
)

##### Specify Private Dataset

In [None]:
priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                      dataset_name="arc_challenge",
                      data_part=f"client_{slm_idx}",
                      seq_max_len=512,
                      need_preprocess=True)
priv_data.load(process_data_output_dir)

##### Other Variables 

In [None]:
tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])

import json
with open(llm_to_slm_vocab_mapping_paths[slm_idx], "r") as fin:
    vocab_mapping = json.loads(fin.read())

##### New FEDMKTSLM Object

In [None]:
trainer = FedMKTSLM(
    ctx=ctx,
    model=model,
    training_args=training_args,
    fed_args=fed_args,
    pub_train_set=pub_data,
    priv_train_set=priv_data,
    tokenizer=tokenizer,
    save_trainable_weights_only=True, # save lora weights only
    llm_tokenizer=get_tokenizer(llm_pretrained_path), # different with LLM setting
    llm_to_slm_vocab_mapping=vocab_mapping, # different with LLM setting
    data_collator=transformers.DataCollatorForSeq2Seq(tokenizer) # use to train private dataset
)

#### Complete Code To DO SFT With 4 SLMs

Please paste the code in "fedmkt_4_slms.py" and execute it with the following command

```python
python fedmkt_4_slms.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO
```

In [None]:
# fedmkt_4_slms.py

import os

from fate.arch import Context
from fate.arch.launchers.multiprocess_launcher import launch
import json

process_data_output_dir = ""
llm_pretrained_path = "Llama-2-7b-hf"
slm_0_pretrained_path = "opt-1.3b"
slm_1_pretrained_path = "gpt2-xl"
slm_2_pretrained_path = "Sheared-LLaMa-1.3B"
slm_3_pretrained_path = "bloom-1b1"
llm_slm_pairs = [
    (llm_pretrained_path, slm_0_pretrained_path),
    (llm_pretrained_path, slm_1_pretrained_path),
    (llm_pretrained_path, slm_2_pretrained_path),
    (llm_pretrained_path, slm_3_pretrained_path)
]

vocab_mapping_directory = ""

slm_to_llm_vocab_mapping_paths = ["opt_to_llama.json", "gpt2_to_llama.json", "llama_small_to_llama.json", "bloom_to_llama.json"]
llm_to_slm_vocab_mapping_paths = ["llama_to_opt.json", "llama_to_gpt2.json", "llama_to_llama_small", "llama_to_bloom.json"]

for idx in range(4):
    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + llm_to_slm_vocab_mapping_paths[idx]

slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]
slm_lora_target_modules = [
    ["q_proj", "v_proj"],
    ["c_attn"],
    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    ["query_key_value"]
]

global_epochs = 5
batch_size=4
llm_lr = 3e-5
slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]

llm_model_saved_directory = "./models/fedmkt_4_slms_llm_model"
slm_models_saved_directory = [
    "./models/fedmkt_4_slms_slm_0", 
    "./models/fedmkt_4_slms_slm_1", 
    "./models/fedmkt_4_slms_slm_2", 
    "./models/fedmkt_4_slms_slm_3"
]


def train_llm(ctx):
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.llama import LLaMa
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
    )

    model = LLaMa(
        pretrained_path=llm_pretrained_path,
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=llm_lr,
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,
    )

    slm_to_llm_vocab_mapping = []
    for path in slm_to_llm_vocab_mapping_paths:
        with open(path, "r") as fin:
            vocab_mapping = json.loads(fin.read())
            slm_to_llm_vocab_mapping.append(vocab_mapping)

    slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]

    tokenizer = get_tokenizer(llm_pretrained_path)
    trainer = FedMKTLLM(
        ctx=ctx,
        model=model,
        training_args=training_args,
        train_set=pub_data,
        tokenizer=tokenizer,
        slm_tokenizers=slm_tokenizers,
        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,
        save_trainable_weights_only=True,
    )

    trainer.train()
    trainer.save_model(llm_model_saved_directory)


def train_slm(ctx, slm_idx):
    import transformers
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.llama import LLaMa
    from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM
    from fate_llm.model_zoo.pellm.opt import OPT
    from fate_llm.model_zoo.pellm.bloom import Bloom
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    slm_model_class = [
        OPT,
        GPT2CLM,
        LLaMa,
        Bloom
    ]

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=slm_lora_target_modules[slm_idx]
    )

    model = slm_model_class[slm_idx](
        pretrained_path=slm_pretrained_paths[slm_idx],
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                          dataset_name="arc_challenge",
                          data_part=f"client_{slm_idx}",
                          seq_max_len=512,
                          need_preprocess=True)
    priv_data.load(process_data_output_dir)

    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=slm_lrs[slm_idx],
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,
    )

    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])

    import json
    with open(llm_to_slm_vocab_mapping_paths[slm_idx], "r") as fin:
        vocab_mapping = json.loads(fin.read())

    trainer = FedMKTSLM(
        ctx=ctx,
        model=model,
        training_args=training_args,
        pub_train_set=pub_data,
        priv_train_set=priv_data,
        tokenizer=tokenizer,
        save_trainable_weights_only=True,
        llm_tokenizer=get_tokenizer(llm_pretrained_path),
        llm_to_slm_vocab_mapping=vocab_mapping,
        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)
    )

    trainer.train()
    trainer.save_model(slm_models_saved_directory[slm_idx])


def run(ctx: Context):
    if ctx.is_on_arbiter:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        train_llm(ctx)
    elif ctx.is_on_guest:
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"
        train_slm(ctx, slm_idx=0)
    else:
        if ctx.local.party[1] == "9999":
            os.environ["CUDA_VISIBLE_DEVICES"] = "2"
            slm_idx = 1
        elif ctx.local.party[1] == "10000":
            os.environ["CUDA_VISIBLE_DEVICES"] = "3"
            slm_idx = 2
        elif ctx.local.party[1] == "10001":
            os.environ["CUDA_VISIBLE_DEVICES"] = "4"
            slm_idx = 3
        else:
            raise ValueError(f"party_id={ctx.local.party[1]} is illegal")

        train_slm(ctx, slm_idx=slm_idx)


if __name__ == "__main__":
    launch(run)


### Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)

Actually, a slightly modifications from 4-SLMs running code are enough to do sft with single clients, it will be listed in below sections, we take SLM-0(OPT) as an example

#### Only Use Single Optimal Vocabulary Mapping Tables

In [None]:
slm_idx = 0
slm_to_llm_vocab_mapping = []
with open(slm_to_llm_vocab_mapping_paths[slm_idx], "r") as fin:
    vocab_mapping = json.loads(fin.read())
    slm_to_llm_vocab_mapping.append(vocab_mapping)

slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]

#### Complete Code To DO SFT With 1 SLM

Please paste the code in "fedmkt_1_slm.py" and execute it with the following command

```python
python fedmkt_1_slm.py --parties guest:9999 arbiter:9999 --log_level INFO
```

In [None]:
# fedmkt_1_slm.py

import os

from fate.arch import Context
from fate.arch.launchers.multiprocess_launcher import launch
import json

process_data_output_dir = ""
llm_pretrained_path = "Llama-2-7b-hf"
slm_0_pretrained_path = "opt-1.3b"
slm_1_pretrained_path = "gpt2-xl"
slm_2_pretrained_path = "Sheared-LLaMa-1.3B"
slm_3_pretrained_path = "bloom-1b1"
llm_slm_pairs = [
    (llm_pretrained_path, slm_0_pretrained_path),
    (llm_pretrained_path, slm_1_pretrained_path),
    (llm_pretrained_path, slm_2_pretrained_path),
    (llm_pretrained_path, slm_3_pretrained_path)
]

vocab_mapping_directory = ""

slm_to_llm_vocab_mapping_paths = ["opt_to_llama.json", "gpt2_to_llama.json", "llama_small_to_llama.json", "bloom_to_llama.json"]
llm_to_slm_vocab_mapping_paths = ["llama_to_opt.json", "llama_to_gpt2.json", "llama_to_llama_small", "llama_to_bloom.json"]

for idx in range(4):
    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + llm_to_slm_vocab_mapping_paths[idx]

slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]
slm_lora_target_modules = [
    ["q_proj", "v_proj"],
    ["c_attn"],
    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    ["query_key_value"]
]

global_epochs = 5
batch_size = 4
llm_lr = 3e-5
slm_lrs = [3e-5]

llm_model_saved_directory = "./models/fedmkt_single_slm_llm"
slm_models_saved_directory = [
    "./models/fedmkt_single_slm_opt",
]


def train_llm(ctx, slm_idx):
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.llama import LLaMa
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
    )

    model = LLaMa(
        pretrained_path=llm_pretrained_path,
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=llm_lr,
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,
    )

    slm_to_llm_vocab_mapping = []
    with open(slm_to_llm_vocab_mapping_paths[slm_idx], "r") as fin:
        vocab_mapping = json.loads(fin.read())
        slm_to_llm_vocab_mapping.append(vocab_mapping)

    slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]

    tokenizer = get_tokenizer(llm_pretrained_path)
    trainer = FedMKTLLM(
        ctx=ctx,
        model=model,
        training_args=training_args,
        train_set=pub_data,
        tokenizer=tokenizer,
        slm_tokenizers=slm_tokenizers,
        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,
        save_trainable_weights_only=True,
    )

    trainer.train()
    trainer.save_model(llm_model_saved_directory)


def train_slm(ctx, slm_idx):
    import transformers
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.llama import LLaMa
    from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM
    from fate_llm.model_zoo.pellm.opt import OPT
    from fate_llm.model_zoo.pellm.bloom import Bloom
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    slm_model_class = [
        OPT,
        GPT2CLM,
        LLaMa,
        Bloom
    ]

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=slm_lora_target_modules[slm_idx]
    )

    model = slm_model_class[slm_idx](
        pretrained_path=slm_pretrained_paths[slm_idx],
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                          dataset_name="arc_challenge",
                          data_part=f"client_{slm_idx}",
                          seq_max_len=512,
                          need_preprocess=True)
    priv_data.load(process_data_output_dir)

    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=slm_lrs[slm_idx],
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,
    )

    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])

    import json
    with open(llm_to_slm_vocab_mapping_paths[slm_idx], "r") as fin:
        vocab_mapping = json.loads(fin.read())

    trainer = FedMKTSLM(
        ctx=ctx,
        model=model,
        training_args=training_args,
        pub_train_set=pub_data,
        priv_train_set=priv_data,
        tokenizer=tokenizer,
        save_trainable_weights_only=True,
        llm_tokenizer=get_tokenizer(llm_pretrained_path),
        llm_to_slm_vocab_mapping=vocab_mapping,
        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)
    )

    trainer.train()
    trainer.save_model(slm_models_saved_directory[slm_idx])


def run(ctx: Context):
    if ctx.is_on_arbiter:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        train_llm(ctx, slm_idx=0)
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"
        train_slm(ctx, slm_idx=0)


if __name__ == "__main__":
    launch(run)


### Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)

In this section, we introduce how to do SFT using FEDMKT algorithm, with only single SLM are trained, but without LLM training, means that SLM distill knowlege from LLM only, not co-training.

#### Difference With Section "Running FEDMKT With Launcher (Experimential Using): 1-SLMs"

Add llm_training=False to fedmkt_training_args to both LLM and LLM is enough!

#### Complete Code To DO SFT With 1 SLM And SLM Trains Only

Please paste the code in "fedmkt_llm_to_slm.py" and execute it with the following command

```python
python fedmkt_llm_to_slm.py --parties guest:9999 arbiter:9999 --log_level INFO
```

In [None]:
# fedmkt_llm_to_slm.py

import os

from fate.arch import Context
from fate.arch.launchers.multiprocess_launcher import launch
import json

process_data_output_dir = ""
llm_pretrained_path = "Llama-2-7b-hf"
slm_0_pretrained_path = "opt-1.3b"
slm_1_pretrained_path = "gpt2-xl"
slm_2_pretrained_path = "Sheared-LLaMa-1.3B"
slm_3_pretrained_path = "bloom-1b1"
llm_slm_pairs = [
    (llm_pretrained_path, slm_0_pretrained_path),
    (llm_pretrained_path, slm_1_pretrained_path),
    (llm_pretrained_path, slm_2_pretrained_path),
    (llm_pretrained_path, slm_3_pretrained_path)
]

vocab_mapping_directory = ""

slm_to_llm_vocab_mapping_paths = ["opt_to_llama.json", "gpt2_to_llama.json", "llama_small_to_llama.json", "bloom_to_llama.json"]
llm_to_slm_vocab_mapping_paths = ["llama_to_opt.json", "llama_to_gpt2.json", "llama_to_llama_small", "llama_to_bloom.json"]

for idx in range(4):
    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + llm_to_slm_vocab_mapping_paths[idx]

slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]
slm_lora_target_modules = [
    ["q_proj", "v_proj"],
    ["c_attn"],
    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    ["query_key_value"]
]

global_epochs = 5
batch_size = 4
llm_lr = 3e-5
slm_lrs = [3e-5]

slm_models_saved_directory = [
    "./models/fedmkt_llm_to_slm_opt",
]


def train_llm(ctx, slm_idx):
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.llama import LLaMa
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
    )

    model = LLaMa(
        pretrained_path=llm_pretrained_path,
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=llm_lr,
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,
        llm_training=False
    )

    slm_to_llm_vocab_mapping = []
    with open(slm_to_llm_vocab_mapping_paths[slm_idx], "r") as fin:
        vocab_mapping = json.loads(fin.read())
        slm_to_llm_vocab_mapping.append(vocab_mapping)

    slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]

    tokenizer = get_tokenizer(llm_pretrained_path)
    trainer = FedMKTLLM(
        ctx=ctx,
        model=model,
        training_args=training_args,
        train_set=pub_data,
        tokenizer=tokenizer,
        slm_tokenizers=slm_tokenizers,
        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,
        save_trainable_weights_only=True,
    )

    trainer.train()


def train_slm(ctx, slm_idx):
    import transformers
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.llama import LLaMa
    from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM
    from fate_llm.model_zoo.pellm.opt import OPT
    from fate_llm.model_zoo.pellm.bloom import Bloom
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    slm_model_class = [
        OPT,
        GPT2CLM,
        LLaMa,
        Bloom
    ]

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1,
        target_modules=slm_lora_target_modules[slm_idx]
    )

    model = slm_model_class[slm_idx](
        pretrained_path=slm_pretrained_paths[slm_idx],
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                          dataset_name="arc_challenge",
                          data_part=f"client_{slm_idx}",
                          seq_max_len=512,
                          need_preprocess=True)
    priv_data.load(process_data_output_dir)

    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=slm_lrs[slm_idx],
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,
        llm_training=False
    )

    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])

    import json
    with open(llm_to_slm_vocab_mapping_paths[slm_idx], "r") as fin:
        vocab_mapping = json.loads(fin.read())

    trainer = FedMKTSLM(
        ctx=ctx,
        model=model,
        training_args=training_args,
        pub_train_set=pub_data,
        priv_train_set=priv_data,
        tokenizer=tokenizer,
        save_trainable_weights_only=True,
        llm_tokenizer=get_tokenizer(llm_pretrained_path),
        llm_to_slm_vocab_mapping=vocab_mapping,
        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)
    )

    trainer.train()
    trainer.save_model(slm_models_saved_directory[slm_idx])


def run(ctx: Context):
    if ctx.is_on_arbiter:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        train_llm(ctx, slm_idx=0)
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"
        train_slm(ctx, slm_idx=0)


if __name__ == "__main__":
    launch(run)


### Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT

To run homogeneous experiments, two steps are needed.
1. add post_fedavg=True to fedmkt_training_args to both LLM and LLM is enough!
2. add fed_args to FEDMKTLLM/FEDMKTSLM

In [None]:
# initialze fed args
from fate.ml.nn.homo.fedavg import FedAVGArguments

fed_args = FedAVGArguments(
    aggregate_strategy='epoch',
    aggregate_freq=1
)

#### Complete Code To DO SFT With 4-SLMs Homogeneous SFT

Please paste the code in "fedmkt_4_slms_homo.py" and execute it with the following command

```python
python fedmkt_4_slms_homo.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO
```

In [None]:
# fedmkt_4_slms_homo.py

import os

from fate.arch import Context
from fate.arch.launchers.multiprocess_launcher import launch
import json

process_data_output_dir = ""
llm_pretrained_path = "Llama-2-7b-hf"
slm_0_pretrained_path = "opt-1.3b"
slm_1_pretrained_path = "opt-1.3b"
slm_2_pretrained_path = "opt-1.3b"
slm_3_pretrained_path = "opt-1.3b"
llm_slm_pairs = [
    (llm_pretrained_path, slm_0_pretrained_path),
    (llm_pretrained_path, slm_1_pretrained_path),
    (llm_pretrained_path, slm_2_pretrained_path),
    (llm_pretrained_path, slm_3_pretrained_path)
]

vocab_mapping_directory = ""

slm_to_llm_vocab_mapping_paths = ["opt_to_llama.json"] * 4
llm_to_slm_vocab_mapping_paths = ["llama_to_opt.json"] * 4

for idx in range(4):
    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + llm_to_slm_vocab_mapping_paths[idx]

slm_pretrained_paths = [slm_0_pretrained_path] * 4
slm_lora_target_modules = [["q_proj", "v_proj"]] * 4

global_epochs = 5
batch_size = 4
llm_lr = 3e-5
slm_lrs = [3e-5, 3e-5, 3e-5, 3e-5, 3e-5]

llm_model_saved_directory = "./models/fedmkt_homo_4_slms_llm_model"
slm_models_saved_directory = [
    "./models/fedmkt_homo_4_slms_slm_0",
]


def train_llm(ctx):
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.llama import LLaMa
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM
    from fate.ml.nn.homo.fedavg import FedAVGArguments
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
    )

    model = LLaMa(
        pretrained_path=llm_pretrained_path,
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=llm_lr,
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,
        post_fedavg=True, # difference
    )

    # difference
    fed_args = FedAVGArguments(
        aggregate_strategy='epoch',
        aggregate_freq=1
    )

    slm_to_llm_vocab_mapping = []
    for path in slm_to_llm_vocab_mapping_paths:
        with open(path, "r") as fin:
            vocab_mapping = json.loads(fin.read())
            slm_to_llm_vocab_mapping.append(vocab_mapping)

    slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]

    tokenizer = get_tokenizer(llm_pretrained_path)
    trainer = FedMKTLLM(
        ctx=ctx,
        model=model,
        training_args=training_args,
        fed_args=fed_args, # difference
        train_set=pub_data,
        tokenizer=tokenizer,
        slm_tokenizers=slm_tokenizers,
        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,
        save_trainable_weights_only=True,
    )

    trainer.train()
    trainer.save_model(llm_model_saved_directory)


def train_slm(ctx, slm_idx):
    import transformers
    from peft import LoraConfig, TaskType
    from fate_llm.model_zoo.pellm.opt import OPT
    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM
    from fate.ml.nn.homo.fedavg import FedAVGArguments
    from fate_llm.dataset.qa_dataset import QaDataset
    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
    from transformers import AutoConfig

    slm_model_class = [OPT] * 4

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,
        target_modules=slm_lora_target_modules[slm_idx]
    )

    model = slm_model_class[slm_idx](
        pretrained_path=slm_pretrained_paths[slm_idx],
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                          dataset_name="arc_challenge",
                          data_part=f"client_{slm_idx}",
                          seq_max_len=512,
                          need_preprocess=True)
    priv_data.load(process_data_output_dir)

    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],
                         dataset_name="arc_challenge",
                         data_part="common",
                         seq_max_len=512,
                         need_preprocess=True)
    pub_data.load(process_data_output_dir)

    training_args = FedMKTTrainingArguments(
        global_epochs=global_epochs,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=batch_size,
        learning_rate=slm_lrs[slm_idx],
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,
        post_fedavg=True, # difference
    )

    # difference
    fed_args = FedAVGArguments(
        aggregate_strategy='epoch',
        aggregate_freq=1
    )

    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])

    import json
    with open(llm_to_slm_vocab_mapping_paths[slm_idx], "r") as fin:
        vocab_mapping = json.loads(fin.read())

    trainer = FedMKTSLM(
        ctx=ctx,
        model=model,
        training_args=training_args, 
        fed_args=fed_args, # difference
        pub_train_set=pub_data,
        priv_train_set=priv_data,
        tokenizer=tokenizer,
        save_trainable_weights_only=True,
        llm_tokenizer=get_tokenizer(llm_pretrained_path),
        llm_to_slm_vocab_mapping=vocab_mapping,
        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)
    )

    trainer.train()
    if slm_idx == 0:
        trainer.save_model(slm_models_saved_directory[slm_idx])


def run(ctx: Context):
    if ctx.is_on_arbiter:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        train_llm(ctx)
    elif ctx.is_on_guest:
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"
        train_slm(ctx, slm_idx=0)
    else:
        if ctx.local.party[1] == "9999":
            os.environ["CUDA_VISIBLE_DEVICES"] = "2"
            slm_idx = 1
        elif ctx.local.party[1] == "10000":
            os.environ["CUDA_VISIBLE_DEVICES"] = "3"
            slm_idx = 2
        elif ctx.local.party[1] == "10001":
            os.environ["CUDA_VISIBLE_DEVICES"] = "4"
            slm_idx = 3
        else:
            raise ValueError(f"party_id={ctx.local.party[1]} is illegal")

        train_slm(ctx, slm_idx=slm_idx)


if __name__ == "__main__":
    launch(run)


### Running FEDMKT with Pipeline (Industrial Using)

Please make sure that [FATE-LLM Cluster](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) has been deployed, ensure that multiple machines has been deployed in FATE-LLM Cluster mode, past the following code to test_fedmkt_4_slms.py, the execute "python test_fedmkt_4_slms.py"

In [13]:
from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fedmkt_runner
from fate_client.pipeline.components.fate.nn.algo_params import FedMKTTrainingArguments, FedAVGArguments
from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader
from peft import LoraConfig, TaskType
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate.reader import Reader
from transformers import AutoConfig

guest = '9999' # replace this party id to actual guest party id in your enviroment
host = ['9999', '10000', '10001'] # replace host party ids in your enviroment
arbiter = '9999' # replace this party id to actual arbiter party id in your enviroment


process_data_output_dir = "" # replace this to actual process_data_output_dir
# replaoce the names of models to local save directories
llm_pretrained_path = "llama-2-7b-hf"
slm_0_pretrained_path = "opt-1.3b"
slm_1_pretrained_path = "gpt2-xl"
slm_2_pretrained_path = "Sheared-LLaMA-1.3B"
slm_3_pretrained_path = "bloom-1b1"
llm_slm_pairs = [
    (llm_pretrained_path, slm_0_pretrained_path),
    (llm_pretrained_path, slm_1_pretrained_path),
    (llm_pretrained_path, slm_2_pretrained_path),
    (llm_pretrained_path, slm_3_pretrained_path)
]

vocab_mapping_directory = "" # reploace this to actual voacb_mapping_directory

slm_to_llm_vocab_mapping_paths = ["opt_to_llama.json", "gpt2_to_llama.json", "llama_small_to_llama.json", "bloom_to_llama.json"]
llm_to_slm_vocab_mapping_paths = ["llama_to_opt.json", "llama_to_gpt2.json", "llama_to_llama_small", "llama_to_bloom.json"]

for idx in range(4):
    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + slm_to_llm_vocab_mapping_paths[idx]
    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + "/" + llm_to_slm_vocab_mapping_paths[idx]

slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]
slm_lora_target_modules = [
    ["q_proj", "v_proj"],
    ["c_attn"],
    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    ["query_key_value"]
]
slm_models = [
    ("pellm.opt", "OPT"),
    ("pellm.gpt2", "GPT2CLM"),
    ("pellm.llama", "LLaMa"),
    ("pellm.bloom", "Bloom")
]


def get_llm_conf():
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
    )
    lora_config.target_modules = list(lora_config.target_modules)

    llm_model = LLMModelLoader(
        "pellm.llama",
        "LLaMa",
        pretrained_path=llm_pretrained_path,
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
        torch_dtype="bfloat16"
    )

    pub_dataset = LLMDatasetLoader(
        "qa_dataset",
        "QaDataset",
        tokenizer_name_or_path=llm_pretrained_path,
        need_preprocess=True,
        dataset_name="arc_challenge",
        data_part="common",
        seq_max_len=512
    )

    training_args = FedMKTTrainingArguments(
        global_epochs=5,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=3e-5,
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,
    )

    fed_args = FedAVGArguments(
        aggregate_strategy='epoch',
        aggregate_freq=1
    )

    tokenizer = LLMDataFuncLoader(
        "tokenizers.cust_tokenizer",
        "get_tokenizer",
        tokenizer_name_or_path=llm_pretrained_path
    )

    slm_tokenizers = list()
    for slm_pretrained_path in slm_pretrained_paths:
        slm_tokenizers.append(
            LLMDataFuncLoader("tokenizers.cust_tokenizer", "get_tokenizer", tokenizer_name_or_path=slm_pretrained_path)
        )

    return get_config_of_fedmkt_runner(
        model=llm_model,
        training_args=training_args,
        fed_args=fed_args,
        pub_dataset=pub_dataset,
        tokenizer=tokenizer,
        slm_tokenizers=slm_tokenizers,
        slm_to_llm_vocab_mapping_paths=slm_to_llm_vocab_mapping_paths,
        pub_dataset_path=process_data_output_dir,
        save_trainable_weights_only=True,
    )


def get_slm_conf(slm_idx):
    slm_pretrained_path = slm_pretrained_paths[slm_idx]
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,
        target_modules=slm_lora_target_modules[slm_idx]
    )
    lora_config.target_modules = list(lora_config.target_modules)
    llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping_paths[slm_idx]

    slm_model = LLMModelLoader(
        slm_models[slm_idx][0],
        slm_models[slm_idx][1],
        pretrained_path=slm_pretrained_path,
        peft_type="LoraConfig",
        peft_config=lora_config.to_dict(),
    )
    vocab_size = AutoConfig.from_pretrained(slm_pretrained_path).vocab_size

    pub_dataset = LLMDatasetLoader(
        "qa_dataset",
        "QaDataset",
        tokenizer_name_or_path=slm_pretrained_path,
        need_preprocess=True,
        dataset_name="arc_challenge",
        data_part="common",
        seq_max_len=512
    )

    priv_dataset = LLMDatasetLoader(
        "qa_dataset",
        "QaDataset",
        tokenizer_name_or_path=slm_pretrained_path,
        need_preprocess=True,
        dataset_name="arc_challenge",
        data_part="client_0",
        seq_max_len=512
    )

    training_args = FedMKTTrainingArguments(
        global_epochs=5,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=3e-5 if slm_idx != 1 else 3e-4
        output_dir="./",
        dataloader_num_workers=4,
        remove_unused_columns=False,
        warmup_ratio=0.008,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=0.1,
        max_grad_norm=1.0,
        use_cpu=False,
        vocab_size=vocab_size,
        # post_fedavg=True,
        # llm_training=False,
    )

    fed_args = FedAVGArguments(
        aggregate_strategy='epoch',
        aggregate_freq=1
    )

    tokenizer = LLMDataFuncLoader(
        "tokenizers.cust_tokenizer",
        "get_tokenizer",
        tokenizer_name_or_path=slm_pretrained_path
    )

    llm_tokenizer = LLMDataFuncLoader(
        "tokenizers.cust_tokenizer", "get_tokenizer", tokenizer_name_or_path=llm_pretrained_path
    )

    data_collator = LLMDataFuncLoader(module_name='data_collator.cust_data_collator',
                                      item_name='get_seq2seq_data_collator', tokenizer_name_or_path=slm_pretrained_path)

    return get_config_of_fedmkt_runner(
        model=slm_model,
        training_args=training_args,
        fed_args=fed_args,
        pub_dataset=pub_dataset,
        priv_dataset=priv_dataset,
        tokenizer=tokenizer,
        llm_tokenizer=llm_tokenizer,
        llm_to_slm_vocab_mapping_path=llm_to_slm_vocab_mapping,
        pub_dataset_path=process_data_output_dir,
        save_trainable_weights_only=True,
        data_collator=data_collator
    )


pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter, host=host)
pipeline.bind_local_path(path=process_data_output_dir, namespace="experiment", name="arc_challenge")


reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host))
reader_0.guest.task_parameters(
    namespace="experiment",
    name="arc_challenge"
)
reader_0.hosts[[0, 1, 2]].task_parameters(
    namespace="experiment",
    name="arc_challenge"
)


homo_nn_0 = HomoNN(
    'nn_0',
    train_data=reader_0.outputs["output_data"],
    runner_module="fedmkt_runner",
    runner_class="FedMKTRunner",
)

homo_nn_0.arbiter.task_parameters(
    runner_conf=get_llm_conf()
)

homo_nn_0.guest.task_parameters(
    runner_conf=get_slm_conf(slm_idx=0)
)

for idx in range(3):
    homo_nn_0.hosts[idx].task_parameters(
        runner_conf=get_slm_conf(slm_idx=idx + 1)
    )

homo_nn_0.guest.conf.set("launcher_name", "deepspeed") # tell schedule engine to run task with deepspeed
homo_nn_0.hosts[[0, 1, 2]].conf.set("launcher_name", "deepspeed") # tell schedule engine to run task with deepspeed
homo_nn_0.arbiter.conf.set("launcher_name", "deepspeed") # tell schedule engine to run task with deepspeed

pipeline.add_tasks([reader_0, homo_nn_0])
pipeline.conf.set("task", dict(engine_run={"cores": 1})) # the number of gpus of each party

pipeline.compile()
pipeline.fit()

