# 0. Import Packages

In [3]:
import pandas as pd
import json
import re
import math
import psycopg2
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import time
import datetime
from itertools import chain
from random import shuffle
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from gensim.models import Word2Vec
from gensim.models import KeyedVectors
torch.set_printoptions(precision=10, threshold=None, edgeitems=None, linewidth=None, profile=None)

# 1. Load Data

In [4]:
# load (1) table data from csv files, (2) table column names and (3) index names
def prepare_dataset(prefix_path):
    data={}
    data["customer"] = pd.read_csv(prefix_path+'/customer.csv',header=None,sep='|')
    data["lineitem"] = pd.read_csv(prefix_path+'/lineitem.csv',header=None,sep='|')
    data["nation"] = pd.read_csv(prefix_path+'/nation.csv',header=None,sep='|')
    data["orders"] = pd.read_csv(prefix_path+'/orders.csv',header=None,sep='|')
    data["part"] = pd.read_csv(prefix_path+'/part.csv',header=None,sep='|')
    data["partsupp"] = pd.read_csv(prefix_path+'/partsupp.csv',header=None,sep='|')
    data["region"] = pd.read_csv(prefix_path+'/region.csv',header=None,sep='|')
    data["supplier"] = pd.read_csv(prefix_path+'/supplier.csv',header=None,sep='|')
    customer_column = [
        'c_custkey',
        'c_name',
        'c_address',
        'c_nationkey',
        'c_phone',
        'c_acctbal',
        'c_mktsegment',
        'c_comment'
    ]
    lineitem_column = [
        'l_orderkey',
        'l_partkey',
        'l_suppkey',
        'l_linenumber',
        'l_quantity',
        'l_extendedprice',
        'l_discount',
        'l_tax',
        'l_returnflag',
        'l_linestatus',
        'l_shipdate',
        'l_commitdate',
        'l_receiptdate',
        'l_shipinstruct',
        'l_shipmode',
        'l_comment'
    ]
    nation_column = [
        'n_nationkey',
        'n_name',
        'n_regionkey',
        'n_comment'
    ]
    orders_column = [
        'o_orderkey',
        'o_custkey',
        'o_orderstatus',
        'o_totalprice',
        'o_orderdate',
        'o_orderpriority',
        'o_clerk',
        'o_shippriority',
        'o_comment'
    ]
    part_column = [
        'p_partkey',
        'p_name',
        'p_mfgr',
        'p_brand',
        'p_type',
        'p_size',
        'p_container',
        'p_retailprice',
        'p_comment'
    ]
    partsupp_column = [
        'ps_partkey',
        'ps_suppkey',
        'ps_availqty',
        'ps_supplycost',
        'ps_comment'
    ]
    region_column = [
        'r_regionkey',
        'r_name',
        'r_comment'
    ]
    supplier_column = [
        's_suppkey',
        's_name',
        's_address',
        's_nationkey',
        's_phone',
        's_acctbal',
        's_comment'
    ]
    data["customer"].columns = customer_column
    data["lineitem"].columns = lineitem_column
    data["nation"].columns = nation_column
    data["orders"].columns = orders_column
    data["part"].columns = part_column
    data["partsupp"].columns = partsupp_column
    data["region"].columns = region_column
    data["supplier"].columns = supplier_column

    column2pos = {}
    column2pos["customer"] = customer_column
    column2pos["lineitem"] = lineitem_column
    column2pos["nation"] = nation_column
    column2pos["orders"] = orders_column
    column2pos["part"] = part_column
    column2pos["partsupp"] = partsupp_column
    column2pos["region"] = region_column
    column2pos["supplier"] = supplier_column

    tables = ['customer', 'lineitem', 'nation', 'orders', 'part', 'partsupp', 'region', 'supplier']
    indexes = ['customer_pkey', 'idx_customer_nationkey','customer_c_nationkey_fkey',  # customer
               'lineitem_pkey', 'idx_lineitem_orderkey', 'idx_lineitem_part_supp', 'idx_lineitem_shipdate', 'lineitem_l_orderkey_fkey', 'lineitem_l_partkey_l_suppkey_fkey',  # lineitem
               'nation_pkey', 'idx_nation_regionkey', 'nation_n_regionkey_fkey'  # nation
               'orders_pkey', 'idx_orders_custkey', 'idx_orders_orderdate', 'orders_o_custkey_fkey',  # orders           
               'part_pkey',  # part
               'partsupp_pkey', 'idx_partsupp_partkey', 'idx_partsupp_suppkey', 'partsupp_ps_partkey_fkey', 'partsupp_ps_suppkey_fkey',  # partsupp
               'region_pkey',  # region
               'supplier_pkey', 'idx_supplier_nation_key', 'supplier_s_nationkey_fkey'  # supplier
               ]
    indexes_id = {}
    for idx, index in enumerate(indexes):
        indexes_id[index] = idx + 1
    physic_ops_id = {'Materialize':1, 'Sort':2, 'Hash':3, 'Merge Join':4, 'Bitmap Index Scan':5,
     'Index Only Scan':6, 'BitmapAnd':7, 'Nested Loop':8, 'Aggregate':9, 'Result':10,
     'Hash Join':11, 'Seq Scan':12, 'Bitmap Heap Scan':13, 'Index Scan':14, 'BitmapOr':15, 'Memoize': 16, 'Gather': 17}
    strategy_id = {'Plain':1}
    compare_ops_id = {'=':1, '>':2, '<':3, '!=':4, '~~':5, '!~~':6, '!Null': 7, '>=':8, '<=':9}
    bool_ops_id = {'AND':1,'OR':2}
    tables_id = {}
    columns_id = {}
    columns_list = customer_column + lineitem_column + nation_column + orders_column + part_column + partsupp_column + region_column + supplier_column
    table_id = 1
    column_id = 1
    for table_name in tables:
        tables_id[table_name] = table_id
        table_id += 1
        for column in column2pos[table_name]:
            columns_id[table_name+'.'+column] = column_id
            column_id += 1
    return data, indexes_id, tables_id, columns_id, columns_list,  physic_ops_id, compare_ops_id, bool_ops_id

In [5]:
# load word_vetctors which is a KeyedVectors object
def load_dictionary(path):
    word_vectors = KeyedVectors.load(path, mmap='r')
    return word_vectors

