In [0]:
import math
import sys
import time

import requests


from typing import Callable

def meazure_duration(df_func: Callable, *args, **kwargs) -> tuple:
    start = time.time()
    success = True
    try:
        df = df_func(*args, **kwargs)
        df.count()
    except Exception as ex:
        print(f"Error measuring benchmark {df_func}: {ex}")
        success = False
    finally:
        return round(time.time() - start, 2), success

def convert_size_bytes(size_bytes: int) -> float:
    """
    Converts a size in bytes to a human readable string using SI units.
    """

    if not isinstance(size_bytes, int):
        size_bytes = sys.getsizeof(size_bytes)

    if size_bytes == 0:
        return "0B"

    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return round(size_bytes / p, 2)


def estimate_df_size(df) -> float:
    df.cache()
    df.count()
    catalyst_plan = df._jdf.queryExecution().logical()
    size_bytes = spark._jsparkSession.sessionState().executePlan(catalyst_plan, df._jdf.queryExecution().mode()).optimizedPlan().stats().sizeInBytes()
    df.unpersist()
    return convert_size_bytes(size_bytes)

def get_data_info(data_path: str):
    df = spark.read.csv(data_path, header=True)
    data = [{
        "column_count": len(df.columns),
        "size": estimate_df_size(df),
        "row_count": df.count()
    }]
    return spark.createDataFrame(data).select("size", "row_count", "column_count")

def get_cluster_info():
    host = f"https://{spark.conf.get('spark.databricks.workspaceUrl')}"
    token = ""
    headers = {"Authorization": f"Bearer {token}"}

    cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")

    cluster_info = requests.get(
        f"{host}/api/2.0/clusters/get?cluster_id={cluster_id}",
        headers=headers
    ).json()
    node_type = cluster_info["node_type_id"]
    # instance_count = cluster_info["autoscale"]["max_workers"]

    node_types = requests.get(
        f"{host}/api/2.0/clusters/list-node-types",
        headers=headers
    ).json()

    node_spec = next(x for x in node_types["node_types"] if x["node_type_id"] == node_type)

    cpu = node_spec["num_cores"]
    memory_gb = node_spec["memory_mb"] / 1024

    data = [{
        "cluster_instance_type": node_type,
        "cluster_instance_cpu": cpu,
        "cluster_instance_memory": memory_gb,
        "cluster_instance_count": 1
    }]

    return (
        spark.createDataFrame(data)
        .select("cluster_instance_type", "cluster_instance_cpu", "cluster_instance_memory", "cluster_instance_count")
    )

def run_benchmark(task_name: str, task_func: callable, *task_args) -> list[tuple]:
    print(f"Running benchmark for {task_name}")
    return [
        (task_name, *meazure_duration(task_func, *task_args))
        for _ in range(MAX_EXECUTION_COUNT)
    ]
