# <center>使用LoRa低资源指令微调Llama(中文)</center>

## 载入模型

In [None]:
import os
import warnings
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
warnings.filterwarnings('ignore')

In [None]:
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig

In [None]:
model = LlamaForCausalLM.from_pretrained(
    '../llama-13b',
    device_map='auto',
    load_in_8bit=True,
    torch_dtype=torch.float16
)

In [None]:
tokenizer = LlamaTokenizer.from_pretrained('../llama-13b')

In [None]:
# copied from fastchat/train.py
def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model):
    """Resize tokenizer and embedding.
    如果更改了词表，则重新更改词表和tokenizer的词表尺寸，新添加的词表embedding
    用之前词表的embedding均值表示

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

# 2023.04.06 add pad token and resize embedding
smart_tokenizer_and_embedding_resize(
    special_tokens_dict=dict(pad_token='[PAD]'),
    tokenizer=tokenizer,
    model=model,
)
# add special tokens
add_token = "</s>"
tokenizer.add_special_tokens({
    "eos_token": add_token,
    "bos_token": add_token,
    "unk_token": add_token,
})

In [None]:
len(tokenizer)

## 处理对话类数据(使用非对话类数据时此节不用执行)

In [None]:
import pandas as pd
from datasets import Dataset
df = pd.read_json('../datasets/sg_90k_part1_html_cleaned.json')
df = df[:10000]
df

In [None]:
# 2023.04.13 用于对话的有监督训练数据的处理
import copy
from dataclasses import dataclass

@dataclass
class Conversation:
    '''
    多轮对话数据集类
    '''
    def __init__(self, tokenizer, max_length=512):
        self.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.</s>"
        self.sep = '</s>'
        self.r1 = 'Human: '
        self.r2 = 'Assistant: '
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.sep_id = self.get_sep_id()
    
    
    def get_sep_id(self):
        '''
        获取unmask的开始和结束ids
        '''
        sep_id = self.tokenizer(self.sep).input_ids[1]
        return sep_id
        
    
    def generate_conversation_prompt(self, example):
        '''
        拼接对话数据集
        '''
        conversation = f'{self.prompt}'
        for idx, content in enumerate(example):
            if idx == 0 and content['from'].lower() != 'human':
                conversation = ''
                break
            if content['from'].lower() == 'human':
                sentence = self.r1 + content['value'] + self.sep
                conversation += sentence
            elif content['from'].lower() == 'gpt':
                sentence = self.r2 + content['value'] + self.sep
                conversation += sentence
            else:
                conversation = ''
                break
        return conversation
    
    
    def preprocess(self, examples):
        '''
        有监督对话数据预处理
        TODO: 是否有更优雅的方式处理？
        '''
        inputs = [
            tokenizer(
                ex,
                return_tensors='pt',
                max_length=self.max_length,
                padding='max_length',
                truncation=True
                
            ) 
            for ex in examples['conversations']
        ]
        input_ids = [i.input_ids[0] for i in inputs]
        attention_mask = [i.attention_mask[0] for i in inputs]
        labels = copy.deepcopy(input_ids)
        sep_idxs = [torch.where(label==self.sep_id)[0].tolist() for label in labels]
        for sep_idx, label in zip(sep_idxs, labels):
            if len(sep_idx)<3:
                continue
            label[:sep_idx[1] + 1] = -100
            cur_len = sep_idx[1]
            count = 3
            for idx in sep_idx[2:]:
                if count % 2 != 0:
                    cur_len = idx
                else:
                    label[cur_len+1: idx+1] = -100
                count += 1
        return dict(
            input_ids = input_ids,
            labels = labels,
            attention_mask = attention_mask
        )
                

In [None]:
conv = Conversation(tokenizer, max_length=512)
df['conversations'] = df['conversations'].map(conv.generate_conversation_prompt)
df['conversations'][98]

In [None]:
data = Dataset.from_pandas(df)

In [None]:
data = data.train_test_split(train_size=0.9, shuffle=True, seed=42)

In [None]:
data = data.map(
    conv.preprocess,
    batched=True,
    batch_size=1000
)
data

In [None]:
train_data = data['train']
val_data = data['test']

## 处理数据(使用对话类数据时此节不用执行)

In [None]:
import pandas as pd
from datasets import Dataset
from datasets import load_dataset

In [None]:
# 读取alpaca类数据集
df = pd.read_json('../datasets/goat_50k.json')
data = Dataset.from_pandas(df)
data = data.train_test_split(train_size=0.9, shuffle=True, seed=42)
data

In [None]:
# 2023.04.04 用于有监督训练数据的处理
def generate_alpaca_prompt(example):
    '''
    生成中文alpaca类数据集的prompt
    '''
    if example['input']:
        source = f"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: {example['instruction']}\n{example['input']}\n\n### Assistant: "
        target = f'{example["output"]}'
        return dict(example=(source + target, source))
    else:
        source = f"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: {example['instruction']}\n\n### Assistant: "
        target = f'{example["output"]}'
        return dict(example=(source + target, source))

In [None]:
data = data.map(lambda x: generate_alpaca_prompt(x))

In [None]:
import copy
ignore_index = -100
def preprocess(examples):
    '''
    tokenize inputs和labels，同时mask标签(labels)中的inputs部分
    '''
    tokenized = [tokenizer(
        example,
        return_tensors='pt',
        max_length=512,
        padding='max_length',
        truncation=True
    ) for example in examples['example']]
    input_ids = [t.input_ids[0] for t in tokenized]
    attention_mask = [t.attention_mask[0] for t in tokenized]
    labels = copy.deepcopy(input_ids)
    source_input_ids_lens = [t.input_ids[1].ne(tokenizer.pad_token_id).sum().item() for t in tokenized]
    for label, source_len in zip(labels, source_input_ids_lens):
        label[:source_len] = -100
    return dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=labels
    )

In [None]:
data = data.map(
    preprocess,
    batched=True,
    batch_size=1000
)
data

In [None]:
train_data = data['train']
val_data = data['test']
# train_data.set_format(type='torch', columns=['input_ids', 'labels', 'attention_mask'])
# val_data.set_format(type='torch', columns=['input_ids', 'labels', 'attention_mask'])
train_data[0]['example']

In [None]:
tokenizer.batch_decode([train_data[0]['input_ids']])

## 训练模型

In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict

In [None]:
trainArgs = TrainingArguments(
    output_dir= '../ckps',
    do_train=True,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=4,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    eval_steps=100,
    logging_steps=10,
    warmup_steps=100,
    num_train_epochs=5,
    learning_rate=3e-4,
    fp16=True,
    push_to_hub=False,
    load_best_model_at_end=True,
)

In [None]:
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

In [None]:
trainer = Trainer(
    model=model,
    args=trainArgs,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))

In [None]:
trainer.train()

In [None]:
from peft import get_peft_model_state_dict
model.save_pretrained('../ckps/GOAT_002')

## 测试模型

In [None]:
text = {
    "instruction": "介绍一下中国的首都",
    "input": "",
    "output": ""
}
text = generate_alpaca_prompt(text)['example'][0]
text

In [None]:
inputs = tokenizer(text, return_tensors='pt')
input_ids = inputs['input_ids'].to('cuda:0')

In [None]:
generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.7,
        top_k=40,
        num_beams=4
    )

In [None]:
model.eval()
with torch.no_grad():
    preds = model.generate(
        input_ids=input_ids,
        max_new_tokens=256,
        generation_config=generation_config,
        repetition_penalty=2.0
    )

In [None]:
output = tokenizer.batch_decode(preds)
output