In [1]:
from botorch.acquisition import LogExpectedImprovement
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood

import datetime
import polars as pl
import torch

# Reading in Data

Start with TPCH Q12 here. For reference
```SQL
SELECT
  l_shipmode,
  SUM(CASE
      WHEN o_orderpriority = '1-URGENT'
          OR o_orderpriority = '2-HIGH'
          THEN 1
      ELSE 0
  END) as high_line_count,
  SUM(CASE
      WHEN o_orderpriority <> '1-URGENT'
          AND o_orderpriority <> '2-HIGH'
          THEN 1
      ELSE 0
  END) AS low_line_count
FROM
    orders,
    lineitem
WHERE
  o_orderkey = l_orderkey
  AND l_shipmode IN ('AIR', 'REG AIR')
  AND l_commitdate < l_receiptdate
  AND l_shipdate < l_commitdate
  AND l_receiptdate >= $receiptdate1
  AND l_receiptdate < $receiptdate2
GROUP BY
      l_shipmode 
ORDER BY
    l_shipmode;
```

In [2]:
data = pl.read_csv('../data/tpch_q12_sweep.csv').rename({'': 'id'})
d = pl.read_json('../data/q12_result.json')
data = data.with_columns(mean_elapsed=pl.mean_horizontal('elapsed_0', 'elapsed_1', 'elapsed_2'))
data = data.with_columns(
    pl.col("receiptdate1").str.to_datetime("%Y-%m-%d", time_unit='ms'),
    pl.col("receiptdate2").str.to_datetime("%Y-%m-%d", time_unit='ms')
)
table_rd1 = pl.DataFrame(
    {
        'receiptdate1': d['sel_receiptdate1'][0].keys(),
        'selectivity1': d['sel_receiptdate1'][0].values()
    }
).with_columns(pl.col("receiptdate1").str.to_datetime("%Y-%m-%d", time_unit='ms'))
table_rd2 = pl.DataFrame(
    {
        'receiptdate2': d['sel_receiptdate2'][0].keys(),
        'selectivity2': d['sel_receiptdate2'][0].values()
    }
).with_columns(pl.col("receiptdate2").str.to_datetime("%Y-%m-%d", time_unit='ms'))
data = data.join(table_rd1, on='receiptdate1').join(table_rd2, on='receiptdate2')
data

id,receiptdate1,receiptdate2,elapsed_0,elapsed_1,elapsed_2,mean_elapsed,selectivity1,selectivity2
i64,datetime[ms],datetime[ms],f64,f64,f64,f64,f64,f64
0,1992-01-03 00:00:00,1992-01-03 00:00:00,0.006372,0.001012,0.000931,0.002771,1.0,0.0
1,1992-01-03 00:00:00,1992-02-02 00:00:00,0.231195,0.088566,0.088825,0.136195,1.0,0.000567
2,1992-01-03 00:00:00,1992-03-03 00:00:00,0.193612,0.157787,0.157573,0.169657,1.0,0.003764
3,1992-01-03 00:00:00,1992-04-02 00:00:00,0.201443,0.199245,0.201111,0.2006,1.0,0.010051
4,1992-01-03 00:00:00,1992-05-02 00:00:00,0.230072,0.228478,0.229138,0.229229,1.0,0.019428
…,…,…,…,…,…,…,…,…
7391,1998-12-27 00:00:00,1998-08-29 00:00:00,0.000568,0.000582,0.000557,0.000569,0.000004,0.978716
7392,1998-12-27 00:00:00,1998-09-28 00:00:00,0.000556,0.000566,0.000572,0.000564,0.000004,0.988609
7393,1998-12-27 00:00:00,1998-10-28 00:00:00,0.000571,0.000582,0.00056,0.000571,0.000004,0.995406
7394,1998-12-27 00:00:00,1998-11-27 00:00:00,0.000557,0.000567,0.000569,0.000564,0.000004,0.999114


In [3]:
deviation_log = d['deviation_log'][0]
table_dev = pl.DataFrame({'deviation': d['deviation_log'][0]}).with_columns(pl.col("deviation").list.to_struct()).unnest("deviation")
table_dev = table_dev.rename(
        {'field_0': 'row1_id',
        'field_1': 'row2_id',
        'field_2': 'slope'}
    ).with_columns(pl.col('row1_id').cast(pl.Int64), pl.col('row2_id').cast(pl.Int64))
table_dev

row1_id,row2_id,slope
i64,i64,f64
0,1,49.143525
0,86,3.71849
0,87,23.907892
1,2,1.245692
1,86,182.73971
…,…,…
7390,7391,1.005479
7391,7392,1.007602
7392,7393,1.011122
7393,7394,1.011407


# Training Initial GP

We'll select 10 starting points to train our GP model initially.

The model acts as a function that takes in two queries (represented by 2 receipt dates) and outputs an estimated slope.

In [5]:
sample = (
    table_dev
    .sample(10)
    .join(data, left_on='row1_id', right_on='id')
    .join(data, left_on='row2_id', right_on='id', suffix='_b')
)
train_x = sample.select('receiptdate1', 'receiptdate2', 'receiptdate1_b', 'receiptdate2_b').cast(pl.Float64).to_torch()
train_y = sample.select('slope').to_torch()

