# PDSS Tutorial

## Introduction to PDSS

PDSS is a novel framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. The framework addresses two major challenges faced by LLM deployment in real-world applications: the privacy of domain-specific knowledge and resource constraints.

PDSS adopts a server-client architecture where the client sends perturbed prompts to the server-side LLM for inference, generating perturbed rationales. The client then decodes these rationales and uses them to enrich the training of its task-specific SLM, ultimately enhancing its performance.

PDSS introduces two privacy protection strategies: 
- **the Exponential Mechanism Strategy**
- **the Encoder-Decoder Strategy**
  
The Exponential Mechanism Strategy utilizes a DP(differential privacy) based exponential mechanism to obfuscate user prompts, while the Encoder-Decoder Strategy employs a specialized Encoder-Decoder SLM to encode and decode perturbed prompts and rationales. These strategies effectively balance user privacy and the usability of rationales, allowing for secure and enhanced training of the client's SLM without compromising on privacy concerns.

Through experiments on various text generation tasks, PDSS demonstrates its effectiveness in training task-specific SLMs with enhanced performance, significantly improving the SLM's capabilities while prioritizing data privacy protection. For more details, please refer to the [original paper](https://arxiv.org/pdf/2406.12403).

**Before reading this tutorial, we strongly recommend that you first read [the InferDPT](./) tutorial.**

## Use the Infer Client & Server

In this section, we are going to introduce the inference part, which is the key part of PDSS that generates useful rationales with privacy-preserving. You can use InferDPT(which utilize the Exponential Mechanism Strategy) or specifically trained SLM as the text encoder & decoder. In this section, we retrieve a sample from the arc-easy dataset as an example:

In [10]:
test_example = {'id': 'Mercury_7220990',
'question': 'Which factor will most likely cause a person to develop a fever?',
'choices': {'text': ['a leg muscle relaxing after exercise',
'a bacterial population in the bloodstream',
'several viral particles on the skin',
'carbohydrates being digested in the stomach'],
'label': ['A', 'B', 'C', 'D']},
'answerKey': 'B'}

### Fate Context

We need to create fate context to enable the communication between client and server. Then, we can initialize infer client(who will encodes the raw prompt and decodes the perturbed response) and server(who deploys the LLM) to enable secure inference.

In [6]:
arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed1"

def create_ctx(local):
    from fate.arch import Context
    from fate.arch.computing.backends.standalone import CSession
    from fate.arch.federation.backends.standalone import StandaloneFederation
    import logging

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)

    logger.addHandler(console_handler)
    computing = CSession(data_dir="./session_dir")
    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))

### The DP based Strategy(InferDPT)

As outlined in the [InferDPT tutorial](./), you can initialize the InferDPT client and server to facilitate secure and private inference. Prior to executing the InferDPT component, it is recommended to generate the InferDPT kit by following the step-by-step instructions provided in the tutorial.

#### Client-Side Code

On the client side, we load the pre-computed inferdpt-kit and deploy a local SLM as the decoding model.

In [None]:
from fate_llm.algo.inferdpt.inference.api import APICompletionInference
from fate_llm.algo.inferdpt import inferdpt
from fate_llm.algo.inferdpt.utils import InferDPTKit
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
from jinja2 import Template
from fate.arch import Context
import sys

arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed1"

def create_ctx(local):
    from fate.arch import Context
    from fate.arch.computing.backends.standalone import CSession
    from fate.arch.federation.backends.standalone import StandaloneFederation
    import logging

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)

    logger.addHandler(console_handler)
    computing = CSession(data_dir="./session_dir")
    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))

ctx = create_ctx(guest)
save_kit_path = 'your path'
kit = InferDPTKit.load_from_path(save_kit_path)
# local deployed small model as decoding model
inference = APICompletionInference(api_url="http://127.0.0.1:8887/v1", model_name='./Qwen1.5-0.5B', api_key='EMPTY')

test_example = {'id': 'Mercury_7220990',
'question': 'Which factor will most likely cause a person to develop a fever?',
'choices': {'text': ['a leg muscle relaxing after exercise',
'a bacterial population in the bloodstream',
'several viral particles on the skin',
'carbohydrates being digested in the stomach'],
'label': ['A', 'B', 'C', 'D']},
'answerKey': 'B'}


doc_template = """{{question}} 
Choices:{{choices.text}}
"""