In [6]:
# load min_max_column which is a dictionary
def load_numeric_min_max(path):
    with open(path,'r') as f:
        min_max_column = json.loads(f.read())
    return min_max_column

# 2. Data Processing

In [77]:
# generate (table) prefix based on column name
def determine_prefix(column):
    relation_name = column.split('.')[0]
    column_name = column.split('.')[1]
    if relation_name == 'customer':
        if column_name == 'c_mktsegment':
            return 'mktsegment_'
        else:
            print (column)
            raise
    elif relation_name == 'lineitem':
        if column_name == 'l_returnflag':
            return 'returnflag_'
        elif column_name == 'l_shipmode':
            return 'shipmode_'
        elif column_name == 'l_shipinstruct':
            return 'shipinstruct_'
        else:
            print (column)
            raise
    elif relation_name == 'nation':
        if column_name == 'n_name':
            return 'name_'
    elif relation_name == 'orders':
        if column_name == 'o_orderstatus':
            return 'orderstatus_'
        elif column_name == 'o_orderpriority':
            return 'orderpriority_'
        elif column_name == 'o_comment':
            return 'comment_'
        else:
            print (column)
            raise
    elif relation_name == 'part':
        if column_name == 'p_name':
            return 'name_'
        elif column_name == 'p_brand':
            return 'brand_'
        elif column_name == 'p_type':
            return 'type_'
        elif column_name == 'p_container':
            return 'container_'
        else:
            print (column)
            raise
    elif relation_name == 'region':
        if column_name == 'r_name':
            return 'name_'
        else:
            print (column)
            raise
    elif relation_name == 'supplier':
        if column_name == 's_comment':
            return 'comment_'
    else:
        print (column)
        raise

In [8]:
# generate vector representation (word vector + hash vector) for input value
def get_representation(value):
    if value in word_vectors:
        embedded_result = np.array(list(word_vectors[value]))
    else:
        embedded_result = np.array([0.0 for _ in range(500)])
    hash_result = np.array([0.0 for _ in range(500)])
    for t in value:
        hash_result[hash(t) % 500] = 1.0
    return np.concatenate((embedded_result, hash_result), 0)

In [9]:
# generate vector representation based on get_representation()
def get_str_representation(value, column):
    vec = np.array([])
    count = 0
    prefix = determine_prefix(column)
    for v in value.split('%'):
        if len(v) > 0:
            if len(vec) == 0:
                vec = get_representation(prefix+v)
                count = 1
            else:
                new_vec = get_representation(prefix+v)
                vec = vec + new_vec
                count += 1
    if count > 0:
        vec = vec / float(count)
    return vec

In [90]:
# code to check if the right value is a date
def is_valid_date(date_string, date_format):
    try:
        datetime.strptime(date_string, date_format)
        return True
    except ValueError:
        return False

# mapping from abbreviation to table name
tpch_alias2table = {'l': 'lineitem', 'c': 'customer', 'o': 'orders', 'n1': 'nation', 'n2': 'nation', 's': 'supplier', 'p': 'part', 'ps': 'partsupp', 'r': 'region'}

# generate vector representation for condition operator
def encode_condition_op(condition_op, relation_name, index_name):
    # bool_operator + left_value + compare_operator + right_value
    if condition_op == None:
        vec = [0 for _ in range(condition_op_dim)]
    elif condition_op['op_type'] == 'Bool':
        idx = bool_ops_id[condition_op['operator']]
        vec = [0 for _ in range(bool_ops_total_num)]
        vec[idx-1] = 1
    else:
        operator = condition_op['operator']
        left_value = condition_op['left_value']
        if re.match(r'.+\..+', left_value) == None:
            if relation_name == None:
                relation_name = index_name.split(left_value)[1].strip('_')
            left_value = relation_name + '.' + left_value
        else:
            relation_name = left_value.split('.')[0]
        left_value_idx = columns_id[left_value]
        left_value_vec = [0 for _ in range(column_total_num)]
        left_value_vec[left_value_idx-1] = 1
        right_value = condition_op['right_value']
        print(right_value)
        column_name = left_value.split('.')[1]
        if re.match(r'^[a-z][a-zA-Z0-9_]*\.[a-z][a-zA-Z0-9_]*$', right_value) != None and right_value.split('.')[0] in data:
            operator_idx = compare_ops_id[operator]
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            right_value_idx = columns_id[right_value]
            right_value_vec = [0]
            left_value_vec[right_value_idx-1] = 1
        # lql: add case when the right value is a column without explicit relation name 
        elif right_value in columns_list:
            operator_idx = compare_ops_id[operator]
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            right_relation_name = tpch_alias2table[right_value.split('_')[0]]
            right_value_idx = columns_id[right_relation_name + '.' + right_value]
            right_value_vec = [0]
            left_value_vec[right_value_idx-1] = 1
        # lql: add case when right_value is a date
        elif is_valid_date(right_value, '%Y-%m-%d %H:%M:%S') or is_valid_date(right_value, '%Y-%m-%d'):
            if is_valid_date(right_value, '%Y-%m-%d %H:%M:%S'):
                right_value = datetime.strptime(right_value, '%Y-%m-%d %H:%M:%S').date()
            else:
                right_value = datetime.strptime(right_value, '%Y-%m-%d').date()
            min_date = datetime.strptime(min_max_column[relation_name][column_name]['min'], '%Y-%m-%d').date()
            max_date = datetime.strptime(min_max_column[relation_name][column_name]['max'], '%Y-%m-%d').date()
            right_value_vec = [(right_value - min_date).days / float((max_date - min_date).days)]
            operator_idx = compare_ops_id[operator]
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
        elif data[relation_name].dtypes[column_name] == 'int64' or data[relation_name].dtypes[column_name] == 'float64':
            right_value = float(right_value)
            value_max = float(min_max_column[relation_name][column_name]['max'])
            value_min = float(min_max_column[relation_name][column_name]['min'])
            right_value_vec = [(right_value - value_min) / (value_max - value_min)]
            operator_idx = compare_ops_id[operator]
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
        elif re.match(r'^__LIKE__', right_value) != None:
            operator_idx = compare_ops_id['~~']
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            right_value = right_value.strip('\'')[8:]
            right_value_vec = get_str_representation(right_value, left_value).tolist()
        elif re.match(r'^__NOTLIKE__', right_value) != None:
            operator_idx = compare_ops_id['!~~']
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            right_value = right_value.strip('\'')[11:]
            right_value_vec = get_str_representation(right_value, left_value).tolist()
        elif re.match(r'^__NOTEQUAL__', right_value) != None:
            operator_idx = compare_ops_id['!=']
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            right_value = right_value.strip('\'')[12:]
            right_value_vec = get_str_representation(right_value, left_value).tolist()
        elif re.match(r'^__ANY__', right_value) != None:
            operator_idx = compare_ops_id['=']
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            right_value = right_value.strip('\'')[7:].strip('{}')
            right_value_vec = []
            count = 0
            for v in right_value.split(','):
                v = v.strip('"').strip('\'')
                if len(v) > 0:
                    count += 1
                    vec = get_str_representation(v, left_value).tolist()
                    if len(right_value_vec) == 0:
                        right_value_vec = [0 for _ in vec]
                    for idx, vv in enumerate(vec):
                        right_value_vec[idx] += vv
            for idx in range(len(right_value_vec)):
                right_value_vec[idx] /= len(right_value.split(','))
        elif right_value == 'None':
            operator_idx = compare_ops_id['!Null']
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            if operator == 'IS':
                right_value_vec = [1]
            elif operator == '!=':
                right_value_vec = [0]
            else:
                print (operator)
                raise
        else:
