In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch

model_id = "rtzr/ko-gemma-2-9b-it"
#model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=torch.bfloat16
    )
tokenizer = AutoTokenizer.from_pretrained(model_id)
terminators = [
    tokenizer.eos_token_id,
   # tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
model = AutoModelForCausalLM.from_pretrained(model_id, 
                                             quantization_config=bnb_config,
                                             low_cpu_mem_usage=True)
pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    eos_token_id = terminators,
    #pad_token_id = tokenizer.eos_token_id
)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 10/10 [00:18<00:00,  1.82s/it]


In [2]:
from faiss_module import  make_db, make_fewshot_db
import pandas as pd
from utils_module import make_dict, format_docs

train_df = pd.read_csv('train.csv')
train_db = make_db(train_df, './train_faiss_db')
# train_dict = make_dict('train.csv')
test_df = pd.read_csv('test.csv')
test_db = make_db(test_df, './test_faiss_db')
dataset = make_dict('test.csv')
fewshot_db = make_fewshot_db(train_df, './fewshot_faiss_db')

Loading FAISS DB from: ./train_faiss_db
Loading FAISS DB from: ./test_faiss_db
Loading FAISS DB from: ./fewshot_faiss_db


In [3]:
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

In [4]:
from tqdm import tqdm
train_retriever = train_db.as_retriever(search_type="similarity_score_threshold",
                search_kwargs={'score_threshold': 0.77,'k':1})
test_retriver = test_db.as_retriever(search_type="similarity_score_threshold",
                search_kwargs={'score_threshold': 0.77,'k':2})

fewshot_num = 3
example_selector = SemanticSimilarityExampleSelector(
        vectorstore=fewshot_db,
        k=fewshot_num,
        
    )
results = []
pipeline.model.eval()
for i in tqdm(range(len(dataset))):
    messages = [
        {"role": "system", "content": """You are the financial expert who helps me with my financial information Q&As.
    You earn 10 points when you answer me and follow the rules and lose 7 points when you don't.

    12,500 백만원 = 125 억원 = 12,500,000,000 원
    5,400 백만원 = 54 억원 = 5,400,000,000 원

    Here are some rules you should follow.
    - Please use contexts to answer the question.
    - Please your answers should be concise.
    - Please answers must be written in Korean.
    - Please answer the question in 1-3 sentences.

    - Use the three examples below to learn how to follow the rules and reference information in context.
        """}
    ]
    exs = example_selector.select_examples({'Question': dataset[i]['Question']}) # buff['Question']에 해당하는 fewshot_num개의 문서를 선택 리턴: [{'Question': '질문', 'Answer': '답변'}]
    for i, ex in enumerate(exs):
        retrieved_docs = train_retriever.invoke(ex['Question'])
        if train_db is not None and len(retrieved_docs) > 0:
            messages.append({"role": "user", "content": f"{ex['Question']}\n\n{format_docs(retrieved_docs)}"})
        else:
            messages.append({"role": "user", "content": f"{ex['Question']}"})
        messages.append({"role": "assistant", "content": f"{ex['Answer']}"})
        
    retrieved_docs = test_retriver.invoke(dataset[i]['Question'])
    messages.append({"role": "system", "content": f"Now do it for real."})
    messages.append({"role": "user", "content": f"{dataset[i]['Question']}\n\n{format_docs(retrieved_docs)}"})
    prompt = pipeline.tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
    )
    outputs = pipeline(
    prompt,
    )
    results.append({
            "Question": dataset[i]['Question'],
            "Answer": outputs[-1]["generated_text"][-1]['content'],
            "Source": dataset[i]['Source']
            })
    print(results[-1]['Question'])
    print(results[-1]['Answer'])
    

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [None]:
print(outputs[-1]["generated_text"][-1])

{'role': 'assistant', 'content': "혁신창업사업화자금(융자) 사업은 '중소기업진흥에 관한 법률 제66조, 제66조, 제67조, 제74조'에 근거하고 있습니다. \n\n\n"}
