In [None]:
import json
import time
import os
import re
import shlex
import subprocess
import platform
from docker import from_env as docker_client
from dataclasses import dataclass
from datetime import datetime
from configparser import ConfigParser

import pandas as pd
import psycopg

In [None]:
#loading config
def load_config(filename="database.ini", section="postgresql"):
    parser = ConfigParser()
    parser.read(filename)

    config = {}
    if parser.has_section(section):
        params = parser.items(section)
        for param in params:
            config[param[0]] = param[1]
    else:
        raise Exception("Section {0} not found in the {1} file".format(section, filename))
    return config

#connecting to db
def connect():
    try:
        pg_conn = psycopg.connect(**load_config(), connect_timeout = 5)
        return pg_conn
    except psycopg.DatabaseError as error:
        raise error

#disconnecting from db
def disconnect(pg_conn, cursor = None):
    if cursor != None:
        cursor.close()
    pg_conn.close()

In [None]:
#restart db and remove cache if possible
def restart_db():
    env = load_config(section = "docker")

    if env["container_name"] != "":
        try:
            client = docker_client()
            container = client.containers.get(env["container_name"])

            container.stop()
            container.wait()
            container.start()
            
        except Exception as e:
            print(e)
            
    elif platform.system() == "Linux":
        try:
            subprocess.run(['sudo', 'systemctl', 'stop', 'postgresql'], check=True)

            subprocess.run(["sync"], check=True)
            subprocess.run(["sudo", "sh", "-c", "echo 3 > /proc/sys/vm/drop_caches"], check=True)
            
            subprocess.run(['sudo', 'systemctl', 'start', 'postgresql'], check=True)

        except Exception as e:
            print(e)

    else:
        raise Exception("System configuration not supported")

#wait for db to accept connections
def wait_for_db(timeout = 15):
    start = time.time()

    while time.time() - start < timeout:
        try:
            pg_conn = connect()
            cursor = pg_conn.cursor()

            disconnect(pg_conn, cursor)
            return
        except psycopg.OperationalError:
            time.sleep(1)
    raise TimeoutError("Could not connect to the database")

In [None]:
@dataclass
class Query:
    label: str
    string: str
    groups: list[str]

#formatting query
def format_query(query):
    #removing comments
    query = re.sub(r"--.*", "", query)
    query = re.sub(r"/\*.*?\*/", "", query, flags = re.DOTALL)
    query = re.sub(r"EXPLAIN (ANALYZE|(\(.*\)))", "", query) #fallback for explain
    
    #joining words with spaces while preserving quoted strings
    query = " ".join(shlex.split(query, posix = False))

    return query

queries: list[Query] = []

#loading all queries
for root, dirs, files in os.walk(os.curdir):
    for file in files:
        if not file.endswith(".sql"):
            continue

        path = os.path.join(root, file)
        
        query = open(path, "r").read()
        query = format_query(query)

        groups = root.removeprefix(".").removeprefix(os.sep).split(os.sep)

        query = Query(file, query, groups)

        queries.append(query)
        print(query)

In [None]:
@dataclass
class QueryResult:
    label: str
    groups: list[str]
    query: str
    bench_time: datetime
    result_set: dict
    exec_time: float

#executes given query and returns QueryResult object
def run_query(query, cursor, analyze_prefix = ""):

    bench_start, query_start = datetime.now(), time.perf_counter_ns()
    cursor.execute(analyze_prefix + query.string)

    result_set = cursor.fetchall()
    query_end = time.perf_counter_ns()

    result = QueryResult(
        label = query.label,
        groups = query.groups,
        query = query.string,
        bench_time = bench_start,
        result_set = json.dumps(result_set),
        exec_time = query_end - query_start
    )

    return result

#executes query with precaching and returns list of results
def hot_run_query(query, precache_repeats = 3, query_repeats = 1, analyze_prefix = ""):
    
    results: list[QueryResult] = []

    pg_conn = connect()
    cursor = pg_conn.cursor()
    
    for statement in query.string.split(";"):
        statement = statement.strip()

        if statement == "":
            continue

        if statement.upper().startswith(("CREATE", "REPLACE", "REFRESH", "DROP")):
            cursor.execute(statement)
            continue

        for _ in range(precache_repeats):
            cursor.execute(statement)
        
        new_query = Query(query.label, statement, query.groups)

        for _ in range(query_repeats):
            result = run_query(new_query, cursor, analyze_prefix)
            results.append(result)

    disconnect(pg_conn, cursor)

    return results