#             print (left_value, operator, right_value)
            operator_idx = compare_ops_id[operator]
            operator_vec = [0 for _ in range(compare_ops_total_num)]
            operator_vec[operator_idx-1] = 1
            right_value_vec = get_str_representation(right_value, left_value).tolist()
        vec = [0 for _ in range(bool_ops_total_num)]
        vec = vec + left_value_vec + operator_vec + right_value_vec
    num_pad = condition_op_dim - len(vec)
    result = np.pad(vec, (0, num_pad), 'constant')
#     print 'condition op: ', result
    return result

In [11]:
# generate vector representation for condition
def encode_condition(condition, relation_name, index_name, condition_max_num):
    if len(condition) == 0:
        vecs = [[0 for _ in range(condition_op_dim)]]
    else:
        vecs = [encode_condition_op(condition_op, relation_name, index_name) for condition_op in condition]
    num_pad = condition_max_num - len(vecs)
    result = np.pad(vecs, ((0, num_pad),(0,0)), 'constant')
    return result

In [12]:
# generate vector representation for samples
def encode_sample(sample):
    return np.array([int(i) for i in sample])

In [13]:
# bitand operation for samples
def bitand(sample1, sample2):
    return np.minimum(sample1, sample2)

In [93]:
# generate all vector representations for a node
def encode_node_job(node, condition_max_num):
    # operator + first_condition + second_condition + relation
    extra_info_num = max(column_total_num, table_total_num, index_total_num)
    operator_vec = np.array([0 for _ in range(physic_op_total_num)])

    extra_info_vec = np.array([0 for _ in range(extra_info_num)])
    condition1_vec = np.array([[0 for _ in range(condition_op_dim)] for _ in range(condition_max_num)])
    condition2_vec = np.array([[0 for _ in range(condition_op_dim)] for _ in range(condition_max_num)])
    ### Samples Starts
    sample_vec = np.array([1 for _ in range(1000)])
    ### Samples Ends
    has_condition = 0
    if node != None:
        operator = node['node_type']
        operator_idx = physic_ops_id[operator]
        operator_vec[operator_idx-1] = 1
        if operator == 'Materialize' or operator == 'BitmapAnd' or operator == 'Result':
            pass
        elif operator == 'Sort':
            for key in node['sort_keys']:
                # lql: add a check statement
                if key in columns_id:
                    extra_info_inx = columns_id[key]
                    extra_info_vec[extra_info_inx-1] = 1
        elif operator == 'Hash Join' or operator == 'Merge Join' or operator == 'Nested Loop':
            condition1_vec = encode_condition(node['condition'], None, None, condition_max_num)
        elif operator == 'Aggregate':
            for key in node['group_keys']:
                # lql: add a check statement
                if key in columns_id:
                    extra_info_inx = columns_id[key]
                    extra_info_vec[extra_info_inx-1] = 1
        elif operator == 'Seq Scan' or operator == 'Bitmap Heap Scan' or operator == 'Index Scan' or operator == 'Bitmap Index Scan' or operator == 'Index Only Scan':
            relation_name = node['relation_name']
            index_name = node['index_name']
            if relation_name != None:
                extra_info_inx = tables_id[relation_name]
            else:
                extra_info_inx = indexes_id[index_name]
            extra_info_vec[extra_info_inx-1] = 1
            condition1_vec = encode_condition(node['condition_filter'], relation_name, index_name, condition_max_num)
            condition2_vec = encode_condition(node['condition_index'], relation_name, index_name, condition_max_num)
            if 'bitmap' in node:
                ### Samples Starts
                sample_vec = encode_sample(node['bitmap'])
                ### Samples Ends
                has_condition = 1
            if 'bitmap_filter' in node:
                ### Samples Starts
                sample_vec = bitand(encode_sample(node['bitmap_filter']), sample_vec)
                ### Samples Ends
                has_condition = 1
            if 'bitmap_index' in node:
                ### Samples Starts
                sample_vec = bitand(encode_sample(node['bitmap_index']), sample_vec)
                ### Samples Ends
                has_condition = 1

#     print 'operator: ', operator_vec
#     print 'extra_infos: ', extra_info_vec
    return operator_vec, extra_info_vec, condition1_vec, condition2_vec, sample_vec, has_condition

# 3. Util Functions related to Tree Structure

In [15]:
# define TreeNode class
class TreeNode(object):
    def __init__(self, current_vec, parent, idx, level_id):
        self.item = current_vec
        self.idx = idx
        self.level_id = level_id
        self.parent = parent
        self.children = []
    def get_parent(self):
        return self.parent
    def get_item(self):
        return self.item
    def get_children(self):
        return self.children
    def add_child(self, child):
        self.children.append(child)
    def get_idx(self):
        return self.idx
    def __str__(self):
        return 'level_id: ' + self.level_id + '; idx: ' + self.idx

