In [18]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [19]:
import sys
sys.path.insert(0, '/content/drive/My Drive/trl')

In [20]:
from trl import GRPOTrainer, GRPOConfig

In [21]:
!pip install "vllm==0.10.2"
!pip install -U bitsandbytes



In [22]:
import json
import os
import torch
import psycopg2
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainerCallback,
    TrainerState,
    TrainerControl,
)
from peft import LoraConfig, get_peft_model
import re
from typing import List, Sequence, Dict, Any, Tuple
import warnings
import numpy as np
warnings.filterwarnings('ignore')


In [23]:
class ConsoleLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            print(f"Step {state.global_step}: {logs}")

In [24]:
grammar = r"""
root ::= hint_comment

# /*+ SeqScan(t) IndexScan(t idx) Leading((t1 t2) t3) */
hint_comment ::= "/*+" wsp? hint_list wsp? "*/"

hint_list ::= hint (wsp hint)*

hint ::= scan_hint
       | join_hint
       | leading_hint

# --- Scan hints --------------------------------------------------------

scan_hint ::=
      "SeqScan" "(" wsp? rel_name wsp? ")"
    | "IndexScan" "(" wsp? rel_name (wsp index_name)* wsp? ")"
    | "IndexOnlyScan" "(" wsp? rel_name (wsp index_name)* wsp? ")"
    | "BitmapScan" "(" wsp? rel_name (wsp index_name)* wsp? ")"
    | "NoSeqScan" "(" wsp? rel_name wsp? ")"
    | "NoIndexScan" "(" wsp? rel_name wsp? ")"
    | "NoIndexOnlyScan" "(" wsp? rel_name wsp? ")"
    | "NoBitmapScan" "(" wsp? rel_name wsp? ")"

# --- Join method hints -------------------------------------------------

join_hint ::= join_method_hint | no_join_method_hint

join_method_hint ::=
      "NestLoop"  "(" wsp? join_rel_list wsp? ")"
    | "HashJoin"  "(" wsp? join_rel_list wsp? ")"
    | "MergeJoin" "(" wsp? join_rel_list wsp? ")"

no_join_method_hint ::=
      "NoNestLoop"  "(" wsp? join_rel_list wsp? ")"
    | "NoHashJoin"  "(" wsp? join_rel_list wsp? ")"
    | "NoMergeJoin" "(" wsp? join_rel_list wsp? ")"

# минимум две таблицы, дальше [ table... ]
join_rel_list ::= rel_name wsp rel_name (wsp rel_name)*

# --- Leading -----------------------------------------------------------

# Две формы:
#  1) Leading(a b c)
#  2) Leading((a b) c) и рекурсивно вложенные пары
leading_hint ::=
      "Leading" "(" wsp leading_list wsp ")"
    | "Leading" "(" wsp join_member wsp ")"

# Простая линейная форма: Leading(t1 t2 [t3...])
leading_list ::= rel_name wsp rel_name (wsp rel_name)*

# Рекурсивный join-pair:
# join_member -> "t" | "(" join_member join_member ")"
join_member ::=
      rel_name
    | "(" wsp? join_member wsp join_member wsp? ")"

# --- Idents, whitespace ------------------------------------------------

rel_name   ::= ident
index_name ::= ident

ident ::= ident_start ident_part*
ident_start ::= [A-Za-z_]
ident_part  ::= [A-Za-z0-9_$.]

# один пробельный символ
ws ::= " " | "\t" | "\n" | "\r"

# один или больше пробельных символов
wsp ::= ws ws*

"""

In [25]:
SYSTEM_PROMPT = """You are an expert PostgreSQL query optimizer specialized in generating pg_hint_plan commands.

Your task: Analyze SQL queries with database statistics and generate optimal hint commands.

Available hints:
- Scan methods: SeqScan(table), IndexScan(table), BitmapScan(table), TidScan(table)
- Join methods: NestLoop(t1 t2), HashJoin(t1 t2), MergeJoin(t1 t2)
- Join order: Leading((t1 t2) t3)
- Row estimates: Rows(table_name #rows)

Output: Generate ONLY valid pg_hint_plan() hints, one per line. No explanations."""

USER_PROMPT_TEMPLATE = """Query: {query}

Cardinalities: {card_tb}
Statistics (NDV): {ndv}
Frequent values: {main_value}
Ranges: {min_max}

Output: One hint per line, no explanation."""


In [26]:
DB_CONFIG = {
    "host": "91.219.226.145",
    "port": 5432,
    "database": "testdb",
    "user": "admin",
    "password": "superadmin"
}