instruction_template="""
<s>[INST]
Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.
Use <end> to finish your rationle."

Example(s):
Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?
Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']
Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>

Please explain:
Question:{{perturbed_doc}}
Rationale:
[/INST]
"""

decode_template = """Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.
Use <end> to finish your rationle."

Example(s):
Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?
Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']
Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>

Question:{{perturbed_doc}}
Rationale:{{perturbed_response | replace('\n', '')}}<end>

Please explain:
Question:{{question}} 
Choices:{{choices.text}}
Rationale:
"""

inferdpt_client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)
result = inferdpt_client.inference([test_example], doc_template, instruction_template, decode_template, \
                                 remote_inference_kwargs={
                                    'stop': ['<\s>'],
                                    'temperature': 0.01,
                                    'max_tokens': 256
                                 },
                                 local_inference_kwargs={
                                    'stop': ['<|im_end|>', '<end>', '<end>\n', '<end>\n\n', '.\n\n\n\n\n', '<|end_of_text|>', '>\n\n\n'],
                                    'temperature': 0.01,
                                    'max_tokens': 256
                                 })
print('result is {}'.format(result[0]['inferdpt_result']))

#### Server Side Code

In [9]:
from fate_llm.algo.inferdpt.utils import InferDPTKit
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
from jinja2 import Template
from fate.arch import Context
import sys
from fate_llm.algo.inferdpt.inference.api import APICompletionInference

arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed1"

def create_ctx(local):
    from fate.arch import Context
    from fate.arch.computing.backends.standalone import CSession
    from fate.arch.federation.backends.standalone import StandaloneFederation
    import logging

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)

    logger.addHandler(console_handler)
    computing = CSession(data_dir="./session_dir")
    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))

ctx = create_ctx(arbiter)
# Api to a LLM
inference_server = APICompletionInference(api_url="http://127.0.0.1:8888/v1", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')
inferdpt_server = InferDPTServer(ctx, inference_server)
inferdpt_server.inference()

Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:

```
The given question asks which factor will most likely cause a person to develop a fever. The factors mentioned are a leg muscle relaxing after exercise, a bacterial population in the bloodstream, several viral particles on the skin, and carbohydrates being digested in the stomach. The question is asking which factor is most likely to cause a person to develop a fever. The factors are all related to the body's internal environment, but the most likely factor is a bacterial population in the bloodstream. This is because bacteria can cause a fever, and the body's immune system responds to the infection by producing antibodies that can fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'
```

### The Encoder-Decoder Model Strategy

Similar to the InferDPT, we can initialize SLMEncoderDecoderClient and SLMEncoderDecoderServer to enable secure inference.
The client will encode the raw prompt using local slm model and then decoded it with the same model

#### Client Side Code

In [4]:
from fate_llm.algo.inferdpt.inference.api import APICompletionInference
from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient

arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed1"

def create_ctx(local):
    from fate.arch import Context
    from fate.arch.computing.backends.standalone import CSession
    from fate.arch.federation.backends.standalone import StandaloneFederation
    import logging

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)

    logger.addHandler(console_handler)
    computing = CSession(data_dir="./session_dir")
    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))


test_example = {'id': 'Mercury_7220990',
'question': 'Which factor will most likely cause a person to develop a fever?',
'choices': {'text': ['a leg muscle relaxing after exercise',
'a bacterial population in the bloodstream',
'several viral particles on the skin',
'carbohydrates being digested in the stomach'],
'label': ['A', 'B', 'C', 'D']},
'answerKey': 'B'
}


encode_prompt = """Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use <end> to end your reply.
Origin Doc:Question:{{question}}
Choices:{{choices.text}}
Perturb Doc: 
"""

decode_prompt = """This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.

Perturbed doc and rationale:
{{perturbed_doc}}
Rationale:{{perturbed_response}}

Original Doc:
Question:{{question}}
Choices:{{choices.text}}

Recover Rationale:
"""

instruction_template = """<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.
Use <end> to finish your rationle.

Example(s):
Question:Which factor will most likely cause a person to develop a fever?
Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']
Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'

Please explain:
{{perturbed_doc}}
Rationale:
<|im_end|>
<|im_start|>assistant
"""