In [16]:
# generate tree (consisted of TreeNode) from vector representation
def recover_tree(vecs, parent, start_idx):
    if len(vecs) == 0:
        return vecs, start_idx
    if vecs[0] == None:
        return vecs[1:], start_idx+1
    node = TreeNode(vecs[0], parent, start_idx, -1)
    while True:
        vecs, start_idx = recover_tree(vecs[1:], node, start_idx+1)
        parent.add_child(node)
        if len(vecs) == 0:
            return vecs, start_idx
        if vecs[0] == None:
            return vecs[1:], start_idx+1
        node = TreeNode(vecs[0], parent, start_idx, -1)

In [17]:
# level limited DFS, append nodes to nodes_by_level
def dfs_tree_to_level(root, level_id, nodes_by_level):
    root.level_id = level_id
    if len(nodes_by_level) <= level_id:
        nodes_by_level.append([])
    nodes_by_level[level_id].append(root)
    root.idx = len(nodes_by_level[level_id])
    for c in root.get_children():
        dfs_tree_to_level(c, level_id+1, nodes_by_level)

In [18]:
# debug function to test level limited DFS
def debug_nodes_by_level(nodes_by_level):
    for nodes in nodes_by_level:
        for node in nodes:
            whitespace = ''
            for i in range(node.level_id):
                whitespace += ' '
            print (whitespace + 'level_id: ' + str(node.level_id))
            print (whitespace + 'idx: ' + str(node.idx))

# 4. Util Functions related to Plan

In [19]:
# generate corresponding tree structure for an input plan
def encode_plan_job(plan, condition_max_num):
    operators, extra_infos, condition1s, condition2s, samples, condition_masks = [], [], [], [], [], []
    mapping = []

    nodes_by_level = []
    node = TreeNode(plan[0], None, 0, -1)
    recover_tree(plan[1:], node, 1)
    dfs_tree_to_level(node, 0, nodes_by_level)

#     print (plan)
#     debug_nodes_by_level(nodes_by_level)


    for level in nodes_by_level:
        operators.append([])
        extra_infos.append([])
        condition1s.append([])
        condition2s.append([])
        samples.append([])
        condition_masks.append([])
        mapping.append([])
        for node in level:
            operator, extra_info, condition1, condition2, sample, condition_mask = encode_node_job(node.item, condition_max_num)
            operators[-1].append(operator)
            extra_infos[-1].append(extra_info)
            condition1s[-1].append(condition1)
            condition2s[-1].append(condition2)
            samples[-1].append(sample)
            condition_masks[-1].append(condition_mask)
            if len(node.children) == 2:
                mapping[-1].append([n.idx for n in node.children])
            elif len(node.children) == 1:
                mapping[-1].append([node.children[0].idx, 0])
            else:
                mapping[-1].append([0, 0])
#     num_pad = plan_node_max_num - len(operators)
#     masks = [0 for _ in range(plan_node_max_num)]
#     for i in range(len(operators)):
#         if operators[i].sum() > 0:
#             masks[i] = 1
#         else:
#             masks[i] = 0
#     masks = np.array(masks)
#     condition_masks = np.array(condition_masks)
#     operators, extra_infos, condition1s, condition2s = np.pad(operators, ((0, num_pad), (0,0)), 'constant'), np.pad(extra_infos, ((0, num_pad), (0,0)), 'constant'),np.pad(condition1s, ((0, num_pad), (0,0), (0,0)), 'constant'),np.pad(condition2s, ((0, num_pad), (0,0), (0,0)), 'constant')
#     samples = np.pad(samples, ((0, num_pad), (0,0)), 'constant')
#     condition_masks = np.pad(condition_masks, (0, num_pad), 'constant')
    return operators, extra_infos, condition1s, condition2s, samples, condition_masks, mapping

In [20]:
# normalize label
def normalize_label(labels, mini, maxi):
    labels_norm = (torch.log(labels) - mini) / (maxi - mini)
    labels_norm = torch.min(labels_norm, torch.ones_like(labels_norm))
    labels_norm = torch.max(labels_norm, torch.zeros_like(labels_norm))
    return labels_norm

In [21]:
# unnormalize values
def unnormalize(vecs, mini, maxi):
    return torch.exp(vecs * (maxi - mini) + mini)

In [22]:
# generate plan_node_max_num, condition_max_num, cost_label_min, cost_label_max, card_label_min, card_label_max
def obtain_upper_bound_query_size(path):
    plan_node_max_num = 0
    condition_max_num = 0
    cost_label_max = 0.0
    cost_label_min = 9999999999.0
    card_label_max = 0.0
    card_label_min = 9999999999.0
    plans = []
    with open(path, 'r') as f:
        for plan in f.readlines():
            plan = json.loads(plan)
            plans.append(plan)
            cost = plan['cost']
            cardinality = plan['cardinality']
            if cost > cost_label_max:
                cost_label_max = cost
            elif cost < cost_label_min:
                cost_label_min = cost
            if cardinality > card_label_max:
                card_label_max = cardinality
            elif cardinality < card_label_min:
                card_label_min = cardinality
            sequence = plan['seq']
            plan_node_num = len(sequence)
            if plan_node_num > plan_node_max_num:
                plan_node_max_num = plan_node_num
            for node in sequence:
                if node == None:
                    continue
                if 'condition_filter' in node:
                    condition_num = len(node['condition_filter'])
                    if condition_num > condition_max_num:
                        condition_max_num = condition_num
                if 'condition_index' in node:
                    condition_num = len(node['condition_index'])
                    if condition_num > condition_max_num:
                        condition_max_num = condition_num
    cost_label_min, cost_label_max = math.log(cost_label_min), math.log(cost_label_max)
    # lql: adpat to the min card is 0
    if card_label_min <= 0:
        card_label_min = 1
    card_label_min, card_label_max = math.log(card_label_min), math.log(card_label_max)
    
    
    print('plan_node_max_num: ', plan_node_max_num)
    print('condition_max_num: ', condition_max_num)
    print('cost_label_min: ', cost_label_min)
    print('cost_label_max: ', cost_label_max)
    print('card_label_min: ', card_label_min)
    print('card_label_max: ', card_label_max)
    return plan_node_max_num, condition_max_num, cost_label_min, cost_label_max, card_label_min, card_label_max