In [27]:
def format_ndv_stats(ndv_string):
    lines = ndv_string.strip().split('\n')
    formatted = [f"  - {x.replace(' : ', ': ')}" for x in lines if ' : ' in x]
    return '\n'.join(formatted)

def format_main_values(main_value_string):
    lines = main_value_string.strip().split('\n')
    formatted = [f"  - {x.replace(' : ', ': ')}" for x in lines if ' : ' in x]
    return '\n'.join(formatted)

def format_min_max(min_max_string):
    lines = min_max_string.strip().split('\n')
    formatted = [f"  - {x.replace(' : ', ': ')}" for x in lines if ' : ' in x]
    return '\n'.join(formatted)

In [28]:
def prepare_grpo_dataset(json_path: str, tokenizer) -> Dataset:

    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    data = data[63:]
    prompts = []
    queries = []
    hints_gt = []

    for record in data:

        system_content = SYSTEM_PROMPT.strip()
        user_content = USER_PROMPT_TEMPLATE.format(
            query=record['query'],
            card_tb=record['Card_Tb'],
            ndv=format_ndv_stats(record['NDV']),
            main_value=format_main_values(record['Main_Value']),
            min_max=format_min_max(record['Min_Max'])
        ).strip()

        messages = [
            {"role": "system", "content": system_content},
            {"role": "user", "content": user_content},
        ]

        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        prompts.append(prompt)
        queries.append(record['query'])
        hints_gt.append(record['best_hints'])

    return Dataset.from_dict({
        'prompt': prompts,
        'query': queries,
        'hints_gt': hints_gt
    })


In [29]:
import json
from datetime import datetime

def extract_hints_from_completion(completion: str) -> str:
    debug_entry = {
        "timestamp": datetime.now().isoformat(),
        "completion_raw": completion,
        "completion_length": len(completion),
        "has_assistant": 'assistant' in completion.lower(),
    }

    with open("completions_debug.jsonl", 'a', encoding='utf-8') as f:
        json.dump(debug_entry, f, ensure_ascii=False)
        f.write('\n')


    if 'assistant' in completion.lower():
        parts = completion.split('assistant')
        if len(parts) > 1:
            hints_text = parts[-1].strip()
        else:
            hints_text = completion
    else:
        hints_text = completion

    hints_text = hints_text.replace('<|im_end|>', '').replace('<|endoftext|>', '').strip()

    return hints_text

In [30]:
def get_db_connection():
    return psycopg2.connect(**DB_CONFIG)

def execute_query_with_hints(query: str, hints: str = None, cursor=None) -> tuple:
    if cursor is None:
        return False, float('inf'), "cursor is None"

    try:
        try:
            cursor.execute("LOAD 'pg_hint_plan';")
        except:
            pass

        if hints and hints.strip():
            full_query = f"/*+ {hints} */ {query}"
        else:
            full_query = query

        cursor.execute("EXPLAIN (ANALYZE, FORMAT JSON) " + full_query)
        result = cursor.fetchone()

        if result and result[0]:
            exec_time = result[0][0].get('Execution Time', 0)
            return True, exec_time, None

        print(f"no execution plan")
        return False, float('inf'), "no execution plan"

    except Exception as e:
        print(str(e))
        return False, float('inf'), str(e)



In [31]:
def reward_func(completions: Sequence[str], **kwargs) -> List[float]:

    conn = None
    cursor = None

    conn = get_db_connection()
    if not conn:
        return False, float('inf'), "сonnection failed"

    cursor = conn.cursor()


    queries = kwargs.get('query', [None] * len(completions))

    num_generations = kwargs.get('num_generations', 4)

    all_rewards = []

    for i in range(0, len(completions), num_generations):
        group_completions = completions[i:i+num_generations]
        group_queries = queries[i:i+num_generations]

        query = group_queries[0]

        execution_results = []
        for completion in group_completions:
            hints = extract_hints_from_completion(completion)
            success, exec_time, error = execute_query_with_hints(query, hints, cursor)
            print(success, exec_time, error)
            execution_results.append({
                'success': success,
                'time': exec_time if success else float('inf'),
                'hints': hints,
                'completion': completion
            })
        success_baseline, time_baseline, error = execute_query_with_hints(query, None, cursor)
        print(success, exec_time, error)
        if not success_baseline:
            time_baseline = float('inf')

        ranks = compute_ranks([r['time'] for r in execution_results])
        group_rewards = []
        for rank, result in zip(ranks, execution_results):
            reward = compute_reward_hybrid(
                rank=rank,
                time_execution=result['time'],
                time_baseline=time_baseline,
                num_generations=num_generations,
                beta=0.7
            )

            group_rewards.append(reward)

            print(f"Rank {rank}: time={result['time']:.2f}ms, "
                  f"reward={reward:.3f}")

        all_rewards.extend(group_rewards)

    return all_rewards


