In [None]:
system_message = """당신은 검색 결과를 바탕으로 질문에 답변해야 합니다.

다음의 지시사항을 따르십시오.

1. 질문과 검색 결과를 바탕으로 답변하십시오.
2. 검색 결과에 없는 내용을 답변하려고 하지 마십시오.
3. 질문에 대한 답이 검색 결과에 없다면 검색 결과에는 "해당 질문 ~에 대한 내용이 없습니다."ㅁ라고 답변하십시오.
4. 답변할 때 특정 문서를 참고하여 문장 또는 문단을 작성했다면 뒤에 출처는 이중 리스트로 해당 문서 번호를 남기십시오.
예를 들어서 특정 문장이나 문단을 1번 문서에서 인용했다면 뒤에 [[ref11]]이라고 기재하십시오.
5. 예를 들어서 특정 문장이나 문단을 1번 문서와 5번 문서에서 동시에 인용했다면 뒤에 [[ref1]], [[ref5]]이라고 기재하십시오.
6. 최대한 다수의 문서를 인용하여 답변하십시오.

검색 결과 :
----
{search_result}"""

In [None]:
!pip install torch==2.4.0 transformers==4.45.1 datasets==3.0.1 accelerate==0.34.2 trl==0.11.1 peft==0.13.0

In [None]:
import torch

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

In [None]:
dataset = load_dataset("iamjoon/klue-mrc-ko-rag-dataset", split = "train")

In [None]:
print(dataset)

In [None]:
print("원본 데이터의 type 분포 :")
for type_name in set(dataset["type"]):
    print(f"{type_name} : {dataset['type'].count(type_name)}")

In [None]:
ratio = 0.8
train_data = []
test_data = []

for type_name in set(dataset["type"]):
    curr_type = [i for i in range(len(dataset)) if dataset[i]["type"] == type_name]

    test_size = int(len(curr_type) * ratio)

    train_data.extend(curr_type[:test_size])
    test_data.extend(curr_type[test_size:])

In [None]:
print(len(train_data))
print(len(test_data))

In [None]:
print(dataset)

In [None]:
def format_data(sample):
    search_result = "\n-----\n".join([f"문서 {idx + 1} : {result}" for idx, result in enumerate(sample["search_result"])])

    return {
        "messages" : [
            {"role" : "system", "content" : system_message.format(search_result = search_result)},
            {"role" : "user", "content" : sample["question"]},
            {"role" : "assistant", "content" : sample["answer"]}
        ]
    }

In [None]:
train_dataset = [format_data(dataset[i]) for i in train_data]
test_dataset = [format_data(dataset[i]) for i in test_data]

In [None]:
print(type(train_dataset))
print(type(test_dataset))

In [None]:
train_dataset = Dataset.from_list(train_dataset)
test_dataset = Dataset.from_list(test_dataset)

print(type(train_dataset))
print(type(test_dataset))

In [None]:
model_name = "Qwen/Qwen2-7B-Instruct"

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map = "auto",
                                             torch_dtype = torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
peft_config = LoraConfig(
    lora_alpha = 32,
    lora_dropout = 0.1,
    r = 8,
    bias = "none",
    target_modules = ["q_proj", "v_proj"],
    task_type = "CAUSAL_LM"
)

In [None]:
args = SFTConfig(
    output_dir = "qwen2-7b-rag-ko",
    num_train_epochs = 3,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 2,
    gradient_checkpointing = True,
    optim = "adamw_torch_fused",
    logging_steps = 10,
    save_strategy = "steps",
    save_steps = 50,
    bf16 = True,
    learning_rate = 1e-4,
    max_grad_norm = 0.3,
    warmup_ratio = 0.03,
    lr_scheduler_type = "constant",
    push_to_hub = False,
    remove_unused_columns = False,
    dataset_kwargs = {"skip_prepare_dataset" : True},
    report_to = None
)

