In [1]:
from train_bys import dgl_node_and_edge_vectorization,dgl_to_pyg,ImprovementPredictionModelGNN,database
import torch
import pickle
import pandas as pd


data_loader_file = '/tmp/tpcds_1_data_loader_gen.pkl'
# 从磁盘加载DataLoader
with open(data_loader_file, 'rb') as f:
    data_loader = pickle.load(f)

print('DataLoader loaded from disk.')

DataLoader loaded from disk.


In [2]:
def predict_query_improvement(sql, index_config, plan,model, attribute_dict, device):
    """
    特征化 SQL 查询和索引配置，并使用保存的模型进行预测。

    参数:
    - sql (str): SQL 查询语句。
    - index_config (str): 索引配置。
    - model (nn.Module): 加载的模型。
    - attribute_dict (dict): 属性字典，用于配置向量化。
    - device (torch.device): 模型运行的设备。

    返回:
    - prediction (float): 预测的提升值。
    """
    # 特征化 SQL 查询和索引配置
    g = dgl_node_and_edge_vectorization(sql, index_config, plan, attribute_dict)
    
    # 转换为 PyG 数据
    pyg_data = dgl_to_pyg(g, 0)  # 这里使用 0 作为标签，因为我们只是预测而不是训练
    
    # 将数据移动到正确的设备
    pyg_data = pyg_data.to(device)
    
    # 模型预测
    model.eval()
    with torch.no_grad():
        prediction = model.predict(pyg_data.x, pyg_data.edge_attr, pyg_data.edge_index, pyg_data.configuration_vector, pyg_data.batch)
    
    return prediction.item()

def load_model_and_attribute_dict(model_path, attribute_dict_path):
    # 首先，我们需要创建一个模型实例
    model = ImprovementPredictionModelGNN(
        data_loader.dataset[0].x.shape[1],
        data_loader.dataset[0].edge_attr.shape[1],
        graph_embedding_size=32,
        config_vector_size=data_loader.dataset[0].configuration_vector.shape[0]
    )
    
    # 然后，加载模型的state_dict
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    
    # 将模型设置为评估模式
    model.eval()
    
    # 加载属性字典
    with open(attribute_dict_path, 'rb') as f:
        attribute_dict = pickle.load(f)
    
    return model, attribute_dict

In [3]:
model_path='/home/ubuntu/project/mayang/LOGER/core/infer_model/trained_model.pth'
attr_path='/home/ubuntu/project/mayang/LOGER/data/tpcds_1/attribute_dict.pkl'
model,attribute_dict = load_model_and_attribute_dict(model_path,attr_path)

start ImprovementPredictionModelGNN


In [4]:
attribute_dict

