In [1]:
import re
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "/data/gongzheng/llm/deepseek-moe-16b-chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda:4")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 7/7 [00:08<00:00,  1.19s/it]


In [1]:
input_sequence = 'what is the name and nation of the singer who have a song having "Hey" in its name? | singer : singer.country , singer.song_name , singer.name , singer.singer_id , singer.age | stadium : stadium.location , stadium.name , stadium.capacity , stadium.highest , stadium.lowest | singer_in_concert : singer_in_concert.concert_id , singer_in_concert.singer_id | concert : concert.theme , concert.year , concert.concert_id , concert.concert_name , concert.stadium_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id'
question, schemas = input_sequence.split('|', 1)

In [2]:
messages = [
    {"role": "user", "content": f"There are a database has several tables: {schemas}, can you give me the SQL about {question}"}
]

In [3]:
messages

[{'role': 'user',
  'content': 'There are a database has several tables:  singer : singer.country , singer.song_name , singer.name , singer.singer_id , singer.age | stadium : stadium.location , stadium.name , stadium.capacity , stadium.highest , stadium.lowest | singer_in_concert : singer_in_concert.concert_id , singer_in_concert.singer_id | concert : concert.theme , concert.year , concert.concert_id , concert.concert_name , concert.stadium_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id, can you give me the SQL about what is the name and nation of the singer who have a song having "Hey" in its name? '}]

In [5]:
# messages = [{"role": "user", "content": f"There are a database has several tables: {schema}, can you give me the SQL about {question}"}]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=250)

result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
result



No chat template is defined for this tokenizer - using the default template for the LlamaTokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.



'\n\n```sql\nSELECT singer.name, singer.country\nFROM singer\nJOIN singer_in_concert ON singer.singer_id = singer_in_concert.singer_id\nJOIN concert ON singer_in_concert.concert_id = concert.concert_id\nJOIN stadium ON concert.stadium_id = stadium.stadium_id\nWHERE singer_in_concert.concert_id IN (\n    SELECT concert_id\n    FROM singer_in_concert\n    WHERE song_name LIKE \'%Hey%\'\n)\nAND singer.singer_id IN (\n    SELECT singer_id\n    FROM singer_in_concert\n    WHERE song_name LIKE \'%Hey%\'\n);\n```\n\nThis SQL query will return the name and nation of the singer who have a song having "Hey" in its name. It first joins the tables together to get the necessary information, then it uses a subquery to find the concert_id of the concerts that have a song with "Hey" in its name. Finally, it uses these concert_id to find the singer who performed in these concerts.'

In [6]:
re.findall(r"```sql(.*)```", result, re.DOTALL)[0].replace('\n', ' ').replace(';', ' ')

" SELECT singer.name, singer.country FROM singer JOIN singer_in_concert ON singer.singer_id = singer_in_concert.singer_id JOIN concert ON singer_in_concert.concert_id = concert.concert_id JOIN stadium ON concert.stadium_id = stadium.stadium_id WHERE singer_in_concert.concert_id IN (     SELECT concert_id     FROM singer_in_concert     WHERE song_name LIKE '%Hey%' ) AND singer.singer_id IN (     SELECT singer_id     FROM singer_in_concert     WHERE song_name LIKE '%Hey%' )  "

In [None]:
import json

with open("data/preprocessed_data/resdsql_dev_natsql.json") as f:
    eval_data = json.load(f)

input_sequences = [x['input_sequence'] for x in eval_data]
output_sqls = [x['output_sequence'].split('|')[1] for x in eval_data]

input_sequences[0]
input_sequences[0].split('|', 1)
question, schema = input_sequences[0].split('|', 1)


from tqdm import tqdm
import re

fail_question = []
fail_sql = []

with open("data/output/deepseek-moe-16b-chat.json", "a+") as f:
    for input_seq, output_sql in tqdm(zip(input_sequences[:5], output_sqls[:5])):
        try:
            question, schema = input_seq.split('|', 1)
            messages = [{"role": "user", "content": f"There are a database has several tables: {schema}, can you give me the SQL about {question}"}]
            input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
            outputs = model.generate(input_tensor.to(model.device), max_new_tokens=200)
            result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
            result_sql = re.findall(r"```sql(.*)```", result, re.DOTALL)[0].replace('\n', ' ').replace(';', ' ')
            f.write(result_sql + '\n')
            f.write(output_sql + '\n')
        except Exception as e:
            fail_question.append(input_seq)
            fail_sql.append(output_sql)
            
with open("data/output/deepseek-moe-16b-chat-fail.json", "a+") as f:
    for input_seq, output_sql in tqdm(zip(fail_question, fail_sql)):
        f.write(input_seq + '\n')
        f.write(output_sql + '\n')

In [None]:
result