def compute_ranks(times: List[float]) -> List[int]:

    sorted_indices = sorted(range(len(times)), key=lambda i: times[i])

    ranks = [0] * len(times)
    current_rank = 1

    i = 0
    while i < len(sorted_indices):
        current_time = times[sorted_indices[i]]
        j = i

        while j < len(sorted_indices) and times[sorted_indices[j]] == current_time:
            ranks[sorted_indices[j]] = current_rank
            j += 1

        current_rank += (j - i)
        i = j

    return ranks


def compute_reward_hybrid(
    rank: int,
    time_execution: float,
    time_baseline: float,
    num_generations: int,
    beta: float = 0.7
) -> float:

    if num_generations == 1:
        rank_reward = 0.0
    else:
        rank_reward = ((num_generations - rank) / (num_generations - 1)) - 0.5

    magnitude_reward = 0.0
    if time_baseline > 0 and time_baseline != float('inf'):
        if time_execution == float('inf'):
            magnitude_reward = -0.5
        else:
            relative_improvement = (time_baseline - time_execution) / time_baseline
            magnitude_reward = np.tanh(relative_improvement) * 0.5

    hybrid_reward = beta * rank_reward + (1 - beta) * magnitude_reward

    return hybrid_reward



In [32]:
def split_dataset(dataset, test_size=0.1, seed=42):
    return dataset.train_test_split(test_size=test_size, seed=seed)

In [33]:
def main(json_file: str, sft_model_path: str, output_dir: str):

    print("GSPO Training")

    print("\n[1/6] загрузка модели")

    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16,
        bnb_8bit_use_double_quant=True
    )

    model = AutoModelForCausalLM.from_pretrained(
        sft_model_path,
        quantization_config=bnb_config,
        device_map='auto',
        trust_remote_code=True
    )


    tokenizer = AutoTokenizer.from_pretrained(
        sft_model_path,
        trust_remote_code=True
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    model.config.pad_token_id = tokenizer.pad_token_id


    with open("chat_template.jinja", "r", encoding="utf-8") as f:
        chat_template = f.read()
    tokenizer.chat_template = chat_template
    tokenizer.padding_side = 'left'

    print(tokenizer.chat_template)

    print("\n[2/6] загрузка датасета")

    raw_dataset = prepare_grpo_dataset(json_file, tokenizer)
    print(f"загружено {len(raw_dataset)} примеров")

    print("\n[3/6] LoRA")

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                       "gate_proj", "up_proj", "down_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    print("\n[4/6] train val для GRPO")

    train_val = split_dataset(raw_dataset)
    train_dataset, val_dataset = train_val['train'], train_val['test']

    print(f"train: {len(train_dataset)}, val: {len(val_dataset)}")
    print(f"колонки: {train_dataset.column_names}")

    print("\n[5/6] конфиг GRPO")

    training_args = GRPOConfig(
        output_dir=output_dir,

        num_train_epochs=1,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=5e-6,

        num_generations=4,
        max_completion_length=150,
        max_prompt_length=3000,
        temperature=0.7,
        top_p=0.9,
        top_k=20,
        #scale_rewards="group",
        #loss_type="dapo",
        beta=0.0,
        importance_sampling_level="sequence",

        max_grad_norm=1.0,
        warmup_steps=20,

        logging_steps=1,
        save_steps=5,
        eval_steps=5,
        save_total_limit=5,

        eval_strategy="steps",

        bf16=False,
        fp16=True,
        gradient_checkpointing=True,
        remove_unused_columns=False,

        use_vllm=True,
        vllm_mode="colocate",
        vllm_gpu_memory_utilization=0.6,
        vllm_tensor_parallel_size=1,
        vllm_guided_decoding_grammar=grammar,


        report_to="none",
    )

    print(f"Num generations: {training_args.num_generations}")
    print(f"Train batch size: {training_args.per_device_train_batch_size}")
    print(f"Max prompt length: {training_args.max_prompt_length}")
    print(f"Max completion length: {training_args.max_completion_length}")

    print("\n[6/6] запуск обучения")

    trainer = GRPOTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        processing_class=tokenizer,
        reward_funcs=reward_func,
        callbacks=[ConsoleLoggingCallback()]
    )

    trainer.train()

    print("сохранение модели")

    final_model_path = os.path.join(output_dir, "final-gspo-model2")
    trainer.save_model(final_model_path)
    tokenizer.save_pretrained(final_model_path)

    print(f"модель сохранена: {final_model_path}")

    return trainer, final_model_path