In [6]:
gp = SingleTaskGP(
    train_X=train_x,
    train_Y=train_y,
    input_transform=Normalize(d=len(train_x[0])),
    outcome_transform=Standardize(m=1),
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

ExactMarginalLogLikelihood(
  (likelihood): GaussianLikelihood(
    (noise_covar): HomoskedasticNoise(
      (noise_prior): LogNormalPrior()
      (raw_noise_constraint): GreaterThan(1.000E-04)
    )
  )
  (model): SingleTaskGP(
    (likelihood): GaussianLikelihood(
      (noise_covar): HomoskedasticNoise(
        (noise_prior): LogNormalPrior()
        (raw_noise_constraint): GreaterThan(1.000E-04)
      )
    )
    (mean_module): ConstantMean()
    (covar_module): RBFKernel(
      (lengthscale_prior): LogNormalPrior()
      (raw_lengthscale_constraint): GreaterThan(2.500E-02)
    )
    (outcome_transform): Standardize()
    (input_transform): Normalize()
  )
)

### Acquisition Function

For now, as our GP estimates the slope from two queries we will directly try to maximize this value.

The bounds will allow us to control the area in which we search, this is where we can decide how to divide up the search space.

We can try something dumb where we search over the whole space for now

In [7]:
bounds = torch.stack([
    torch.concat([
        data.select(pl.col('receiptdate1', 'receiptdate2').min()).cast(pl.Float64).to_torch().flatten(), # minimum bounds for query 1
        data.select(pl.col('receiptdate1', 'receiptdate2').min()).cast(pl.Float64).to_torch().flatten()  # minimum bounds for query 2
    ]),
    torch.concat([
        data.select(pl.col('receiptdate1', 'receiptdate2').max()).cast(pl.Float64).to_torch().flatten(), # maximum bounds for query 1
        data.select(pl.col('receiptdate1', 'receiptdate2').max()).cast(pl.Float64).to_torch().flatten()  # maximum bounds for query 2
    ]),
]).to(torch.double)

logNEI = LogExpectedImprovement(model=gp, best_f=train_y.max())
candidate, acq_value = optimize_acqf(
    logNEI, bounds=bounds, q=1, num_restarts=4, raw_samples=100,
)
candidate = candidate.round().to(torch.int64) # Note we round off the item as we're dealing in ms
candidate, acq_value

(tensor([[799770733576, 709737847556, 808019773963, 810552194928]]),
 tensor(-5.9125, dtype=torch.float64))

Once we have the candidate we rebuild the query into something we could potentially pass into the DB

TODO: we will want to map the closest query to this for experimentation, however, we will need to round to the nearest whole number as we're dealing in milliseconds

In [8]:
# For some reason datetime does not take in ms, so we'll convert to seconds losing a little more precision
predicates_query_a = {
    'receiptdate1': datetime.datetime.fromtimestamp(candidate[0][0].item() // 1000), # datetime ms -> s
    'receiptdate2': datetime.datetime.fromtimestamp(candidate[0][1].item() // 1000), # datetime ms -> s
}
predicates_query_b = {
    'receiptdate1': datetime.datetime.fromtimestamp(candidate[0][2].item() // 1000), # datetime ms -> s
    'receiptdate2': datetime.datetime.fromtimestamp(candidate[0][3].item() // 1000), # datetime ms -> s
}
predicates_query_a, predicates_query_b

({'receiptdate1': datetime.datetime(1995, 5, 6, 10, 32, 13),
  'receiptdate2': datetime.datetime(1992, 6, 28, 9, 24, 7)},
 {'receiptdate1': datetime.datetime(1995, 8, 9, 21, 56, 13),
  'receiptdate2': datetime.datetime(1995, 9, 8, 5, 23, 14)})

## Gaining feedback by running on the DB

Now we have some queries to test, we can run it back on the DB to gain feedback on whether or not this worked out

In [9]:
import duckdb
import time

In [10]:
class DBWrapperTpchQ12:
    def __init__(self, db_path: str) -> None:
        con = duckdb.connect(database=db_path)
        con.execute('SET enable_progress_bar = false')
        self.con: duckdb.DuckDBPyConnection = con
        self.query_template = """
            SELECT
              l_shipmode,
              SUM(CASE
                  WHEN o_orderpriority = '1-URGENT'
                      OR o_orderpriority = '2-HIGH'
                      THEN 1
                  ELSE 0
              END) as high_line_count,
              SUM(CASE
                  WHEN o_orderpriority <> '1-URGENT'
                      AND o_orderpriority <> '2-HIGH'
                      THEN 1
                  ELSE 0
              END) AS low_line_count
            FROM
                orders,
                lineitem
            WHERE
              o_orderkey = l_orderkey
              AND l_shipmode IN ('AIR', 'REG AIR')
              AND l_commitdate < l_receiptdate
              AND l_shipdate < l_commitdate
              AND l_receiptdate >= $receiptdate1
              AND l_receiptdate < $receiptdate2
            GROUP BY
                  l_shipmode 
            ORDER BY
                l_shipmode;
        """

    def run_query(self, predicates: dict):
        start = time.time()
        self.con.sql(self.query_template, params=predicates)
        elapsed = time.time() - start
        
        return elapsed

In [11]:
def calculate_qerror(a, b):
    return max(a/b, b/a)

In [12]:
db = DBWrapperTpchQ12(db_path='../data/tpch_sf100.db')

In [13]:
a = db.run_query(predicates_query_a)
b = db.run_query(predicates_query_b)

In [14]:
calculate_qerror(a, b)

16.54903763208361

# Doing the training loop

With one point of feedback, we can now do the whole training loop. Simply placing back the data point we just got into the original dataset and rebuilding the new GP model we can incorporate this new piece of data.