{'customer_demographics.cd_demo_sk': 0,
 'customer_address.ca_address_sk': 1,
 'store_sales.ss_net_profit': 2,
 'date_dim.d_year': 3,
 'store.s_store_sk': 4,
 'household_demographics.hd_demo_sk': 5,
 'date_dim.d_date_sk': 6,
 'customer_address.ca_state': 7,
 'web_page.wp_web_page_sk': 8,
 'web_sales.ws_order_number': 9,
 'web_sales.ws_item_sk': 10,
 'store_sales.ss_sales_price': 11,
 'income_band.ib_income_band_sk': 12,
 'income_band.ib_lower_bound': 13,
 'customer_address.ca_city': 14,
 'store_returns.sr_cdemo_sk': 15,
 'date_dim.d_month_seq': 16,
 'item.i_item_sk': 17,
 'customer.c_customer_sk': 18,
 'catalog_sales.cs_sold_date_sk': 19,
 'promotion.p_promo_sk': 20,
 'promotion.p_channel_email': 21,
 'catalog_sales.cs_ship_date_sk': 22,
 'time_dim.t_hour': 23,
 'household_demographics.hd_dep_count': 24,
 'store_sales.ss_store_sk': 25,
 'store_sales.ss_sold_time_sk': 26,
 'store_sales.ss_sold_date_sk': 27,
 'store_sales.ss_item_sk': 28,
 'catalog_returns.cr_returned_date_sk': 29,
 'ite

In [5]:
df=pd.read_csv('/home/ubuntu/project/mayang/Classification/process_data/gene_luo/run_tpcds_1_gen_train.csv').iloc[:,2:]
df.head()

Unnamed: 0,query,index,c0,c1,t0,t1,error,query_plan_no_index,query_plan_index,improvement_whatif,improvement
0,"\n\n\n\n\nselect avg(ss_quantity)\n ,avg...",['I(C customer_demographics.cd_demo_sk)'],146893.67,104818.78,878.538,691.15,0.073136,"[{'Plan': {'Node Type': 'Aggregate', 'Strategy...","[{'Plan': {'Node Type': 'Aggregate', 'Strategy...",0.286431,0.213295
1,"\n\n\n\n\nselect avg(ss_quantity)\n ,avg...","['I(C customer_demographics.cd_demo_sk)', 'I(C...",146893.67,103012.62,878.538,424.949,-0.217573,"[{'Plan': {'Node Type': 'Aggregate', 'Strategy...","[{'Plan': {'Node Type': 'Aggregate', 'Strategy...",0.298727,0.5163
2,"\n\n\n\n\nselect avg(ss_quantity)\n ,avg...","['I(C customer_demographics.cd_demo_sk)', 'I(C...",146893.67,71299.61,878.538,342.197,-0.095875,"[{'Plan': {'Node Type': 'Aggregate', 'Strategy...","[{'Plan': {'Node Type': 'Aggregate', 'Strategy...",0.514618,0.610493
3,"\n\n\n\n\nselect avg(ss_quantity)\n ,avg...","['I(C customer_demographics.cd_demo_sk)', 'I(C...",146893.67,102903.64,878.538,691.551,0.08663,"[{'Plan': {'Node Type': 'Aggregate', 'Strategy...","[{'Plan': {'Node Type': 'Aggregate', 'Strategy...",0.299469,0.212839
4,"\n\n\n\n\nselect avg(ss_quantity)\n ,avg...","['I(C customer_demographics.cd_demo_sk)', 'I(C...",146893.67,104817.66,878.538,752.692,0.143194,"[{'Plan': {'Node Type': 'Aggregate', 'Strategy...","[{'Plan': {'Node Type': 'Aggregate', 'Strategy...",0.286439,0.143245


In [6]:
df['index'].values[:10]

array(["['I(C customer_demographics.cd_demo_sk)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C customer_address.ca_address_sk)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C store_sales.ss_net_profit)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C date_dim.d_year)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C store.s_store_sk)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C household_demographics.hd_demo_sk)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C customer_address.ca_address_sk)', 'I(C store_sales.ss_net_profit)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C store_sales.ss_net_profit)', 'I(C date_dim.d_year)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C store_sales.ss_net_profit)', 'I(C store.s_store_sk)']",
       "['I(C customer_demographics.cd_demo_sk)', 'I(C store_sales.ss_net_profit)', 'I(C household_demographics.hd_demo_sk)']"],
      dtype=object)

In [7]:
print(df['query'].values[0])






select avg(ss_quantity)
       ,avg(ss_ext_sales_price)
       ,avg(ss_ext_wholesale_cost)
       ,sum(ss_ext_wholesale_cost)
 from store_sales
     ,store
     ,customer_demographics
     ,household_demographics
     ,customer_address
     ,date_dim
 where s_store_sk = ss_store_sk
 and  ss_sold_date_sk = d_date_sk and d_year = 2001
 and((ss_hdemo_sk=hd_demo_sk
  and cd_demo_sk = ss_cdemo_sk
  and cd_marital_status = 'M'
  and cd_education_status = 'Secondary'
  and ss_sales_price between 100.00 and 150.00
  and hd_dep_count = 3   
     )or
     (ss_hdemo_sk=hd_demo_sk
  and cd_demo_sk = ss_cdemo_sk
  and cd_marital_status = 'S'
  and cd_education_status = 'Advanced Degree'
  and ss_sales_price between 50.00 and 100.00   
  and hd_dep_count = 1
     ) or 
     (ss_hdemo_sk=hd_demo_sk
  and cd_demo_sk = ss_cdemo_sk
  and cd_marital_status = 'U'
  and cd_education_status = 'Primary'
  and ss_sales_price between 150.00 and 200.00 
  and hd_dep_count = 1  
     ))
 and((ss_addr_sk = c

In [9]:
database.setup(dbname='indexselection_tpcds___1',
            user='postgres',
            password='password',
            host='127.0.0.1', cache=False)
idx=0
predict_query_improvement(df['query'].values[idx],"['I(C customer_demographics.cd_demo_sk)']",df['query_plan_no_index'].values[idx],model,attribute_dict,'cpu')

{'dbname': 'indexselection_tpcds___1', 'user': 'postgres', 'password': 'password', 'host': '127.0.0.1', 'cache': False}
<connection object at 0x7f21b4ff72c0; dsn: 'user=postgres password=xxx dbname=indexselection_tpcds___1 host=127.0.0.1 port=5432', closed: 0>
<cursor object at 0x7f21bada95e0; closed: 0>
error cmp ( ( store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk AND customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk AND customer_demographics.cd_marital_status = 'M' AND customer_demographics.cd_education_status = 'Secondary' AND store_sales.ss_sales_price BETWEEN 100.0 AND 150.0 AND household_demographics.hd_dep_count = 3) OR ( store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk AND customer_demographics.cd_demo_sk = store_sales.ss_cdemo_sk AND customer_demographics.cd_marital_status = 'S' AND customer_demographics.cd_education_status = 'Advanced Degree' AND store_sales.ss_sales_price BETWEEN 50.0 AND 100.0 AND household_demographics.hd_dep_count = 1) OR (

0.19708728790283203

In [19]:
df['index'].values[idx]

"['I(C customer_demographics.cd_demo_sk)', 'I(C store.s_store_sk)']"