In [8]:
import numpy as np
import os
import torch
import torch.nn as nn
import time
import pandas as pd
from scipy.stats import pearsonr

In [9]:
from model.util import Normalizer
from model.database_util import get_hist_file, get_job_table_sample, collator
from model.model import QueryFormer
from model.database_util import Encoding
from model.dataset import PlanTreeDataset
from model.trainer import eval_workload, train

In [10]:
data_path = './data/imdb/'

In [11]:
class Args:
    bs = 1024
    lr = 0.001
    epochs = 200
    clip_size = 50
    embed_size = 64
    pred_hid = 128
    ffn_dim = 128
    head_size = 12
    n_layers = 8
    dropout = 0.1
    sch_decay = 0.6
    device = 'cuda:0'
    newpath = './results/full/cost/'
    to_predict = 'cost'
args = Args()

import os
if not os.path.exists(args.newpath):
    os.makedirs(args.newpath)

In [12]:
# 我以为连接数据库得到的统计信息，结果是存在了csv文件夹里（当然也可能是方便展示）
hist_file = get_hist_file(data_path + 'histogram_string.csv')
cost_norm = Normalizer(-3.61192, 12.290855)
card_norm = Normalizer(1,100)

In [13]:
# encoder部分，直接加载训练好的
encoding_ckpt = torch.load('checkpoints/encoding.pt')
encoding = encoding_ckpt['encoding']
checkpoint = torch.load('checkpoints/cost_model.pt', map_location='cpu')

In [20]:
dir(encoding)
encoding.join2idx

{None: 0,
 'mi_idx.movie_id = t.id': 1,
 'mc.movie_id = t.id': 2,
 'mi.movie_id = t.id': 3,
 'ci.movie_id = t.id': 4,
 'mk.movie_id = t.id': 5,
 'ci.movie_id = mk.movie_id': 6,
 'mi.movie_id = mk.movie_id': 7,
 'mi_idx.movie_id = mk.movie_id': 8,
 'mc.movie_id = mk.movie_id': 9,
 'ci.movie_id = mi_idx.movie_id': 10,
 'ci.movie_id = mc.movie_id': 11,
 'ci.movie_id = mi.movie_id': 12,
 'mi.movie_id = mi_idx.movie_id': 13,
 'mc.movie_id = mi_idx.movie_id': 14,
 'mc.movie_id = mi.movie_id': 15}

In [7]:
from model.util import seed_everything
seed_everything()

In [8]:
model = QueryFormer(emb_size = args.embed_size ,ffn_dim = args.ffn_dim, head_size = args.head_size, \
                 dropout = args.dropout, n_layers = args.n_layers, \
                 use_sample = True, use_hist = True, \
                 pred_hid = args.pred_hid
                )

In [9]:
_ = model.to(args.device)

In [10]:
to_predict = 'cost'

In [11]:
imdb_path = './data/imdb/'
full_train_df = pd.DataFrame()
for i in range(18):
    file = imdb_path + 'plan_and_cost/train_plan_part{}.csv'.format(i)
    df = pd.read_csv(file)
    full_train_df = full_train_df.append(df)

val_df = pd.DataFrame()
for i in range(18,20):
    file = imdb_path + 'plan_and_cost/train_plan_part{}.csv'.format(i)
    df = pd.read_csv(file)
    val_df = val_df.append(df)
table_sample = get_job_table_sample(imdb_path+'train')

Loaded queries with len  100000
Loaded bitmaps


In [12]:
full_train_df

Unnamed: 0,id,json
0,0,"{""Plan"": {""Node Type"": ""Gather"", ""Parallel Awa..."
1,1,"{""Plan"": {""Node Type"": ""Seq Scan"", ""Parallel A..."
2,2,"{""Plan"": {""Node Type"": ""Seq Scan"", ""Parallel A..."
3,3,"{""Plan"": {""Node Type"": ""Gather"", ""Parallel Awa..."
4,4,"{""Plan"": {""Node Type"": ""Bitmap Heap Scan"", ""Pa..."
...,...,...
4995,89995,"{""Plan"": {""Node Type"": ""Nested Loop"", ""Paralle..."
4996,89996,"{""Plan"": {""Node Type"": ""Index Scan"", ""Parallel..."
4997,89997,"{""Plan"": {""Node Type"": ""Gather"", ""Parallel Awa..."
4998,89998,"{""Plan"": {""Node Type"": ""Seq Scan"", ""Parallel A..."


In [12]:
train_ds = PlanTreeDataset(full_train_df, None, encoding, hist_file, card_norm, cost_norm, to_predict, table_sample)
val_ds = PlanTreeDataset(val_df, None, encoding, hist_file, card_norm, cost_norm, to_predict, table_sample)

  'features' : torch.FloatTensor(features),


In [13]:
crit = nn.MSELoss()
model, best_path = train(model, train_ds, val_ds, crit, cost_norm, args)

Epoch: 0  Avg Loss: 1.7558196942425435e-05, Time: 34.624067306518555
Median: 2.612529754759801
Mean: 408.87085068470975
Epoch: 20  Avg Loss: 9.952369812203364e-07, Time: 691.9167795181274
Median: 1.1413855878989845
Mean: 1.5584855022427344
Epoch: 40  Avg Loss: 8.243279602740788e-07, Time: 1360.6300423145294
Median: 1.1049955399020734
Mean: 1.4659079084318005
Epoch: 60  Avg Loss: 7.719535037823435e-07, Time: 2071.088265657425
Median: 1.093797157958328
Mean: 1.4326517437497097
Epoch: 80  Avg Loss: 7.282165991556313e-07, Time: 2774.7058358192444
Median: 1.0826038085533964
Mean: 1.4045129295801333
Epoch: 100  Avg Loss: 6.846819868466507e-07, Time: 3473.196921825409
Median: 1.0788684341139938
Mean: 1.3818407950326816
Epoch: 120  Avg Loss: 6.433513170729081e-07, Time: 4168.938290834427
Median: 1.073025687601357
Mean: 1.3610806416361587
Epoch: 140  Avg Loss: 6.067492885954885e-07, Time: 4859.934057474136
Median: 1.070786768600955
Mean: 1.3437374549677246
Epoch: 160  Avg Loss: 5.82302518887445

In [15]:
methods = {
    'get_sample' : get_job_table_sample,
    'encoding': encoding,
    'cost_norm': cost_norm,
    'hist_file': hist_file,
    'model': model,
    'device': args.device,
    'bs': 512,
}

In [16]:
_ = eval_workload('job-light', methods)

Loaded queries with len  70
Loaded bitmaps
Median: 1.6015447359157347
Mean: 15.04861380976482
Corr:  0.8955015382416885


In [17]:
_ = eval_workload('synthetic', methods)

Loaded queries with len  5000
Loaded bitmaps
Median: 1.0554397104507522
Mean: 1.7017223965744472
Corr:  0.9835725288032631