ctx = create_ctx(guest)
model_name = 'Deploy your encoder decoder model'
# api_url to your locally deployed encoder decoder
api = APICompletionInference(api_url='http://127.0.0.1:8887/v1', api_key='EMPTY', model_name=model_name)
client = SLMEncoderDecoderClient(ctx, api)
result = client.inference([test_example], encode_prompt, instruction_template, decode_prompt, \
                                 remote_inference_kwargs={
                                    'stop': ['<\s>'],
                                    'temperature': 0.01,
                                    'max_tokens': 256
                                 },
                                 local_inference_kwargs={
                                    'stop': ['<|im_end|>', '<end>', '<end>\n', '<end>\n\n', '.\n\n\n\n\n', '<|end_of_text|>', '>\n\n\n'],
                                    'temperature': 0.01,
                                    'max_tokens': 256
                                 })
print('result is {}'.format(result[0]['inferdpt_result']))

#### Server Side Code

In [7]:
from fate_llm.algo.inferdpt.inference.api import APICompletionInference
from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderServer

arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed1"

def create_ctx(local):
    from fate.arch import Context
    from fate.arch.computing.backends.standalone import CSession
    from fate.arch.federation.backends.standalone import StandaloneFederation
    import logging

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)

    logger.addHandler(console_handler)
    computing = CSession(data_dir="./session_dir")
    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))

ctx = create_ctx(arbiter)
# api url&name are depolyed LLM
model_name = '/data/cephfs/llm/models/Qwen1.5-14B-Chat/'
api = APICompletionInference(api_url='http://127.0.0.1:8888/v1', api_key='EMPTY', model_name=model_name)
server = SLMEncoderDecoderServer(ctx, api)
server.inference()

Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:

```
A fever is typically caused by a bacterial population in the bloodstream, as it is a response to an infection. So the answer is 'a bacterial population in the bloodstream'.
```

## Prefix Dataset & PDSS Trainer

Now that we can carry out privacy-preserving inference and acquire rationales, the next step is to train a new task-specific model, enhanced by the rationales generated by the LLMs.

In this section, we will introduce the PrefixDataset and PDSSTrainer, which facilitate training tasks with the added benefit of supplementary rationales. The PrefixDataset allows you to assign various text prefixes, guiding the model to produce different text targets. With PDSSTrainer, the model is trained to generate both text labels and text rationales at each update step, ultimately leading to superior performance compared to training on the raw dataset alone.

### Prepare dataset
In this tutorial, we will use the arc-easy dataset.

In [None]:
from datasets import load_dataset
dataset = load_dataset("arc_easy")
dataset.save_to_disk('path_to_save/arce')

Let’s proceed with testing the PrefixDataset. We can utilize Jinja2 templates to structure the text and append prefixes or suffixes to our training data.

Please note that at this stage, the dataset does not contain rationales. In the 'rationale_output_template', the key used for the inference results is ‘infer_result’. We can perform secure inference using the PDSSTrainer and then integrate the rationale results, keyed as ‘infer_result’, into the PrefixDataset.

In [17]:
from fate_llm.dataset.pdss_dataset import PrefixDataset

pds = PrefixDataset(
        tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',
        predict_input_template="""Predict:
Question:{{question}}
Choices:{{choices.text}}
Answer:
    """,
        predict_output_template="""{{choices.text[choices.label.index(answerKey)]}}<end>""",
        rationale_input_template="""Explain:
Question:{{question}}
Choices:{{choices.text}}
Rationale:
    """,
        rationale_output_template="""{{infer_result}}<end>""",
        max_input_length=128,
        max_target_length=128,
        split_key='train'
    )


pds.load('path_to_save/arce')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [27]:
pds.dataset[0] # the structure is the same as hf dataset

{'id': 'Mercury_7220990',
 'question': 'Which factor will most likely cause a person to develop a fever?',
 'choices': {'text': ['a leg muscle relaxing after exercise',
   'a bacterial population in the bloodstream',
   'several viral particles on the skin',
   'carbohydrates being digested in the stomach'],
  'label': ['A', 'B', 'C', 'D']},
 'answerKey': 'B'}

In [21]:
pds.get_str_item(0)  # we can see that the output of rationale term is empty

{'predict': {'input': "Predict:\nQuestion:Which factor will most likely cause a person to develop a fever?\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\nAnswer:\n    ",
  'output': 'a bacterial population in the bloodstream<end>'},
 'rationale': {'input': "Explain:\nQuestion:Which factor will most likely cause a person to develop a fever?\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\nRationale:\n    ",
  'output': '<end>\n    '}}

