In [None]:
!pip install -U polars pyarrow 

In [None]:
!pip install -U narwhals

In [None]:
import pandas as pd
import polars as pl

pd.options.mode.copy_on_write = True
pd.options.future.infer_string = True

In [None]:
from typing import Any
from datetime import datetime, date
import narwhals as nw

def q8_pandas_native(
    nation_ds,
    customer_ds,
    line_item_ds,
    orders_ds,
    supplier_ds,
) -> None:
    var1 = "FRANCE"
    var2 = "GERMANY"
    var3 = date(1995, 1, 1)
    var4 = date(1996, 12, 31)

    n1 = nation_ds[(nation_ds["n_name"] == var1)]
    n2 = nation_ds[(nation_ds["n_name"] == var2)]

    # Part 1
    jn1 = customer_ds.merge(n1, left_on="c_nationkey", right_on="n_nationkey")
    jn2 = jn1.merge(orders_ds, left_on="c_custkey", right_on="o_custkey")
    jn2 = jn2.rename({"n_name": "cust_nation"}, axis="columns")
    jn3 = jn2.merge(line_item_ds, left_on="o_orderkey", right_on="l_orderkey")
    jn4 = jn3.merge(supplier_ds, left_on="l_suppkey", right_on="s_suppkey")
    jn5 = jn4.merge(n2, left_on="s_nationkey", right_on="n_nationkey")
    df1 = jn5.rename({"n_name": "supp_nation"}, axis="columns")

    # Part 2
    jn1 = customer_ds.merge(n2, left_on="c_nationkey", right_on="n_nationkey")
    jn2 = jn1.merge(orders_ds, left_on="c_custkey", right_on="o_custkey")
    jn2 = jn2.rename({"n_name": "cust_nation"}, axis="columns")
    jn3 = jn2.merge(line_item_ds, left_on="o_orderkey", right_on="l_orderkey")
    jn4 = jn3.merge(supplier_ds, left_on="l_suppkey", right_on="s_suppkey")
    jn5 = jn4.merge(n1, left_on="s_nationkey", right_on="n_nationkey")
    df2 = jn5.rename({"n_name": "supp_nation"}, axis="columns")

    # Combine
    total = pd.concat([df1, df2])

    total = total[(total["l_shipdate"] >= var3) & (total["l_shipdate"] <= var4)]
    total["volume"] = total["l_extendedprice"] * (1.0 - total["l_discount"])
    total["l_year"] = total["l_shipdate"].dt.year

    gb = total.groupby(["supp_nation", "cust_nation", "l_year"], as_index=False)
    agg = gb.agg(revenue=pd.NamedAgg(column="volume", aggfunc="sum"))

    result_df = agg.sort_values(by=["supp_nation", "cust_nation", "l_year"])

    return result_df  # type: ignore[no-any-return]

In [None]:
from typing import Any
from datetime import datetime
import narwhals as nw

def q8(
    nation_ds_raw: Any,
    customer_ds_raw: Any,
    line_item_ds_raw: Any,
    orders_ds_raw: Any,
    supplier_ds_raw: Any,
    part_ds_raw: Any,
) -> None:
    nation_ds = nw.from_native(nation_ds_raw)
    customer_ds = nw.from_native(customer_ds_raw)
    line_item_ds = nw.from_native(line_item_ds_raw)
    orders_ds = nw.from_native(orders_ds_raw)
    supplier_ds = nw.from_native(supplier_ds_raw)
    part_ds = nw.from_native(part_ds_raw)

    n1 = nation_ds.select("n_nationkey", "n_regionkey")
    n2 = nation_ds.select("n_nationkey", "n_name")

    result = (
        part_ds.join(line_item_ds, left_on="p_partkey", right_on="l_partkey")
        .join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey")
        .join(orders_ds, left_on="l_orderkey", right_on="o_orderkey")
        .join(customer_ds, left_on="o_custkey", right_on="c_custkey")
        .join(n1, left_on="c_nationkey", right_on="n_nationkey")
        .join(region_ds, left_on="n_regionkey", right_on="r_regionkey")
        .filter(nw.col("r_name") == "AMERICA")
        .join(n2, left_on="s_nationkey", right_on="n_nationkey")
        .filter(
            nw.col("o_orderdate")>= date(1995, 1, 1),
            nw.col('o_orderdate')<=date(1996, 12, 31)
        )
        .filter(nw.col("p_type") == "ECONOMY ANODIZED STEEL")
        .select(
            nw.col("o_orderdate").dt.year().alias("o_year"),
            (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("volume"),
            nw.col("n_name").alias("nation"),
        )
        .with_columns(
            nw.when(nw.col("nation") == "BRAZIL")
            .then(nw.col("volume"))
            .otherwise(0)
            .alias("_tmp")
        )
        .group_by("o_year")
        .agg((nw.sum("_tmp") / nw.sum("volume")).round(2).alias("mkt_share"))
        .sort("o_year")
    )
    
    return nw.to_native(result)


In [None]:
dir_ = "/kaggle/input/tpc-h-data-parquet-s-2/"
region = dir_ + 'region.parquet'
nation = dir_ + 'nation.parquet'
customer = dir_ + 'customer.parquet'
lineitem = dir_ + 'lineitem.parquet'
orders = dir_ + 'orders.parquet'
supplier = dir_ + 'supplier.parquet'
part = dir_ + 'part.parquet'
partsupp = dir_ + 'partsupp.parquet'

In [None]:
IO_FUNCS = {
    'pandas': lambda x: pd.read_parquet(x, engine='pyarrow'),
    'pandas[pyarrow]': lambda x: pd.read_parquet(x, engine='pyarrow', dtype_backend='pyarrow'),
    'polars[eager]': lambda x: pl.read_parquet(x),
    'polars[lazy]': lambda x: pl.scan_parquet(x),
}

In [None]:
results = {}

## pandas, pyarrow dtypes, native

In [None]:
tool = 'pandas[pyarrow]'
fn = IO_FUNCS[tool]
timings = %timeit -o q7_pandas_native(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))
results[tool+'[native]'] = timings.all_runs

## pandas via Narwhals

In [None]:
tool = 'pandas'
fn = IO_FUNCS[tool]
timings = %timeit -o q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))
results[tool] = timings.all_runs

## pandas, pyarrow dtypes, via Narwhals

In [None]:
tool = 'pandas[pyarrow]'
fn = IO_FUNCS[tool]
timings = %timeit -o q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))
results[tool] = timings.all_runs

## Polars read_parquet

In [None]:
tool = 'polars[eager]'
fn = IO_FUNCS[tool]
timings = %timeit -o q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier))
results[tool] = timings.all_runs

## Polars scan_parquet

In [None]:
tool = 'polars[lazy]'
fn = IO_FUNCS[tool]
timings = %timeit -o q7(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)).collect()
results[tool] = timings.all_runs

## Save

In [None]:
import json
with open('results.json', 'w') as fd:
    json.dump(results, fd)
