In [1]:
!export TOKENIZERS_PARALLELISM=false
!python -m pip install --upgrade pip
!pip install typing_extensions pydantic openai
!pip install datasets transformers peft trl bitsandbytes

[0m

In [2]:
# 학습 모델 로드 및 토큰 호출 한번더
from peft import PeftModel
from transformers import AutoModelForCausalLM,AutoTokenizer

base_model = AutoModelForCausalLM.from_pretrained("google/gemma-3-12b-it", torch_dtype="auto", device_map="cuda")
model = PeftModel.from_pretrained(base_model, "ohdyo/trip_qa_gem13b_v2")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it",add_bos=True)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [3]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
!pip install flask --ignore-installed blinker

Collecting flask
  Using cached flask-3.1.0-py3-none-any.whl.metadata (2.7 kB)
Collecting blinker
  Using cached blinker-1.9.0-py3-none-any.whl.metadata (1.6 kB)
Collecting Werkzeug>=3.1 (from flask)
  Using cached werkzeug-3.1.3-py3-none-any.whl.metadata (3.7 kB)
Collecting Jinja2>=3.1.2 (from flask)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting itsdangerous>=2.2 (from flask)
  Using cached itsdangerous-2.2.0-py3-none-any.whl.metadata (1.9 kB)
Collecting click>=8.1.3 (from flask)
  Using cached click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting MarkupSafe>=2.0 (from Jinja2>=3.1.2->flask)
  Using cached MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB)
Using cached flask-3.1.0-py3-none-any.whl (102 kB)
Using cached blinker-1.9.0-py3-none-any.whl (8.5 kB)
Using cached click-8.1.8-py3-none-any.whl (98 kB)
Using cached itsdangerous-2.2.0-py3-none-any.whl (16 kB)
Using cached jinja2-3.1.6-py3-none-any.whl (134 k

In [5]:
from flask import Flask, request, jsonify
app = Flask(__name__)

In [21]:
chat_history = []

# 프롬프트 생성 함수
def make_prompt(history, new_question):
    system_prompt = """너는 항공기 반입 금지 물품 및 여행 정보를 전문적으로 알려주는 비서야.
답변이 없는 질문에 대해서 답변해줘. 추가 질문은 만들지마"""
    
    history_prompt = ""
    for qa in history:
        history_prompt += f"### 질문:\n{qa['question']}\n\n### 답변:\n{qa['answer']}\n\n"
    
    history_prompt += f"### 질문:\n{new_question}\n\n### 답변:"
    return system_prompt + history_prompt

# 모델 질의 함수
def ask_with_generate(new_question, max_new_tokens=200):
    prompt = make_prompt(chat_history,new_question)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    answer = generated_text.replace(prompt, "").strip()
    # 가끔씩 질문이 뻑나는 경우 이렇게 함
    if "### 질문:" in answer:
        answer = answer.split("###")[0].strip()

    chat_history.append({"question": new_question, "answer": answer})
    
    return answer

In [25]:
@app.route('/refresh', methods=['GET'])
def refresh_history():
    global chat_history
    chat_history = []
    return jsonify({"message": "대화 히스토리가 초기화되었습니다."})

54.180.8.8 - - [22/Apr/2025 01:23:31] "POST /input HTTP/1.1" 200 -


In [28]:
@app.route('/refresh', methods=['POST'])
def add_history():
    global chat_history
    try:
        new_history = request.get_json()

        # 나라 변수 저장
        country = new_history.get('country', '')
        # 메세지 저장
        messages = new_history.get('messages', [])

        # 하나의 질문-답변 쌍 저장 변수
        current_item = {}
        
        # 히스토리 초기화
        chat_history = []

        for message in messages:
            role = message.get('role')
            content = message.get('message')

            if role == "user":
                # 새 질문 시작
                current_item = {"question": content}
            elif role == "assistant" and "question" in current_item:
                # 현재 질문에 답변을 달고 리스트에 추가
                current_item["answer"] = content
                chat_history.append(current_item)
                current_item = {}  # 초기화

        return jsonify({
            "message": f"{len(chat_history)}개의 대화가 불러와졌습니다.",
            "country" : country,
            "history": chat_history
        })

    except Exception as e:
        return jsonify({"error": str(e)}), 500

In [7]:
@app.route('/input', methods=['POST'])
def handle_input():
    try:
        data = request.get_json()
        question = data.get("question", "")
        if not question:
            return jsonify({"error": "질문이 비어있습니다."}), 400

        result = ask_with_generate(new_question=question)
        return jsonify({"question": question, "answer": result})
    except Exception as e:
        return jsonify({"error": str(e)}), 500

In [None]:
app.run(host='0.0.0.0', port=8080)

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:8080
 * Running on http://192.168.4.2:8080
[33mPress CTRL+C to quit[0m
222.112.208.69 - - [22/Apr/2025 01:58:09] "POST /refresh HTTP/1.1" 200 -
222.112.208.69 - - [22/Apr/2025 01:58:42] "POST /refresh HTTP/1.1" 200 -
222.112.208.69 - - [22/Apr/2025 01:58:52] "[35m[1mPOST /refresh HTTP/1.1[0m" 500 -
222.112.208.69 - - [22/Apr/2025 01:59:07] "POST /refresh HTTP/1.1" 200 -
222.112.208.69 - - [22/Apr/2025 01:59:44] "POST /refresh HTTP/1.1" 200 -
222.112.208.69 - - [22/Apr/2025 02:00:30] "POST /input HTTP/1.1" 200 -
222.112.208.69 - - [22/Apr/2025 02:01:26] "GET /refresh HTTP/1.1" 200 -
222.112.208.69 - - [22/Apr/2025 02:01:47] "POST /input HTTP/1.1" 200 -