In [25]:
print(pds[0]) # show tokenized, for the sake of breif we dont show it in this tutorial doc

### The PDSSTrainer

Here we introduce the PDSSTrainer which is develop based on Huggingface trainer and supports collaboratively training a task with raw labels and additional rationales. Here show how the compute loss function is realized:

In [None]:
def compute_loss(self, model, inputs, return_outputs=False):

    label_outputs = model(**inputs['predict'])
    cot_outputs = model(**inputs['rationale'])
    loss = self.alpha * cot_outputs.loss + (1. - self.alpha) * label_outputs.loss
    return (loss, {'rationale_loss': cot_outputs, 'predict_loss': label_outputs}) if return_outputs else loss

You have the option to choose from three distinct modes: ‘infer_only’, ‘train_only’, and ‘infer_and_train’, to meet your specific requirements.
- infer_only: Only generate the rationales and they will be saved to the output_dir
- train_only: Local training only
- infer_and_train: Generate rationales, and then load them into PrefixDataset and start training
  
In this instance, we will opt for the ‘infer_and_train’ mode to initially generate rationales with the assistance of the remote LLM. To activate the inference process, it is necessary to initialize the infer client and server for both the client-side and server-side trainers, as demonstrated in the preceding sections.

Below is an PDSS example. We ran this example on a machine equipped with 4 V100-32G GPUs. We launch the client script using deepspeed. LLM is depolyed on another machine.

## PDSS Example

### Client Script(deepspeed_run.py)

This script show how to setup a pdss task on the client side.

In [None]:
import logging
import os
import sys
from transformers import (
    AutoTokenizer,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List
from fate_llm.algo.inferdpt.utils import InferDPTKit
from fate_llm.dataset.pdss_dataset import PrefixDataset
from fate_llm.algo.pdss.pdss_trainer import PDSSTrainerClient
from fate_llm.data.data_collator.pdss_collator import PrefixDataCollator
from fate_llm.algo.inferdpt import inferdpt


arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed1"


def create_ctx(local):
    from fate.arch import Context
    from fate.arch.computing.backends.standalone import CSession
    from fate.arch.federation.backends.standalone import StandaloneFederation
    import logging

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)

    logger.addHandler(console_handler)
    computing = CSession(data_dir="./session_dir")
    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))


logger = logging.getLogger(__name__)


doc_template = """{{question}} 
Choices:{{choices.text}}
"""

instruction_template="""
<s>[INST]
Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.
Use <end> to finish your rationle."

Example(s):
Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?
Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']
Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>

Please explain:
Question:{{perturbed_doc}}
Rationale:
[/INST]
"""

decode_template = """Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.
Use <end> to finish your rationle."

Example(s):
Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?
Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']
Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>

Question:{{perturbed_doc}}
Rationale:{{perturbed_response | replace('\n', '')}}<end>

Please explain:
Question:{{question}} 
Choices:{{choices.text}}
Rationale:
"""
    

if __name__ == "__main__":
    
    parser = HfArgumentParser(Seq2SeqTrainingArguments)
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        training_args = parser.parse_args_into_dataclasses()[0]

    model_path = '/data/cephfs/llm/models/Qwen1.5-0.5B/'
    pds = PrefixDataset(
        tokenizer_path=model_path,
        predict_input_template="""Predict:
Question:{{question}}
Choices:{{choices.text}}
Answer:
    """,
        predict_output_template="""{{choices.text[choices.label.index(answerKey)]}}<end>""",
        rationale_input_template="""Explain:
Question:{{question}}
Choices:{{choices.text}}
Rationale:
    """,
        rationale_output_template="""{{infer_result}}<end>
    """,
        max_input_length=128,
        max_target_length=128,
        split_key='train'
    )
    pds.load('/data/cephfs/llm/datasets/arce/')
    
    model = AutoModelForCausalLM.from_pretrained(model_path).half().cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    ctx = create_ctx(guest)
    if training_args.local_rank == 0:
        # only rank 0 need to load infer instance
        save_kit_path = 'your path'
        kit = InferDPTKit.load_from_path(save_kit_path)
        # local deployed small model as decoding model
        from fate_llm.algo.inferdpt.inference.api import APICompletionInference
        inference = APICompletionInference(api_url="http://xxxx/v1", model_name='./Qwen1.5-0.5B', api_key='EMPTY')
        client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)
    else:
        client = None
    
    trainer = PDSSTrainerClient(
        ctx=ctx,
        model=model,
        training_args=training_args,
        tokenizer=tokenizer,    
        train_set=pds,
        data_collator=PrefixDataCollator(tokenizer),
        mode='infer_and_train',
        infer_client=client,
        encode_template=doc_template,
        decode_template=decode_template,
        instruction_template=instruction_template,
        remote_inference_kwargs={
            'stop': ['<\s>'],
            'temperature': 0.01,
            'max_tokens': 256
         },
         local_inference_kwargs={
            'stop': ['<|im_end|>', '<end>', '<end>\n', '<end>\n\n', '.\n\n\n\n\n', '<|end_of_text|>', '>\n\n\n'],
            'temperature': 0.01,
            'max_tokens': 256
         }
    )

    trainer.train()

    if training_args.local_rank == 0:
        model.save_pretrained(training_args.output_dir)
        tokenizer.save_pretrained(training_args.output_dir)

