In [1]:
from datasets import load_dataset

dataset = load_dataset("shangrilar/ko_text2sql", "origin")
dataset

DatasetDict({
    train: Dataset({
        features: ['db_id', 'context', 'question', 'answer'],
        num_rows: 38246
    })
    test: Dataset({
        features: ['db_id', 'context', 'question', 'answer'],
        num_rows: 112
    })
})

In [2]:
def make_prompt(ddl, question, query=""):
    prompt=f"""당신은 SQL을 생성하는 SQL 봇입니다. DDL의 테이블을 활용한 Question을 해결할 수 있는 SQL을 생성해주세요.
    
    ### DDL:
    {ddl}

    ### Question:
    {question}

    ### SQL:
    {query}
    """
    return prompt

In [6]:
# 6.3.1. 기초모델 평가하기

import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

def make_inference_pipeline(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    return pipe


In [8]:
model_name = "beomi/Yi-Ko-6B"

hf_pipe = make_inference_pipeline(model_name)

example = """당신은 SQL을 생성하는 SQL 봇입니다. DDL의 테이블을 활용한 Question을 해결할 수 있는 SQL을 생성해주세요.
    
    ### DDL:
    CREATE TABLE players (
        player_id INT PRIMARY KEY, AUTO_INCREMENT,
        username VARCHAR(255) UNIQUE NOT NULL,
        email VARCHAR(255) UNIQUE NOT NULL,
        password_hash VARCHAR(255) NOT NULL,
        date_joined DATETIME NOT NULL,
        last_login DATETIME NOT NULL
    );

    ### Question:
    사용자의 이름에 'admin'이 포함되어 있는 계정의 수를 알려주세요.

    ### SQL:
    """

results = hf_pipe(example, do_sample=False, return_full_text=False, max_length=1024, truncation=True)
import json
print(json.dumps(results, indent=2, ensure_ascii=False))


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

[
  {
    "generated_text": "SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FROM players WHERE username LIKE '%admin%';\n\n    ### SQL 봇의 결과:\n    SELECT COUNT(*) FR