In [23]:
# merge plan level 2 into level1?
def merge_plans_level(level1, level2, isMapping=False):
    for idx, level in enumerate(level2):
        if idx >= len(level1):
            level1.append([])
        if isMapping:
            if idx < len(level1) - 1:
                base = len(level1[idx+1])
                for i in range(len(level)):
                    if level[i][0] > 0:
                        level[i][0] += base
                    if level[i][1] > 0:
                        level[i][1] += base
        level1[idx] += level
    return level1

In [73]:
# generate training data from plans loaded from json file
def make_data_job(plans):
    target_cost_batch = []
    target_card_batch = []
    operators_batch = []
    extra_infos_batch = []
    condition1s_batch = []
    condition2s_batch = []
    node_masks_batch = []
    samples_batch = []
    condition_masks_batch = []
    mapping_batch = []

    # for plan in plans:
    # lql: print idx for debug
    for i in range(len(plans)):
        print('idx: ', i)
        plan = plans[i]
        target_cost = plan['cost']
        target_cardinality = plan['cardinality']
        target_cost_batch.append(target_cost)
        target_card_batch.append(target_cardinality)
        plan = plan['seq']
        operators, extra_infos, condition1s, condition2s, samples, condition_masks, mapping = encode_plan_job(plan, condition_max_num)

        operators_batch = merge_plans_level(operators_batch, operators)
        extra_infos_batch = merge_plans_level(extra_infos_batch, extra_infos)
        condition1s_batch = merge_plans_level(condition1s_batch, condition1s)
        condition2s_batch = merge_plans_level(condition2s_batch, condition2s)
        samples_batch = merge_plans_level(samples_batch, samples)
        condition_masks_batch = merge_plans_level(condition_masks_batch, condition_masks)
        mapping_batch = merge_plans_level(mapping_batch, mapping, True)
    max_nodes = 0
    for o in operators_batch:
        if len(o) > max_nodes:
            max_nodes = len(o)
    print (max_nodes)
    print (len(condition2s_batch))
    operators_batch = np.array([np.pad(v, ((0, max_nodes - len(v)),(0,0)), 'constant') for v in operators_batch])
    extra_infos_batch = np.array([np.pad(v, ((0, max_nodes - len(v)),(0,0)), 'constant') for v in extra_infos_batch])
    condition1s_batch = np.array([np.pad(v, ((0, max_nodes - len(v)),(0,0),(0,0)), 'constant') for v in condition1s_batch])
    condition2s_batch = np.array([np.pad(v, ((0, max_nodes - len(v)),(0,0),(0,0)), 'constant') for v in condition2s_batch])
    samples_batch = np.array([np.pad(v, ((0, max_nodes - len(v)),(0,0)), 'constant') for v in samples_batch])
    condition_masks_batch = np.array([np.pad(v, (0, max_nodes - len(v)), 'constant') for v in condition_masks_batch])
    mapping_batch = np.array([np.pad(v, ((0, max_nodes - len(v)),(0,0)), 'constant') for v in mapping_batch])

    print ('operators_batch: ', operators_batch.shape)

    target_cost_batch = torch.FloatTensor(target_cost_batch)
    target_card_batch = torch.FloatTensor(target_card_batch)
    operators_batch = torch.FloatTensor([operators_batch])
    extra_infos_batch = torch.FloatTensor([extra_infos_batch])
    condition1s_batch = torch.FloatTensor([condition1s_batch])
    condition2s_batch = torch.FloatTensor([condition2s_batch])
    samples_batch = torch.FloatTensor([samples_batch])
    condition_masks_batch = torch.FloatTensor([condition_masks_batch])
    mapping_batch = torch.FloatTensor([mapping_batch])

    target_cost_batch = normalize_label(target_cost_batch, cost_label_min, cost_label_max)
    target_card_batch = normalize_label(target_card_batch, card_label_min, card_label_max)

    return (target_cost_batch, target_card_batch, operators_batch, extra_infos_batch, condition1s_batch, condition2s_batch, samples_batch, condition_masks_batch, mapping_batch)

In [25]:
# split data into chunks
def chunks(arr, batch_size):
    return [arr[i:i+batch_size] for i in range(0, len(arr), batch_size)]

In [26]:
# save training data into npy files
def save_data_job(plans, istest = False, batch_size=64, directory='/home/sunji/learnedcardinality/job'):
    if istest:
        suffix = 'test_'
    else:
        suffix = ''
    batch_id = 0
    for batch_id, plans_batch in enumerate(chunks(plans, batch_size)):
        print ('batch_id', batch_id, len(plans_batch))
        target_cost_batch, target_cardinality_batch, operators_batch, extra_infos_batch, condition1s_batch, condition2s_batch, samples_batch, condition_masks_batch, mapping_batch = make_data_job(plans_batch)
        np.save(directory+'/target_cost_'+suffix+str(batch_id)+'.np', target_cost_batch.numpy())
        np.save(directory+'/target_cardinality_'+suffix+str(batch_id)+'.np', target_cardinality_batch.numpy())
        np.save(directory+'/operators_'+suffix+str(batch_id)+'.np', operators_batch.numpy())
        np.save(directory+'/extra_infos_'+suffix+str(batch_id)+'.np', extra_infos_batch.numpy())
        np.save(directory+'/condition1s_'+suffix+str(batch_id)+'.np', condition1s_batch.numpy())
        np.save(directory+'/condition2s_'+suffix+str(batch_id)+'.np', condition2s_batch.numpy())
        np.save(directory+'/samples_'+suffix+str(batch_id)+'.np', samples_batch.numpy())
        np.save(directory+'/condition_masks_'+suffix+str(batch_id)+'.np', condition_masks_batch.numpy())
        np.save(directory+'/mapping_'+suffix+str(batch_id)+'.np', mapping_batch.numpy())
        print ('saved: ', str(batch_id))

