In [4]:
import pyarrow.parquet as pq
from pyarrow.csv import write_csv
from pgpq import ArrowToPostgresBinaryEncoder

In [5]:
from pathlib import Path
import requests

file = Path(".").resolve().parent.parent / "yellow_tripdata_2022-01.parquet"
if not file.exists():
    with requests.get("https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet", stream=True) as r:
        r.raise_for_status()
        with file.open("wb") as f:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                f.write(chunk)

In [14]:
arrow_table = pq.read_table(file)

csv_file = file.with_suffix(".csv")
binary_file = file.with_suffix(".bin")

write_csv(arrow_table, csv_file)


def encode_file():
    encoder = ArrowToPostgresBinaryEncoder(arrow_table.schema)
    with binary_file.open("wb") as f:
        f.write(encoder.write_header())
        for batch in arrow_table.to_batches():
            f.write(encoder.write_batch(batch))
        f.write(encoder.finish())


In [31]:
from contextlib import contextmanager
import subprocess
from time import time
from typing import Iterator
import psycopg
from testing.postgresql import Postgresql


@contextmanager
def get_dsn() -> Iterator[str]:
    postgres = Postgresql()
    dsn = postgres.url()
    try:
        yield dsn
    finally:
        postgres.terminate()


def clean(dsn: str):
    with psycopg.connect(dsn) as conn:
        with conn.cursor() as cursor:
            cursor.execute("DROP TABLE IF EXISTS data")
            encoder = ArrowToPostgresBinaryEncoder(arrow_table.schema)
            pg_schema = encoder.schema()
            cols = [f"\"{col['name']}\" {col['data_type']['ddl']}" for col in pg_schema["columns"]]
            ddl = f"CREATE TABLE data ({','.join(cols)})"
            cursor.execute(ddl)
            conn.commit()


# using psql just because it's the "gold standard" for this
with get_dsn() as dsn:
    clean(dsn)
    start = time()
    subprocess.run(["psql", dsn, "-c", f"\\copy data FROM '{csv_file}' WITH (FORMAT CSV, HEADER);"], check=True, capture_output=True)
    end = time()
    print(f"{end-start:.2f}")

    clean(dsn)
    start = time()
    encode_file()
    subprocess.run(["psql", dsn, "-c", f"\\copy data FROM '{binary_file}' WITH (FORMAT BINARY);"], check=True, capture_output=True)
    end = time()
    print(f"{end-start:.2f}")

4.85
3.64