### Server Script(server.py)

This script show how to setup a pdss task on the server side.

In [None]:
from fate_llm.algo.inferdpt.inferdpt import InferDPTServer
from fate_llm.algo.pdss.pdss_trainer import PDSSTraineServer
from jinja2 import Template
from fate.arch import Context
import sys


arbiter = ("arbiter", 10000)
guest = ("guest", 10000)
host = ("host", 9999)
name = "fed1"


def create_ctx(local):
    from fate.arch import Context
    from fate.arch.computing.backends.standalone import CSession
    from fate.arch.federation.backends.standalone import StandaloneFederation
    import logging

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)

    logger.addHandler(console_handler)
    computing = CSession(data_dir="./session_dir")
    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))


from fate_llm.algo.inferdpt.inference.api import APICompletionInference
api = APICompletionInference(api_url='http://xxxx:8080/v1', api_key='EMPTY', model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat')

ctx = create_ctx(arbiter)
server_api = InferDPTServer(ctx, api)
server = PDSSTraineServer(ctx, server_api)
server.train()

### Start script

You can launch client side training with following script:

```
deepspeed --num_nodes 1 --num_gpus 4 deepspeed_run.py \
    --output_dir "./" \
    --per_device_train_batch_size "1" \
    --gradient_accumulation_steps "8" \
    --max_steps "750" \
    --fp16 \
    --logging_steps 10 \
    --save_only_model \
    --deepspeed "./ds_config.json" 
```

and the ds_config.json is
```
{   
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 8,
    "optimizer": {
        "type": "AdamW",
        "params": {
             "lr": 5e-5
        }
    },
    "fp16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 0
    }
}
```

And server side:

```python server.py```

## PDSS Pipeline Example

You have the capability to submit a PDSS task within the FATE pipeline. By appropriately configuring the necessary settings, you can execute PDSS in a production environment.

In [None]:
from fate_llm.runner.pdss_runner import PDSSRunner
from fate.components.components.nn.nn_runner import loader_load_from_conf
from fate.components.components.nn.loader import Loader
from fate_llm.dataset.pdss_dataset import PrefixDataset
from fate_client.pipeline.components.fate.nn.loader import ModelLoader, DatasetLoader, CustFuncLoader, Loader
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
    Trainer
)
import argparse
from fate_client.pipeline.utils import test_utils
from fate_client.pipeline.components.fate.evaluation import Evaluation
from fate_client.pipeline.components.fate.reader import Reader
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate.nn.torch import nn, optim
from fate_client.pipeline.components.fate.nn.torch.base import Sequential
from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner
from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments

