In [None]:
import argparse
import json
import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from src.data import CustomDataset
from peft import PeftModel

MODEL_ID = "MODEL_ID"
DEVICE = "cuda:0"
adapter_checkpoint_path = ""
tokenizer = None

In [None]:
# test dataset 중 가장 길이가 긴 것 + 짧은 것 + 중간 값
import json
import numpy as np

utterance_length = []

def make_data(data_path):
    global utterance_length
    with open(data_path, "r") as f:
        data = json.load(f)

    text_list = []
    def make_chat(inp):
        chat = ""
        for cvt in inp['conversation']:
            chat += cvt['utterance']
        return chat

    for example in data:
        total_chat = make_chat(example['input'])
        text_list.append(total_chat)
        utterance_length.append(len(total_chat))
    
    text = ' '.join(text_list)
    return text

In [None]:
utterance_length = []
make_data("/mnt/c/Users/hwyew/Downloads/korean_dialogue/korean_dialog/resource/data/test.json")

In [None]:
utterance_length.index(max(utterance_length))

In [None]:
utterance_length.index(min(utterance_length))

In [None]:
utterance_length.index(round(utterance_length.mean(), 1))

In [None]:
test_index = [utterance_length.index(min(utterance_length)), utterance_length.index(round(utterance_length.mean(), 1)), utterance_length.index(max(utterance_length))]

In [None]:
model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map=DEVICE,
        low_cpu_mem_usage=True
)

model = PeftModel.from_pretrained(model, adapter_checkpoint_path)
model = model.merge_and_unload()
model.to(dtype = torch.bfloat16)
model.eval()

if tokenizer == None:
        tokenizer = MODEL_ID
    
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
tokenizer.pad_token = tokenizer.eos_token
terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

dataset = CustomDataset("resource/data/일상대화요약_test.json", tokenizer)
dataset = dataset.loc[test_index]

In [None]:
with open("resource/data/일상대화요약_test.json", "r") as f:
    result = json.load(f)

In [None]:
top_p = 0.95; top_k = 50; temperature = 1

In [None]:
for idx in tqdm.tqdm(range(len(dataset))):
    inp = dataset[idx]
    outputs = model.generate(
        inp.to(DEVICE).unsqueeze(0),
        max_new_tokens=1024,
        eos_token_id=terminators,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_p = top_p,
        top_k = top_k,
        temperature = temperature
    )
    result[idx]["input"] = dataset[idx]
    result[idx]["output"] = tokenizer.decode(outputs[0][inp.shape[-1]:], skip_special_tokens=True)

with open(f'results/test_top_p_{top_p}_top_k_{top_k}_temperature_{temperature}.json', "w", encoding="utf-8") as f:
    f.write(json.dumps(result, ensure_ascii=False, indent=4))