In [None]:
import os
# 设置HTTP和HTTPS代理
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Qwen2ForCausalLM

model_name = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)

In [None]:
#模型下载
from modelscope import snapshot_download
model_dir = snapshot_download('qwen/Qwen2.5-Coder-1.5B-Instruct')

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "/home/linzhisheng/.cache/modelscope/hub/qwen/Qwen2___5-Coder-1___5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "write a quick sort algorithm."
messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [None]:
response

In [None]:
import torch
import math


class multiHeadAttention:
    def __init__(self, hidden_dim, head_num, group_num) -> None:
        
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = self.hidden_dim // self.head_num
        self.group_num = group_num
        self.key_value_head_num = self.head_num // group_num
        
        self.q = torch.nn.Linear(hidden_dim, self.head_num * self.head_dim)
        self.k = torch.nn.Linear(hidden_dim, self.key_value_head_num * self.head_dim)
        self.v = torch.nn.Linear(hidden_dim, self.key_value_head_num * self.head_dim)
        self.o = torch.nn.Linear(hidden_dim, hidden_dim)

        self.softmax = torch.nn.Softmax(dim=-1)
        self.dropout = torch.nn.Dropout()
    
    def forward(self, x):
        
        bz, seq_len, hidden_dim = x.shape
        
        query = self.q(x).view(bz, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        key = self.k(x).view(bz, seq_len, self.key_value_head_num, self.head_dim).transpose(1, 2)
        value = self.v(x).view(bz, seq_len, self.key_value_head_num, self.head_dim).transpose(1, 2)
        # output = self.o(x).view(bz, seq_len, self.head_num, self.head_dim)
        
        key = torch.repeat_interleave(key, dim=1, repeats=self.group_num)
        value = torch.repeat_interleave(value, dim=1, repeats=self.group_num)
        print(value.shape)
        
        # print(query.shape)
        score = torch.matmul(query, key.transpose(2,3)) / math.sqrt(self.head_dim)
        
        prob = self.softmax(score)
        prob = self.dropout(prob)
        
        output = torch.matmul(prob, value).transpose(1,2).contiguous().view(bz, seq_len, hidden_dim)
        
        output = self.o(output)
        return output

x = torch.randn(12, 24, 768)
attention = multiHeadAttention(768, 2, 2)

output = attention.forward(x)

output.shape





# 预处理spider数据

In [53]:
import sys
sys.path.append('../')
from utils.common_utils import read_json, write_json
from utils.schema_utils import scm_dict2text 
from utils.prompt_utils import nl2sqlte_template, gen_train_prompt

data = read_json('spider/train_spider.json')

table = read_json('spider/tables.json')
schema_map = {}

sql_template = """【DB_ID】 {db_id}
【Schema】
{tables_info}
【Foreign keys】
{fks_info}"""

for schema_dict in table:
    
    db_id = schema_dict['db_id']
    
    tables_info = ''
    fks_info = ''
    
    column_names_original = schema_dict['column_names_original']
    table_names = schema_dict['table_names_original']
    column_types = schema_dict['column_types']
    foreign_keys = schema_dict['foreign_keys']
    
    column_infos_map = {}
    
    for forengn_key in foreign_keys:
        left = forengn_key[0]
        right = forengn_key[1]
        
        left_info = column_names_original[left]
        left_table = table_names[left_info[0]]
        left_name = left_info[1]
        right_info = column_names_original[right]
        right_table = table_names[right_info[0]]
        right_name = right_info[1]
        
        fks_info += f'{left_table}.{left_name}={right_table}.{right_name}\n'        
    
    
    for idx, column_info in enumerate(column_names_original):
        table_idx = column_info[0]
        column_name = column_info[1]
        column_type = column_types[idx-1]
        if table_idx == -1:
            continue
        if table_names[table_idx] not in column_infos_map:
            column_infos_map[table_names[table_idx]] = {'column_name':[], 'column_type':[]}
        column_infos_map[table_names[table_idx]]['column_name'].append(column_name)
        column_infos_map[table_names[table_idx]]['column_type'].append(column_type)
    
    # print(column_infos_map)
    
    
    for i, table_name in enumerate(table_names):
        t = f'# Table: {table_name}\n[\n'
        column_infos = column_infos_map[table_name]
        
        for j, column_name in enumerate(column_infos['column_name']):
            t += f'  ({column_name}:{column_infos["column_type"][j]}),\n'
        
        t += ']'
        
        tables_info += t + '\n'
    
    schema = sql_template.format(db_id=db_id, tables_info=tables_info, fks_info=fks_info)
    schema_map[schema_dict['db_id']] = schema

# print(schema_map)


result = []
for idx, row in enumerate(data):
    question = row['question']
    db_id = row['db_id']
    sql = row['query']
    schema = schema_map[db_id]
    # result.append({'question':question, 'schema':schema, 'sql':sql, 'db_id':db_id})
    data_item = {'question':question, 'db_schema':schema, 'sql':sql, 'evidence': ''}

    prompt = gen_train_prompt(idx, data_item, 'sqlite')
    prompt['db_name'] = db_id
    result.append(prompt)



write_json('spider/train_spider_chat.json', result)
    

In [35]:
t = {
    "question": "How many singers do we have?",
    "schema": "【DB_ID】 concert_singer\n【Schema】\n# Table: stadium\n[\n  (Stadium_ID:text),\n  (Location:number),\n  (Name:text),\n  (Capacity:text),\n  (Highest:number),\n  (Lowest:number),\n  (Average:number),\n]\n# Table: singer\n[\n  (Singer_ID:number),\n  (Name:number),\n  (Country:text),\n  (Song_Name:text),\n  (Song_release_year:text),\n  (Age:text),\n  (Is_male:number),\n]\n# Table: concert\n[\n  (concert_ID:others),\n  (concert_Name:number),\n  (Theme:text),\n  (Stadium_ID:text),\n  (Year:text),\n]\n# Table: singer_in_concert\n[\n  (concert_ID:text),\n  (Singer_ID:number),\n]\n\n【Foreign keys】\nconcert.Stadium_ID=stadium.Stadium_ID\nsinger_in_concert.Singer_ID=singer.Singer_ID\nsinger_in_concert.concert_ID=concert.concert_ID\n",
    "sql": "SELECT count(*) FROM singer",
    "db_id": "concert_singer"
  }

print(t['schema'])

【DB_ID】 concert_singer
【Schema】
# Table: stadium
[
  (Stadium_ID:text),
  (Location:number),
  (Name:text),
  (Capacity:text),
  (Highest:number),
  (Lowest:number),
  (Average:number),
]
# Table: singer
[
  (Singer_ID:number),
  (Name:number),
  (Country:text),
  (Song_Name:text),
  (Song_release_year:text),
  (Age:text),
  (Is_male:number),
]
# Table: concert
[
  (concert_ID:others),
  (concert_Name:number),
  (Theme:text),
  (Stadium_ID:text),
  (Year:text),
]
# Table: singer_in_concert
[
  (concert_ID:text),
  (Singer_ID:number),
]

【Foreign keys】
concert.Stadium_ID=stadium.Stadium_ID
singer_in_concert.Singer_ID=singer.Singer_ID
singer_in_concert.concert_ID=concert.concert_ID

