# PDSS - Train a SLM Encoder Decoder

PDSS is an innovative framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. This method involves a strategy that trains a small language model (SLM) to learn from perturbed and recovered texts. The SLM can then encode raw text, produce results similar to differential privacy mechanisms, and return higher quality recovered text.

In this tutorial, we will introduce how to train an SLM using the built-in trainer.

## Prepare Data

Several steps need to be done to prepare data for training a SLM encoder-decoder model:
- Sample data from original dataset(For example 50%)
- Organize raw text and get a direct rationale reply from a remote LLM
- Perturb doc using InferDPTKit to get perturbed docs
- Get perturbed replies from a remote LLM
- Organize training data

### Sample data
Here we will use the arc-easy data as an example, and take first 50% of the original dataset


In [3]:
from datasets import load_dataset, load_from_disk
ds = load_dataset('arc_easy')['train']
ds = [ds[i] for i in range(len(ds)//2)]

### Get Direct Replies from A Remote LLM

We use the inference class to create an API for remote LLMs, or you can implement this part on your own.

In [42]:
from fate_llm.inference.api import APICompletionInference
from jinja2 import Template
from transformers import AutoTokenizer

# We are using a Qwen 14B model as the remote model
# You can change the setting
api = APICompletionInference(
    api_url='http://172.21.140.2:8081/v1',
    api_key='EMPTY',
    model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat'
)

tokenizer = AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B-Chat/')

arc_e_template_r = """Select Answer from Choices and explain it in "Rationale" with few words. Please refer to the example to write the rationale.
Use <end> to finish your rationle.

Example(s):
Question:Which factor will most likely cause a person to develop a fever?
Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']
Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'

Please explain:
Question:{{question}}
Choices:{{choices.text}}
Rationale:
"""

template = Template(arc_e_template_r)
docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in ds]
results = api.inference(docs_to_infer, {
    'stop': ['<|im_end|>', '<end>', '<end>\n', '<end>\n\n', '.\n\n\n\n\n', '<|end_of_text|>', '>\n\n\n'],
    'temperature': 0.01,
    'max_tokens': 256
})

for i, r in zip(ds, results):
    i['rationale'] = r

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
print(results[0])

A fever is a response to an infection, typically caused by bacteria or viruses. So, the answer is 'a bacterial population in the bloodstream' because it indicates an immune response to a foreign invader. 'Several viral particles on the skin' could also lead to a fever if they enter the body, but bloodstream presence is more direct. The other choices are unrelated to fever development.


### Perturb Docs & Replies

You can refer to the InferDPT tutorial for guidance on using the InferDPTKit to generate perturbed documents: [InferDPT Document](./)
We can produce perturbed doc using InferDPTKit:

In [8]:
from fate_llm.algo.inferdpt.utils import InferDPTKit
path_to_kit = '/data/projects/inferdpt/test_fate_llm/'
kit = InferDPTKit.load_from_path(path_to_kit)

In [22]:
import copy
tmp_ds = copy.deepcopy(ds)

q_doc = [kit.perturb(i, epsilon=1.0) for i in [Template("""{{question}}""").render(i) for i in tmp_ds]]
c_doc = [kit.perturb(i, epsilon=1.0) for i in [Template("""{{choices.text}}""").render(i) for i in tmp_ds]]
for i,q,c in zip(tmp_ds,q_doc,c_doc):
    i['question'] = q
    i['choices']['text'] = c

In [23]:
tmp_ds[6]

{'id': 'Mercury_7179953',
 'question': 'stuff two alpha Rogers are today chap in Department?',
 'choices': {'text': "['muscular and skeletal', 'digestive and muscular', 'skeletal and pasteiratory', 'respiratory and exhibive']",
  'label': ['A', 'B', 'C', 'D']},
 'answerKey': 'A',
 'rationale': {...}}

And then send formatted docs to remote LLM for perturbed responses:

In [33]:
template = Template(arc_e_template_r)
docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in tmp_ds]
p_results = api.inference(docs_to_infer, {
    'stop': ['<|im_end|>', '<end>', '<end>\n', '<end>\n\n', '.\n\n\n\n\n', '<|end_of_text|>', '>\n\n\n'],
    'temperature': 0.01,
    'max_tokens': 256
})