def main(config="../../config.yaml", namespace=""):
    # obtain config
    if isinstance(config, str):
        config = test_utils.load_job_config(config)
    parties = config.parties
    guest = '9999'
    host = parties.host[0]
    arbiter = '10000'

    pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)

    reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest))
    reader_0.guest.task_parameters(
        namespace="experiment",
        name="arc_e_example"
    )

    model_conf = Loader(module_name='fate_llm.model_zoo.hf_model', item_name='HFAutoModelForCausalLM', 
                        pretrained_model_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()
    data_collator_conf = Loader(module_name='fate_llm.data.data_collator.pdss_collator', item_name='get_prefix_data_collator', tokenizer_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()

    infer_init_conf_client = {
        'module_name': 'fate_llm.algo.inferdpt.init.default_init',
        'item_name': 'InferDPTAPIClientInit'
    }

    infer_init_conf_server = {
        'module_name': 'fate_llm.algo.inferdpt.init.default_init',
        'item_name': 'InferDPTAPIServerInit'
    }

    dataset_conf = {
        'module_name': 'fate_llm.dataset.pdss_dataset',
        'item_name': 'PrefixDataset',
        'kwargs':dict(
            tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',
            predict_input_template="""Predict:
    Question:{{question}}
    Choices:{{choices.text}}
    """,
            predict_output_template="""{{choices.text[choices.label.index(answerKey)]}}<end>""",
            rationale_input_template="""Explain:
    Question:{{question}}
    Choices:{{choices.text}}
    """,
            rationale_output_template="""{{infer_result}}<end>
        """,
            max_input_length=128,
            max_target_length=128,
            split_key='train'
        )
    }

    encoder_prompt = """{{question}}
Choices:{{choices.text}}
"""

    decoder_prompt = """Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.Use <end> to finish your rationle.

Example(s):
Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?
Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']
Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>

Question:{{perturbed_doc}}
Rationale:{{perturbed_response | replace('\n', '')}}<end>

Please explain:
Question:{{question}} 
Choices:{{choices.text}}
    """

    instruction_prompt = """<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.
Use <end> to finish your rationle.

Example(s):
Question:Which factor will most likely cause a person to develop a fever?
Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']
Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'

Please explain:
Question:{{perturbed_doc}}
Rationale:
<|im_end|>
<|im_start|>assistant
    """

    remote_inference_kwargs={
        'stop': ['<|im_end|>', '<end>', '<end>\n', '<end>\n\n', '.\n\n\n\n\n', '<|end_of_text|>', '>\n\n\n'],
        'temperature': 0.01,
        'max_tokens': 256
    }

    local_inference_kwargs={
        'stop': ['<|im_end|>', '<end>', '<end>\n', '<end>\n\n', '.\n\n\n\n\n', '<|end_of_text|>', '>\n\n\n'],
        'temperature': 0.01,
        'max_tokens': 256
    }

    ds_config = {   
        "train_micro_batch_size_per_gpu": 1,
        "gradient_accumulation_steps": 8,
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": 5e-5
            }
        },
        "fp16": {
            "enabled": True
        },
        "zero_optimization": {
            "stage": 0
        }
    }

    training_args_dict = dict(
        per_device_train_batch_size=1, 
        gradient_accumulation_steps=8,
        logging_steps=10,
        max_steps=30,
        fp16=True,
        log_level='debug'
    )

    mode = 'infer_and_train'

    client_conf = dict(
        model_conf=model_conf,
        dataset_conf=dataset_conf,
        training_args_conf=training_args_dict,
        data_collator_conf=data_collator_conf,
        mode=mode,
        infer_inst_init_conf=infer_init_conf_client,
        encode_template=encoder_prompt,
        instruction_template=instruction_prompt,
        decode_template=decoder_prompt,
        remote_inference_kwargs=remote_inference_kwargs,
        local_inference_kwargs=local_inference_kwargs,
        perturb_doc_key='perturbed_doc',
        perturbed_response_key='perturbed_response',
        result_key='infer_result'
    )

    server_conf = dict(
        infer_inst_init_conf=infer_init_conf_server,
        mode=mode
    )

    homo_nn_0 = HomoNN(
        'nn_0',
        train_data=reader_0.outputs["output_data"],
        runner_module="pdss_runner",
        runner_class="PDSSRunner"
    )

    homo_nn_0.guest.task_parameters(runner_conf=client_conf)
    homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)

    homo_nn_0.guest.conf.set("launcher_name", "deepspeed")

    pipeline.add_tasks([reader_0, homo_nn_0])
    pipeline.conf.set("task", dict(engine_run={"cores": 4}))
    pipeline.compile()
    pipeline.fit()

if __name__ == "__main__":
    parser = argparse.ArgumentParser("PIPELINE DEMO")
    parser.add_argument("--config", type=str, default="../config.yaml",
                        help="config file")
    parser.add_argument("--namespace", type=str, default="",
                        help="namespace for data stored in FATE")
    args = parser.parse_args()
    main(config=args.config, namespace=args.namespace)