In [34]:
drive.mount('/content/drive')
JSON_FILE = "output.json"
SFT_MODEL_PATH = "/content/drive/My Drive/sql_models/final-model_1"
OUTPUT_DIR = "/content/drive/My Drive/sql_models/gspo-model"

main(JSON_FILE, SFT_MODEL_PATH, OUTPUT_DIR)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
GSPO Training

[1/6] загрузка модели


ValueError: Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to `from_pretrained`. Check https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu for more details. 

In [None]:
drive.mount('/content/drive')
gspo_model_path = "/content/drive/My Drive/sql_models/gspo-model/final-gspo-model2"
SFT_MODEL_PATH = "/content/drive/My Drive/sql_models/final-model_1"


bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.float16,
    bnb_8bit_use_double_quant=True
)

model = AutoModelForCausalLM.from_pretrained(
    gspo_model_path,
    quantization_config=bnb_config,
    device_map='auto',
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    gspo_model_path,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

model.config.pad_token_id = tokenizer.pad_token_id

with open("chat_template.jinja", "r", encoding="utf-8") as f:
    chat_template = f.read()
tokenizer.chat_template = chat_template
tokenizer.padding_side = 'left'
print(tokenizer.chat_template)


In [None]:
def generate_hints(query, statistics, model, tokenizer):

    user_prompt = USER_PROMPT_TEMPLATE.format(
        query=query,
        card_tb=statistics["Card_Tb"],
        ndv=format_ndv_stats(statistics["NDV"]),
        main_value=format_main_values(statistics["Main_Value"]),
        min_max=format_min_max(statistics["Min_Max"])
    )

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT.strip()},
        {"role": "user", "content": user_prompt.strip()},
    ]

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=3000,
        do_sample=True,
        temperature=0.2,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