In [27]:
# load training data from npy files
def get_batch_job(batch_id, istest=False, directory='tlstm_files/job'):
    if istest:
        suffix = 'test_'
    else:
        suffix = ''
    target_cost_batch = np.load(directory+'/target_cost_'+suffix+str(batch_id)+'.np.npy')
    target_cardinality_batch = np.load(directory+'/target_cardinality_'+suffix+str(batch_id)+'.np.npy')
    operators_batch = np.load(directory+'/operators_'+suffix+str(batch_id)+'.np.npy')
    extra_infos_batch = np.load(directory+'/extra_infos_'+suffix+str(batch_id)+'.np.npy')
    condition1s_batch = np.load(directory+'/condition1s_'+suffix+str(batch_id)+'.np.npy')
    condition2s_batch = np.load(directory+'/condition2s_'+suffix+str(batch_id)+'.np.npy')
    samples_batch = np.load(directory+'/samples_'+suffix+str(batch_id)+'.np.npy')
    condition_masks_batch = np.load(directory+'/condition_masks_'+suffix+str(batch_id)+'.np.npy')
    mapping_batch = np.load(directory+'/mapping_'+suffix+str(batch_id)+'.np.npy')
    return target_cost_batch, target_cardinality_batch, operators_batch, extra_infos_batch, condition1s_batch, condition2s_batch, samples_batch, condition_masks_batch, mapping_batch

In [96]:
# store train data from json file in path into npy files
def encode_train_plan_seq_save(path, batch_size=64, directory='tlstm_files/job'):
    train_plans = []
    with open(path, 'r') as f:
        for idx, seq in enumerate(f.readlines()):
            # if idx < 2500:
            #     continue
            plan = json.loads(seq)
            train_plans.append(plan)
#     shuffle(test_plans)
    save_data_job(plans=train_plans, batch_size=batch_size, directory=directory)
# store test data from json file in path into npy files
def encode_test_plan_seq_save(path, batch_size=64, directory='tlstm_files/job'):
    test_plans = []
    with open(path, 'r') as f:
        for idx, seq in enumerate(f.readlines()):
            plan = json.loads(seq)
            test_plans.append(plan)
#     shuffle(test_plans)
    save_data_job(plans=test_plans, istest=True, batch_size=batch_size, directory=directory)

# 5. Util Functions related to Model

In [109]:
class D Representation(nn.Module):
    def __init__(self, input_dim, hidden_dim, hid_dim, middle_result_dim, task_num):
        super(Representation, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm1 = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.batch_norm1 = nn.BatchNorm1d(hid_dim)
        # The linear layer that maps from hidden state space to tag space
        
        self.sample_mlp = nn.Linear(1000, hid_dim)
        self.condition_mlp = nn.Linear(hidden_dim, hid_dim)
        
        # 17 is the number of operators, 61 is the number of extra infos
        self.lstm2 = nn.LSTM(17+61+2*hid_dim, hidden_dim, batch_first=True)
        self.batch_norm2 = nn.BatchNorm1d(hidden_dim)
        # The linear layer that maps from hidden state space to tag space
        self.hid_mlp2_task1 = nn.Linear(hidden_dim, hid_dim)
        self.hid_mlp2_task2 = nn.Linear(hidden_dim, hid_dim)
        self.batch_norm3 = nn.BatchNorm1d(hid_dim)
        self.hid_mlp3_task1 = nn.Linear(hid_dim, hid_dim)
        self.hid_mlp3_task2 = nn.Linear(hid_dim, hid_dim)
        self.out_mlp2_task1 = nn.Linear(hid_dim, 1)
        self.out_mlp2_task2 = nn.Linear(hid_dim, 1)

    def init_hidden(self, hidden_dim, batch_size=1):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros(1, batch_size, hidden_dim),
                torch.zeros(1, batch_size, hidden_dim))
    
    def forward(self, operators, extra_infos, condition1s, condition2s, samples, condition_masks, mapping):
        # condition1
        batch_size = 0
        for i in range(operators.size()[1]):
            if operators[0][i].sum(0) != 0:
                batch_size += 1
            else:
                break
        print ('batch_size: ', batch_size)
        
        # print (operators.size())
        # print (extra_infos.size())
        # print (condition1s.size())
        # print (condition2s.size())
        # print (samples.size())
        # print (condition_masks.size())
        # print (mapping.size())
        
        # torch.Size([5, 64, 17])
        # torch.Size([5, 64, 61])
        # torch.Size([5, 64, 80, 1072])
        # torch.Size([5, 64, 80, 1072])
        # torch.Size([5, 64, 1000])
        # torch.Size([5, 64, 1])
        # torch.Size([5, 64, 2])     
          
        num_level = condition1s.size()[0]
        num_node_per_level = condition1s.size()[1]
        num_condition_per_node = condition1s.size()[2]
        condition_op_length = condition1s.size()[3]
        
        inputs = condition1s.view(num_level * num_node_per_level, num_condition_per_node, condition_op_length)
        hidden = self.init_hidden(self.hidden_dim, num_level * num_node_per_level)
        
        out, hid = self.lstm1(inputs, hidden)
        last_output1 = hid[0].view(num_level * num_node_per_level, -1)
        
        # condition2
        num_level = condition2s.size()[0]
        num_node_per_level = condition2s.size()[1]
        num_condition_per_node = condition2s.size()[2]
        condition_op_length = condition2s.size()[3]
        
        inputs = condition2s.view(num_level * num_node_per_level, num_condition_per_node, condition_op_length)
        hidden = self.init_hidden(self.hidden_dim, num_level * num_node_per_level)
        
        out, hid = self.lstm1(inputs, hidden)
        last_output2 = hid[0].view(num_level * num_node_per_level, -1)
        
        last_output1 = F.relu(self.condition_mlp(last_output1))
        last_output2 = F.relu(self.condition_mlp(last_output2))
        last_output = (last_output1 + last_output2) / 2
        last_output = self.batch_norm1(last_output).view(num_level, num_node_per_level, -1)
        
#         print (last_output.size())
#         torch.Size([14, 133, 256])
        
        sample_output = F.relu(self.sample_mlp(samples))
        sample_output = sample_output * condition_masks

        out = torch.cat((operators, extra_infos, last_output, sample_output), 2)
#         print (out.size())
#         torch.Size([14, 133, 635])
#         out = out * node_masks
        start = time.time()
        hidden = self.init_hidden(self.hidden_dim, num_node_per_level)
        last_level = out[num_level-1].view(num_node_per_level, 1, -1)