In [37]:
for i, r in zip(ds, p_results):
    i['p_rationale'] = r

for i,q,c in zip(ds, q_doc, c_doc):
    i['p_question'] = q
    i['p_choice'] = c

### Organize Training Data

As described in the original paper, we need to train the encoder and decoder in one model.
We can organize the training data using templates below:

In [47]:
train_data = []

encoder_prompt = Template("""Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use <end> to end your reply.
Origin Doc: 
Question:{{question}}
Choices:{{choices.text}}

Perturbed Doc:
""")

encoder_out = Template("""
Question:{{p_question}}
Choices:{{p_choice}}<end>
""")

decoder_in = Template("""This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.

Perturbed doc and rationale:
Question:{{p_question}}
Choices:{{p_choice}}
Rationale:{{p_rationale}}

Original Doc:
Question:{{question}}
Choices:{{choices.text}}

Recover Rationale:
""")

decoder_out = Template("""{{rationale}}<end>""")


for i in ds:
    a = {}
    a['encoder_in'] = encoder_prompt.render(i)
    a['encoder_out'] = encoder_out.render(i)
    a['decoder_in'] = decoder_in.render(i)
    a['decoder_out'] = decoder_out.render(i)
    train_data.append(a)

import torch
torch.save(train_data, './slm_ed_train_data.pkl')

## Train Script

The key step: preparing data is now done. Then we can train a SLM model using the train data. You can use following dataset&trainer class to train an encoder-decoder slm model. Here we use Qwen-0.5B as the example.

In [51]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [52]:
model = AutoModelForCausalLM.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/').half().cuda()

In [75]:
from torch.utils.data import Dataset

class EDDataset(Dataset):

    def __init__(self, tokenizer, train_data, max_input_length=64, max_target_length=64):
        self.tokenizer = tokenizer
        self.dataset = train_data
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
        self.max_seq_length = max_input_length + max_target_length + 1

    def get_str_item(self, i) -> dict:

        data_item = self.dataset[i]
        ret_dict = {
            'encoder':{
                'input': data_item['encoder_in'],
                'output': data_item['encoder_out']
            },
            'decoder':{
                'input': data_item['decoder_in'],
                'output': data_item['decoder_out']
            }
        }
        return ret_dict

    def _process_item(self, data_item):

        a_ids = self.tokenizer.encode(text=data_item['input'], add_special_tokens=True, truncation=True,
                                      max_length=self.max_input_length)
        b_ids = self.tokenizer.encode(text=data_item['output'], add_special_tokens=False, truncation=True,
                                      max_length=self.max_target_length)
        context_length = len(a_ids)
        input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
        labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]
        pad_len = self.max_seq_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
        labels = labels + [self.tokenizer.pad_token_id] * pad_len
        labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]

        assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}"

        return {
            "input_ids": input_ids,
            "labels": labels
        }

    def get_tokenized_item(self, i) -> dict:   

        str_item = self.get_str_item(i)
        ret_dict = {
            'encoder': self._process_item(str_item['encoder']),
            'docoder': self._process_item(str_item['decoder'])
        }
        return ret_dict

    def __getitem__(self, i) -> dict:
        item = self.get_tokenized_item(i)
        return item

In [76]:
train_ds = EDDataset(AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/'), train_data)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
print(train_ds.get_str_item(0))
print(train_ds[0])

In [82]:
from fate_llm.algo.pdss.slm_encoder_decoder_trainer import EncoderDecoderPrefixTrainer, EDPrefixDataCollator

After completing the setup, you can utilize the EncoderDecoderPrefixTrainer, EDPrefixDataCollator, and the training dataset to train an SLM encoder-decoder model following the Huggingface approach! 