In [None]:
#29a.sql
test_statistics = {"Card_Tb": "k : 1 (236627)\nit : 1 (113)\nit : 1 (113)\nt : 11 (3299359)\ncc : 2 (135086)\nmi : 1 (29881566)\nci : 1 (63525648)\nn : 1 (6379436)\nchn : 1 (4314624)\ncct1 : 1 (4)\nrt : 1 (12)\npi : 36 (4135549)\nan : 2 (1312273)\nmk : 64 (7480081)\ncct1 : 1 (4)\nmc : 6 (4958465)\ncn : 1 (362131)",
    "NDV": "aka_name.md5sum : -0.42351857\naka_name.name_pcode_cf : 12213.0\naka_name.imdb_index : -3.3318996e-05\naka_name.id : -1.0\naka_name.person_id : -0.40863448\naka_name.name_pcode_nf : 10619.0\naka_name.name : -0.42351857\naka_name.surname_pcode : 2575.0\ncomplete_cast.movie_id : -0.62968034\ncomplete_cast.id : -1.0\ncomplete_cast.status_id : 2.0\ncomplete_cast.subject_id : 2.0\ncomp_cast_type.id : -1.0\ncomp_cast_type.kind : -1.0\ncomp_cast_type.id : -1.0\ncomp_cast_type.kind : -1.0\nchar_name.name_pcode_nf : 12349.0\nchar_name.id : -1.0\nchar_name.name : -1.0\nchar_name.imdb_index : 0.0\nchar_name.imdb_id : 0.0\nchar_name.md5sum : -1.0\nchar_name.surname_pcode : 4931.0\ncast_info.person_id : 66808.0\ncast_info.movie_id : 1733126.0\ncast_info.note : 4054.0\ncast_info.person_role_id : 20048.0\ncast_info.id : -1.0\ncast_info.nr_order : 147.0\ncast_info.role_id : 11.0\ncompany_name.country_code : 174.0\ncompany_name.name_pcode_sf : 11103.0\ncompany_name.md5sum : -1.0\ncompany_name.name : -0.86129606\ncompany_name.id : -1.0\ncompany_name.imdb_id : 0.0\ncompany_name.name_pcode_nf : 11853.0\ninfo_type.info : -1.0\ninfo_type.id : -1.0\ninfo_type.info : -1.0\ninfo_type.id : -1.0\nkeyword.phonetic_code : 10374.0\nkeyword.keyword : -1.0\nkeyword.id : -1.0\nmovie_companies.company_id : 20568.0\nmovie_companies.note : 6522.0\nmovie_companies.movie_id : -0.16302344\nmovie_companies.id : -1.0\nmovie_companies.company_type_id : 2.0\nmovie_info.id : -1.0\nmovie_info.note : 2190.0\nmovie_info.info : 24949.0\nmovie_info.info_type_id : 75.0\nmovie_info.movie_id : 549133.0\nmovie_keyword.keyword_id : 18044.0\nmovie_keyword.id : -1.0\nmovie_keyword.movie_id : 116740.0\nname.name_pcode_nf : 10820.0\nname.surname_pcode : 2644.0\nname.name : -0.29187077\nname.imdb_id : 0.0\nname.gender : 2.0\nname.name_pcode_cf : 12306.0\nname.id : -1.0\nname.md5sum : -1.0\nname.imdb_index : 110.0\nperson_info.id : -1.0\nperson_info.person_id : 113614.0\nperson_info.info_type_id : 21.0\nperson_info.info : 160562.0\nperson_info.note : 4985.0\nrole_type.id : -1.0\nrole_type.role : -1.0\ntitle.imdb_index : -0.000100016594\ntitle.kind_id : 6.0\ntitle.imdb_id : 0.0\ntitle.id : -1.0\ntitle.episode_nr : 2045.0\ntitle.title : 78983.0\ntitle.episode_of_id : 21104.0\ntitle.series_years : 286.0\ntitle.md5sum : -1.0\ntitle.phonetic_code : 10765.0\ntitle.season_nr : 80.0\ntitle.production_year : 116.0",
    "Main_Value": "aka_name.md5sum : ['1b83d5da74032b6a750ef12210642eea', 525]\naka_name.name_pcode_cf : ['J5252', 3937]\naka_name.name_pcode_nf : ['A4253', 6080]\naka_name.name : ['Mike', 525]\naka_name.surname_pcode : ['R2', 8136]\ncomplete_cast.status_id : ['3', 110626]\ncomplete_cast.subject_id : ['1', 86320]\nchar_name.name_pcode_nf : ['H5241', 148135]\nchar_name.surname_pcode : ['M5', 32935]\ncast_info.person_id : ['186155', 29645]\ncast_info.note : ['(producer)', 2530438]\ncast_info.person_role_id : ['2', 2214928]\ncast_info.nr_order : ['1', 2094229]\ncast_info.role_id : ['1', 20495491]\ncompany_name.country_code : ['[us]', 138250]\ncompany_name.name_pcode_sf : ['P6325', 1171]\ncompany_name.name : ['Sony Pictures Releasing', 97]\ncompany_name.name_pcode_nf : ['P6325', 1243]\nkeyword.phonetic_code : ['R1652', 25201]\nmovie_companies.company_id : ['67', 72559]\nmovie_companies.note : ['(in association with)', 39172]\nmovie_companies.company_type_id : ['2', 2915743]\nmovie_info.note : ['Anonymous', 131479]\nmovie_info.info : ['Color', 1729147]\nmovie_info.info_type_id : ['16', 5471315]\nmovie_keyword.keyword_id : ['1078', 81284]\nmovie_keyword.movie_id : ['3361480', 1995]\nname.name_pcode_nf : ['A5362', 36150]\nname.surname_pcode : ['R2', 49547]\nname.gender : ['m', 2682340]\nname.name_pcode_cf : ['J5252', 19138]\nname.imdb_index : ['I', 502487]\nperson_info.person_id : ['2425605', 2343]\nperson_info.info_type_id : ['17', 882940]\nperson_info.info : ['Los Angeles', 15439]\nperson_info.note : ['Anonymous', 14474]\ntitle.kind_id : ['7', 3022103]\ntitle.episode_nr : ['1', 142752]\ntitle.title : ['(#1.1)', 27495]\ntitle.episode_of_id : ['628404', 11988]\ntitle.series_years : ['2015-????', 9128]\ntitle.phonetic_code : ['A1416', 8248]\ntitle.season_nr : ['1', 1344269]\ntitle.production_year : ['2015', 191363]",
    "Min_Max": "aka_name.id : [1, 1312273]\naka_name.person_id : [5, 6379735]\ncomplete_cast.id : [1, 135086]\ncomplete_cast.movie_id : [1781, 4736782]\ncomplete_cast.subject_id : [1, 2]\ncomplete_cast.status_id : [3, 4]\ncomp_cast_type.id : [1, 4]\ncomp_cast_type.id : [1, 4]\nchar_name.id : [1, 4314864]\ncast_info.id : [1, 63475835]\ncast_info.person_id : [1, 6226526]\ncast_info.movie_id : [1, 4730370]\ncast_info.person_role_id : [1, 4314864]\ncast_info.nr_order : [-2068070866, 1776839230]\ncast_info.role_id : [1, 11]\ncompany_name.id : [1, 362131]\ninfo_type.id : [1, 113]\ninfo_type.id : [1, 113]\nkeyword.id : [1, 236627]\nmovie_companies.id : [1, 4958296]\nmovie_companies.movie_id : [2, 4698791]\nmovie_companies.company_id : [1, 362131]\nmovie_companies.company_type_id : [1, 2]\nmovie_info.id : [1, 29774686]\nmovie_info.movie_id : [1, 4730846]\nmovie_info.info_type_id : [1, 113]\nmovie_keyword.id : [1, 7480087]\nmovie_keyword.movie_id : [2, 4730753]\nmovie_keyword.keyword_id : [1, 236627]\nname.id : [1, 6379740]\nperson_info.id : [1, 4130207]\nperson_info.person_id : [1, 6379740]\nperson_info.info_type_id : [15, 39]\nrole_type.id : [1, 12]\ntitle.id : [100000, 3399999]\ntitle.kind_id : [1, 8]\ntitle.production_year : [1888, 2115]\ntitle.episode_of_id : [99685, 3300011]\ntitle.season_nr : [1, 2015]\ntitle.episode_nr : [1, 91334]"}