#         torch.Size([133, 1, 635])
        _, (hid, cid) = self.lstm2(last_level, hidden)
        mapping = mapping.long()
        for idx in reversed(range(0, num_level-1)):
            mapp_left = mapping[idx][:,0]
            mapp_right = mapping[idx][:,1]
            pad = torch.zeros_like(hid)[:,0].unsqueeze(1)
            next_hid = torch.cat((pad, hid), 1)
            pad = torch.zeros_like(cid)[:,0].unsqueeze(1)
            next_cid = torch.cat((pad, cid), 1)
            hid_left = torch.index_select(next_hid, 1, mapp_left)
            cid_left = torch.index_select(next_cid, 1, mapp_left)
            hid_right = torch.index_select(next_hid, 1, mapp_right)
            cid_right = torch.index_select(next_cid, 1, mapp_right)
            hid = (hid_left + hid_right) / 2
            cid = (cid_left + cid_right) / 2
            last_level = out[idx].view(num_node_per_level, 1, -1)
            _, (hid, cid) = self.lstm2(last_level, (hid, cid))
        output = hid[0]
#         print (output.size())
#         torch.Size([133, 128])
        end = time.time()
        print ('Forest Evaluate Running Time: ', end - start)
        last_output = output[0:batch_size]
        out = self.batch_norm2(last_output)
        
        out_task1 = F.relu(self.hid_mlp2_task1(out))
        out_task1 = self.batch_norm3(out_task1)
        out_task1 = F.relu(self.hid_mlp3_task1(out_task1))
        out_task1 = self.out_mlp2_task1(out_task1)
        out_task1 = F.sigmoid(out_task1)
        
        out_task2 = F.relu(self.hid_mlp2_task2(out))
        out_task2 = self.batch_norm3(out_task2)
        out_task2 = F.relu(self.hid_mlp3_task2(out_task2))
        out_task2 = self.out_mlp2_task2(out_task2)
        out_task2 = F.sigmoid(out_task2)
#         print 'out: ', out.size()
        # batch_size * task_num
        return out_task1, out_task2

In [30]:
# implement q-error loss function
def qerror_loss(preds, targets, mini, maxi):
    qerror = []
    preds = unnormalize(preds, mini, maxi)
    targets = unnormalize(targets, mini, maxi)
    for i in range(len(targets)):
        if (preds[i] > targets[i]).cpu().data.numpy()[0]:
            qerror.append(preds[i]/targets[i])
        else:
            qerror.append(targets[i]/preds[i])
    return torch.mean(torch.cat(qerror)), torch.median(torch.cat(qerror)), torch.max(torch.cat(qerror)), torch.argmax(torch.cat(qerror))

In [105]:
# train function
def train(train_start, train_end, validate_start, validate_end, num_epochs, data_dir='tlstm_files/job'):
    input_dim = condition_op_dim #  1072
    hidden_dim = 128
    hid_dim = 256
    middle_result_dim = 128
    task_num = 2
    model = Representation(input_dim, hidden_dim, hid_dim, middle_result_dim, task_num)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    model.train()
    start = time.time()
    for epoch in range(num_epochs):
        cost_loss_total = 0.
        card_loss_total = 0.
        model.train()
        for batch_idx in range(train_start, train_end):
            print ('batch_idx: ', batch_idx)
            target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, mapping = get_batch_job(batch_idx, directory=data_dir)
            target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, mapping = torch.FloatTensor(target_cost), torch.FloatTensor(target_cardinality),torch.FloatTensor(operatorss),torch.FloatTensor(extra_infoss),torch.FloatTensor(condition1ss),torch.FloatTensor(condition2ss), torch.FloatTensor(sampless), torch.FloatTensor(condition_maskss), torch.FloatTensor(mapping)
            operatorss, extra_infoss, condition1ss, condition2ss, condition_maskss = operatorss.squeeze(0), extra_infoss.squeeze(0), condition1ss.squeeze(0), condition2ss.squeeze(0), condition_maskss.squeeze(0).unsqueeze(2)
            sampless = sampless.squeeze(0)
            mapping = mapping.squeeze(0)
            target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss = Variable(target_cost), Variable(target_cardinality), Variable(operatorss), Variable(extra_infoss), Variable(condition1ss), Variable(condition2ss)
            sampless = Variable(sampless)
            optimizer.zero_grad()
            estimate_cost,estimate_cardinality = model(operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, mapping)
            target_cost = target_cost
            target_cardinality = target_cardinality
            cost_loss,cost_loss_median,cost_loss_max,cost_max_idx = qerror_loss(estimate_cost, target_cost, cost_label_min, cost_label_max)
            card_loss,card_loss_median,card_loss_max,card_max_idx = qerror_loss(estimate_cardinality, target_cardinality, card_label_min, card_label_max)
            print (card_loss.item(),card_loss_median.item(),card_loss_max.item(),card_max_idx.item())
            loss = cost_loss + card_loss
            cost_loss_total += cost_loss.item()
            card_loss_total += card_loss.item()
            start = time.time()
            loss.backward()
            optimizer.step()
            end = time.time()
            print ('batchward time: ',end - start)
        batch_num = train_end - train_start
        print("Epoch {}, training cost loss: {}, training card loss: {}".format(epoch, cost_loss_total/batch_num, card_loss_total/batch_num))

        cost_loss_total = 0.
        card_loss_total = 0.
        for batch_idx in range(validate_start, validate_end):
            print ('batch_idx: ', batch_idx)
            target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, mapping = get_batch_job(batch_idx, directory=data_dir)
            target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, mapping = torch.FloatTensor(target_cost), torch.FloatTensor(target_cardinality),torch.FloatTensor(operatorss),torch.FloatTensor(extra_infoss),torch.FloatTensor(condition1ss),torch.FloatTensor(condition2ss), torch.FloatTensor(sampless), torch.FloatTensor(condition_maskss), torch.FloatTensor(mapping)
            operatorss, extra_infoss, condition1ss, condition2ss, condition_maskss = operatorss.squeeze(0), extra_infoss.squeeze(0), condition1ss.squeeze(0), condition2ss.squeeze(0), condition_maskss.squeeze(0).unsqueeze(2)
            sampless = sampless.squeeze(0)
            mapping = mapping.squeeze(0)
            target_cost, target_cardinality, operatorss, extra_infoss, condition1ss, condition2ss = Variable(target_cost), Variable(target_cardinality), Variable(operatorss), Variable(extra_infoss), Variable(condition1ss), Variable(condition2ss)
            sampless = Variable(sampless)
            estimate_cost,estimate_cardinality = model(operatorss, extra_infoss, condition1ss, condition2ss, sampless, condition_maskss, mapping)
            target_cost = target_cost
            target_cardinality = target_cardinality
            cost_loss,cost_loss_median,cost_loss_max,cost_max_idx = qerror_loss(estimate_cost, target_cost, cost_label_min, cost_label_max)
            card_loss,card_loss_median,card_loss_max,card_max_idx = qerror_loss(estimate_cardinality, target_cardinality, card_label_min, card_label_max)
            print (card_loss.item(),card_loss_median.item(),card_loss_max.item(),card_max_idx.item())
            loss = cost_loss + card_loss
            cost_loss_total += cost_loss.item()
            card_loss_total += card_loss.item()
        batch_num = validate_end - validate_start
        print("Epoch {}, validating cost loss: {}, validating card loss: {}".format(epoch, cost_loss_total/batch_num, card_loss_total/batch_num))
    end = time.time()
    print (end-start)
    return model