In [None]:
def collate_fn(batch):
    new_batch = {
        "input_ids" : [],
        "attention_mask" : [],
        "labels" : []
    }

    for example in batch:
        clean_messages = []

        for mag in example["messages"]:
            clean_message = {
                "role" : msg["role"],
                "content" : msg["content"]
            }
            clean_messages.append(clean_message)

        text = tokenizer.apply_chat_template(
            clean_messages,
            tokenize = False,
            add_generation_prompt = False
        ).strip()

        tokenized = tokenizer(
            text,
            truncation = True,
            max_length = max_seq_length,
            padding = False,
            return_tensors = None
        )

        input_ids = tokenized["input_ids"]
        attention_mask = tokenized["attention_mask"]
        labels = [-100] * len(input_ids)

        im_start = "<|im_start|>"
        im_end = "<|im_end|>"
        assistant = "assistant"

        im_start_token = tokenizer.encode(im_start, add_special_tokens = False)
        im_end_token = tokenizer.encode(im_end, add_special_tokens = False)
        assistant_token = tokenizer.encode(assistant, add_special_tokens = False)

        i = 0
        while i <= len(input_ids) - len(im_start_token):
            if input_ids[i : i + len(im_start_token)] == im_start_token:

                assistant_pos = i + len(im_start_token)
                if (assistant_pos <+ len(input_ids) - len(assistant_token)) and (input_ids[assistant_pos : assistant_pos + len(assistant_token)] == assistant_token):

                    current_pos = assistant_pos + len(assistant_token)
                    while current_pos <= len(input_ids) - len(im_end_token):
                        if input_ids[current_pos : current_pos + len(im_end_token)] == im_end_token:
                            for j in range(len(im_end_token)):
                                labels[current_pos + j] = input_ids[current_token + j]
                            break

                        labels[current_pos] = input_ids[current_pos]
                        current_pos += 1

                    i = current_pos

            i += 1

        new_batch["input_ids"].append(input_ids)
        new_batch["attention_mask"].append(attention_mask)
        new_batch["labels"].append(labels)

    max_length = max(len(ids) for ids in new_batch["input_ids"])
    for i in range(len(new_batch["input_ids"])):
        pad_len = max_length - len(new_batch["input_ids"][i])
        new_batch["input_ids"][i].extend([tokenizer.pad_token_id] * pad_len)
        new_batch["attention_mask"][i].extend([0] * pad_len)
        new_batch["labels"][i].extend([-100] * pad_len)

    for k, v in new_batch.items():
        new_batch[k] = torch.tensor(v)

    return new_batch

In [None]:
trainer = SFTTrainer(
    model = model,
    args = args,
    max_seq_length = max_seq_length,
    train_dataset = train_dataset,
    data_collator = collate_fn,
    peft_config = peft_config
)

trainer.train()

trainer.save_model()

In [None]:
prompt_list = []
label_list = []

for prompt in test_dataset["messages"]:
    text = tokenizer.apply_chat_template(
        prompt, tokenize = False, add_generation_prompt = False
    )
    input = text.split("<|im_start|>assistant")[0] + "<|im_start|>assistant"
    label = text.split("<|im_start|>assistant")[1]

    prompt_list.append(input)
    label_list.append(label)

In [None]:
def test_inference(pipe, prompt):
    outputs = pipe(prompt, max_new_tokens = 1024, eos_token_id = eos_token, do_sample = False)
    return outputs[0]["generated_text"][len(prompt) : ].strip()

In [None]:
base_model_name = "Qwen/Qwen2-7B-Instruct"

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    device_map = "auto",
    torch_dtype = torch.bfloat16
)

base_pipe = pipeline("text-generation", model = base_model, tokenizer = tokenizer)
eos_token = tokenizer("<|im_end|>", add_special_tokens = False)["input_ids"][0]

In [None]:
base_prompt = prompt_list[42]
base_label = label_list[42]

print(f"모델의 예측 : \n {test_inference(base_pipe, base_prompt)}")
print(f"정답 : \n {base_label}")

In [None]:
peft_model_name = "qwen2-7b-rag-ko/checkpoint-285"

peft_model = AutoModelForCausalLM.from_pretrained(
    peft_model_name,
    device_map = "auto",
    torch_dtype = torch.bfloat16
)

peft_pipe = pipeline("text-generation", model = peft_model, tokenizer = tokenizer)

In [None]:
peft_prompt = prompt_list[42]
peft_label = label_list[42]

print(f"모델의 예측 : \n {test_inference(peft_pipe, peft_prompt)}")
print(f"정답 : \n {peft_label}")