sql = "SELECT MIN(chn.name) AS voiced_char,\n       MIN(n.name) AS voicing_actress,\n       MIN(t.title) AS voiced_animation\nFROM aka_name AS an,\n     complete_cast AS cc,\n     comp_cast_type AS cct1,\n     comp_cast_type AS cct2,\n     char_name AS chn,\n     cast_info AS ci,\n     company_name AS cn,\n     info_type AS it,\n     info_type AS it3,\n     keyword AS k,\n     movie_companies AS mc,\n     movie_info AS mi,\n     movie_keyword AS mk,\n     name AS n,\n     person_info AS pi,\n     role_type AS rt,\n     title AS t\nWHERE cct1.kind ='cast'\n  AND cct2.kind ='complete+verified'\n  AND chn.name = 'Queen'\n  AND ci.note IN ('(voice)',\n                  '(voice) (uncredited)',\n                  '(voice: English version)')\n  AND cn.country_code ='[us]'\n  AND it.info = 'release dates'\n  AND it3.info = 'trivia'\n  AND k.keyword = 'computer-animation'\n  AND mi.info IS NOT NULL\n  AND (mi.info LIKE 'Japan:%200%'\n       OR mi.info LIKE 'USA:%200%')\n  AND n.gender ='f'\n  AND n.name LIKE '%An%'\n  AND rt.role ='actress'\n  AND t.title = 'Shrek 2'\n  AND t.production_year BETWEEN 2000 AND 2010\n  AND t.id = mi.movie_id\n  AND t.id = mc.movie_id\n  AND t.id = ci.movie_id\n  AND t.id = mk.movie_id\n  AND t.id = cc.movie_id\n  AND mc.movie_id = ci.movie_id\n  AND mc.movie_id = mi.movie_id\n  AND mc.movie_id = mk.movie_id\n  AND mc.movie_id = cc.movie_id\n  AND mi.movie_id = ci.movie_id\n  AND mi.movie_id = mk.movie_id\n  AND mi.movie_id = cc.movie_id\n  AND ci.movie_id = mk.movie_id\n  AND ci.movie_id = cc.movie_id\n  AND mk.movie_id = cc.movie_id\n  AND cn.id = mc.company_id\n  AND it.id = mi.info_type_id\n  AND n.id = ci.person_id\n  AND rt.id = ci.role_id\n  AND n.id = an.person_id\n  AND ci.person_id = an.person_id\n  AND chn.id = ci.person_role_id\n  AND n.id = pi.person_id\n  AND ci.person_id = pi.person_id\n  AND it3.id = pi.info_type_id\n  AND k.id = mk.keyword_id\n  AND cct1.id = cc.subject_id\n  AND cct2.id = cc.status_id;"


In [None]:
print(extract_hints_from_completion(generate_hints(sql, test_statistics, model, tokenizer)))