In [None]:
from joinboost_disk import *

In [None]:
def create_jg(con, unique_id = 0, sample=True):
    # name needs to be different for different views
    # learning  rate should be 1/# trees
    jg = joinGraph("favorita" + str(unique_id), con, log=False, max_leaves = 8,  learning_rate=0.01, target_variable ="Y")
    jg.add_table("sales", [], [], fact=True)
    jg.add_table("holidays", ["htype", "locale", "locale_name", "transferred","f2"], [2,2,2,2,2])
    jg.add_table("oil", ["dcoilwtico","f3"], [2,2])
    jg.add_table("transactions", ["transactions","f5"], [2,2])
    jg.add_table("stores", ["city","state","stype","cluster","f4"], [2,2,2,2,2])
    jg.add_table("items", ["family","class","perishable","f1"], [2,2,2,2])
    if sample:
        jg.create_sample_fact(sample_percent = 1, sample_seed = unique_id, view=True)

    jg.join(jg.fact, "items", ["item_nbr"], ["item_nbr"])
    jg.join(jg.fact, "transactions", ["tid"], ["tid"])
    jg.join("transactions", "stores", ["store_nbr"], ["store_nbr"])
    jg.join("transactions", "holidays", ["date"], ["date"])
    jg.join("holidays", "oil", ["date"], ["date"])
    return jg

In [None]:
con = duckdb.connect(database='fav_2.duckdb',check_same_thread=False)
con.execute("CREATE OR REPLACE TABLE holidays AS SELECT * FROM 'data/holidays.csv';")
con.execute("CREATE OR REPLACE TABLE oil AS SELECT * FROM 'data/oil.csv';")
con.execute("CREATE OR REPLACE TABLE transactions AS SELECT * FROM 'data/transactions.csv';")
con.execute("CREATE OR REPLACE TABLE stores AS SELECT * FROM 'data/stores.csv';")
con.execute("CREATE OR REPLACE TABLE items AS SELECT * FROM 'data/items.csv';")
con.execute("CREATE OR REPLACE TABLE sales AS SELECT * FROM 'data/train.csv';")
con.execute("CREATE OR REPLACE TABLE test AS SELECT * FROM 'data/test.csv';")
jg = create_jg(con, sample=False)
jg.create_dummy_model(replace=False)
con.close()

In [None]:
cons = dict()
tree_queries = []
# specify the number of trees
trees = list(range(100))

def init(unique_id):
    try:
        cons[unique_id] = duckdb.connect(database='fav_2.duckdb',check_same_thread=False)
    except Exception as e: print(e)

def train_tree(worker_id):
    con = cons[worker_id]
    while True:
        try:
            con.execute("PRAGMA threads=4;")
            tree_id = trees.pop()
            print(str(worker_id) + " trains tree " + str(tree_id))
            build_tree(con, tree_id)
        except Exception as e: 
            print(e)
            return

def build_tree(con, tree_id):
    jg = create_jg(con, tree_id)
    # get this from create_dummy_model
    jg.set_ts_tc(0.0, 80318105)
    jg.create_base_node()
    jg.build_gradient_tree()
    jg.clean_leaves()
    jg.clean_table()
    tree_queries.append(jg.tree_queries[0])
    print("Tree " + str(tree_id)  + "finishes:" + str(time.time() - initial_time))

In [None]:
def function_threading(func, num_threads):
    threads = dict()
    num_threads = num_threads
    for i in range(num_threads):
        threads[i] = threading.Thread(target=func, args=(i,))

    for i in range(num_threads):
        threads[i].start()

    for i in range(num_threads):
        threads[i].join()

In [None]:
%%time
function_threading(init, 16)

In [None]:
%%time
initial_time = time.time()
# how many threads
function_threading(train_tree, 16)