# Main Functions

In [32]:
t = time.time()
data, indexes_id, tables_id, columns_id, columns_list, physic_ops_id, compare_ops_id, bool_ops_id = prepare_dataset('tlstm_files/TPCH_10/tpch10_data_csv')
print ('data prepared, use {} ms', time.time() - t)

data prepared, use {} ms 161.4234824180603


In [33]:
t = time.time()
word_vectors = load_dictionary('tlstm_files/Initial_Dataset/wordvectors_updated.kv')
print ('word_vectors loaded, use {} ms', time.time() - t)

word_vectors loaded, use {} ms 43.678524017333984


In [34]:
t = time.time()
min_max_column = load_numeric_min_max('tlstm_files/TPCH_10/tpch_10_min_max_vals.json')
print ('min_max loaded, use {} ms', time.time() - t)

min_max loaded, use {} ms 0.005000591278076172


In [55]:
index_total_num = len(indexes_id)
table_total_num = len(tables_id)
column_total_num = len(columns_id)
physic_op_total_num = len(physic_ops_id)
compare_ops_total_num = len(compare_ops_id)
bool_ops_total_num = len(bool_ops_id)
condition_op_dim = bool_ops_total_num + compare_ops_total_num+column_total_num+1000
condition_op_dim_pro = bool_ops_total_num + column_total_num + 3
t = time.time()
plan_node_max_num, condition_max_num, cost_label_min, cost_label_max, card_label_min, card_label_max = obtain_upper_bound_query_size('tlstm_files/TPCH_10/plans_seq.json')
print ('plan loaded and obtain query upper size prepared, use {} ms', time.time() - t)

plan_node_max_num:  44
condition_max_num:  80
cost_label_min:  5.676514047502008
cost_label_max:  11.65474849922678
card_label_min:  0.0
card_label_max:  12.853563247110298
plan loaded and obtain query upper size prepared, use {} ms 0.3966364860534668


In [97]:
t = time.time()
encode_train_plan_seq_save('tlstm_files/TPCH_10/plans_seq.json', batch_size=64, directory='tlstm_files/TPCH_10/encoded_data_no_sample')
print ('data encoded, use {} ms', time.time() - t)

batch_id 0 64
idx:  0
1998-09-20 00:00:00
idx:  1
1998-09-09 00:00:00
idx:  2
1998-09-23 00:00:00
idx:  3
1998-08-06 00:00:00
idx:  4
1998-09-24 00:00:00
idx:  5
1998-08-29 00:00:00
idx:  6
1998-08-15 00:00:00
idx:  7
1998-08-04 00:00:00
idx:  8
1998-08-19 00:00:00
idx:  9
1998-08-18 00:00:00
idx:  10
1998-08-07 00:00:00
idx:  11
1998-09-30 00:00:00
idx:  12
1998-09-01 00:00:00
idx:  13
1998-09-20 00:00:00
idx:  14
1998-08-13 00:00:00
idx:  15
1998-09-05 00:00:00
idx:  16
1998-08-16 00:00:00
idx:  17
1998-09-02 00:00:00
idx:  18
1998-08-19 00:00:00
idx:  19
1998-09-23 00:00:00
idx:  20
1998-09-07 00:00:00
idx:  21
1998-08-13 00:00:00
idx:  22
1998-09-28 00:00:00
idx:  23
1998-09-10 00:00:00
idx:  24
1998-08-08 00:00:00
idx:  25
1998-09-16 00:00:00
idx:  26
1998-08-31 00:00:00
idx:  27
1998-08-24 00:00:00
idx:  28
1998-09-08 00:00:00
idx:  29
1998-09-28 00:00:00
idx:  30
1998-09-16 00:00:00
idx:  31
1998-09-09 00:00:00
idx:  32
1998-09-04 00:00:00
idx:  33
1998-09-21 00:00:00
idx:  34
1

In [110]:
model = train(0, 90, 90, 100, 20, data_dir='tlstm_files/TPCH_10/encoded_data_no_sample')

input_dim:  1072
batch_idx:  0
batch_size:  64
Forest Evaluate Running Time:  0.1490018367767334
132.32513427734375 132.32589721679688 132.33901977539062 10




batchward time:  1.399421215057373
batch_idx:  1
batch_size:  64
Forest Evaluate Running Time:  0.0069997310638427734
127.25428009033203 127.25456237792969 127.26074981689453 12
batchward time:  0.8459994792938232
batch_idx:  2
batch_size:  64
Forest Evaluate Running Time:  0.005999565124511719
121.85597229003906 121.84911346435547 121.95048522949219 30
batchward time:  0.828000545501709
batch_idx:  3
batch_size:  64
Forest Evaluate Running Time:  0.006999969482421875
116.11135864257812 116.11199951171875 116.1408462524414 30
batchward time:  0.801002025604248
batch_idx:  4
batch_size:  64
Forest Evaluate Running Time:  0.006000518798828125
110.09410095214844 110.09231567382812 110.1309585571289 31
batchward time:  0.7980005741119385
batch_idx:  5
batch_size:  64
Forest Evaluate Running Time:  0.004999637603759766
103.80342864990234 103.80530548095703 103.81411743164062 20
batchward time:  0.8080003261566162
batch_idx:  6
batch_size:  64
Forest Evaluate Running Time:  0.005000114440917

In [None]:
import pickle

with open('tlstm_files/TPCH_10/model', 'wb') as f:
    pickle.dump(model, f)