#executes query without precaching and returns list of results
def cold_run_query(query, query_repeats = 1, analyze_prefix = ""):

    results: list[QueryResult] = []
    
    for _ in range(query_repeats):

        restart_db()
        wait_for_db()

        pg_conn = connect()
        cursor = pg_conn.cursor()

        for statement in query.string.split(";"):
            statement = statement.strip()

            if statement == "":
                continue

            if statement.upper().startswith(("CREATE", "REPLACE", "REFRESH", "DROP")):
                cursor.execute(statement)
                continue

            new_query = Query(query.label, statement, query.groups)
            
            result = run_query(new_query, cursor, analyze_prefix)
            results.append(result)

        disconnect(pg_conn, cursor)

    return results

In [None]:
#execution settings
analyze_prefix = "EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON, SETTINGS) "
only_run = [] #full label or group name
precache_repeats = 3
query_repeats = 10
cold_run = False

In [None]:
# executing queries
results: list[QueryResult] = []

for query in queries:

    if (
        len(only_run) > 0
        and query.label not in only_run
        and not any(i in only_run for i in query.groups)
    ):
        continue

    print(query.label)

    if cold_run:
        result = cold_run_query(query, query_repeats, analyze_prefix)
    else:
        result = hot_run_query(query, precache_repeats, query_repeats, analyze_prefix)

    results.extend(result)

In [None]:
#saving results
file_name = "results"

if cold_run:
    file_name += "_cold"
else:
    file_name += "_hot"

df = pd.DataFrame(results)
df.to_csv(file_name + ".csv", index = False)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
#plotting results
sns.set_context("talk")
sns.set_palette("viridis")
sns.set_theme(style = "whitegrid")

In [None]:
df = pd.read_csv("results.csv")

def getTotalCost(text):
    match = re.search(r'"Total Cost": (\d+)', text)
    return int(match.group(1)) if match else None

df["Total Cost"] = df["result_set"].apply(getTotalCost)

df["Cost per Time"] = df["Total Cost"]/df["exec_time"]
df["Time per Cost"] = df["exec_time"]/df["Total Cost"]

df["Json"] = "json" in df["label"]

print(df.describe())

In [None]:
g = sns.regplot(
    data = df,
    x = "Total Cost",
    y = "exec_time",
    scatter = True,
    
)

g.set(
    title = "Runtime per Cost",
    xlabel = "Total Cost",
    ylabel = "Runtime"
)

g.get_figure().set_size_inches(6, 6)
g.get_figure().tight_layout()

plt.xscale("log")
plt.yscale("log")

plt.show

In [None]:
g = sns.regplot(
    data = df,
    x = "exec_time",
    y = "Time per Cost",
    scatter = True,
)

g.set(
    title = "Time per Cost to Runtime",
    xlabel = "Runtime",
    ylabel = "Cost per Runtime"
)

g.get_figure().set_size_inches(6, 6)
g.get_figure().tight_layout()

plt.xscale("log")
plt.yscale("log")

plt.show

In [None]:
g = sns.regplot(
    data = df,
    x = "Total Cost",
    y = "Time per Cost",
    scatter = True,
)

g.set(
    title = "Time per Cost to Total Cost",
    xlabel = "Total Cost",
    ylabel = "Time per Cost"
)

g.get_figure().set_size_inches(6, 6)
g.get_figure().tight_layout()

plt.xscale("log")
plt.yscale("log")

plt.show

In [None]:
g = sns.lmplot(
    data = df,
    x = "Total Cost",
    y = "Time per Cost",
    hue = "label",
)

g.set(
    title = "Time per Cost to Total Cost",
    xlabel = "Total Cost",
    ylabel = "Time per Cost"
)

plt.xscale("log")
plt.yscale("log")

plt.show

In [None]:
df = pd.read_csv("results.csv")

df["exec_time"] = df["exec_time"] / 1000000

g = sns.barplot(
    data = df,
    x = "exec_time",
    y = "label"
)

g.set(
    title = "Query runtimes",
    xlabel = "Execution time [ms]",
    ylabel = "Query"
)

g.get_figure().set_size_inches(10, 10)
g.get_figure().tight_layout()

plt.xscale("log")

plt.show()