From 86b7ad4eb0bb76d616eeea2ccbc0b33b80c18e2c Mon Sep 17 00:00:00 2001 From: Allison Wang Date: Mon, 21 Jul 2025 20:35:44 -0700 Subject: [PATCH 1/3] add CLAUDE.md --- CLAUDE.md | 220 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..adb21ec --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,220 @@ +# PySpark Data Sources - Project Context for Claude + +## Project Overview +This is a demonstration library showcasing custom Spark data sources built using Apache Spark 4.0's new Python Data Source API. The project provides various data source connectors for reading from external APIs and services. + +**Important**: This is a demo/educational project and not intended for production use. + +## Tech Stack +- **Language**: Python (3.9-3.12) +- **Framework**: Apache Spark 4.0+ (PySpark) +- **Package Management**: Poetry +- **Documentation**: MkDocs with Material theme +- **Testing**: pytest +- **Dependencies**: PyArrow, requests, faker, and optional extras + +## Project Structure +``` +pyspark_datasources/ +├── __init__.py # Main package exports +├── fake.py # Fake data generator using Faker +├── github.py # GitHub repository data connector +├── googlesheets.py # Public Google Sheets reader +├── huggingface.py # Hugging Face datasets connector +├── kaggle.py # Kaggle datasets connector +├── lance.py # Lance vector database connector +├── opensky.py # OpenSky flight data connector +├── simplejson.py # JSON writer for Databricks DBFS +├── stock.py # Alpha Vantage stock data reader +└── weather.py # Weather data connector +``` + +## Available Data Sources +| Short Name | File | Description | Dependencies | +|---------------|------|-------------|--------------| +| `fake` | fake.py | Generate fake data using Faker | faker | +| `github` | github.py | Read GitHub repository PRs | None | +| `googlesheets`| googlesheets.py | Read public Google Sheets | None | +| `huggingface` | huggingface.py | Access Hugging Face datasets | datasets | +| `kaggle` | kaggle.py | Read Kaggle datasets | kagglehub, pandas | +| `opensky` | opensky.py | Flight data from OpenSky Network | None | +| `simplejson` | simplejson.py | Write JSON to Databricks DBFS | databricks-sdk | +| `stock` | stock.py | Stock data from Alpha Vantage | None | + +## Development Commands + +### Environment Setup +```bash +poetry install # Install dependencies +poetry install --extras all # Install all optional dependencies +poetry shell # Activate virtual environment +``` + +### Testing +```bash +# Note: On macOS, set this environment variable to avoid fork safety issues +export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES + +pytest # Run all tests +pytest tests/test_data_sources.py # Run specific test file +``` + +### Documentation +```bash +mkdocs serve # Start local docs server +mkdocs build # Build static documentation +``` + +### Package Management +```bash +poetry build # Build package +poetry publish # Publish to PyPI (requires auth) +poetry add # Add new dependency +poetry update # Update dependencies +``` + +## Usage Patterns +All data sources follow the Spark Data Source API pattern: + +```python +from pyspark_datasources import FakeDataSource + +# Register the data source +spark.dataSource.register(FakeDataSource) + +# Batch reading +df = spark.read.format("fake").option("numRows", 100).load() + +# Streaming (where supported) +stream = spark.readStream.format("fake").load() +``` + +## Testing Strategy +- Tests use pytest with PySpark session fixtures +- Each data source has basic functionality tests +- Tests verify data reading and schema validation +- Some tests may require external API access + +## Key Implementation Details +- All data sources inherit from Spark's DataSource base class +- Implements reader() method for batch reading +- Some implement streamReader() for streaming +- Schema is defined using PySpark StructType +- Options are passed via Spark's option() method + +## External Dependencies +- **GitHub API**: Uses public API, no auth required +- **Alpha Vantage**: Stock data API (may require API key) +- **Google Sheets**: Public sheets only, no auth +- **Kaggle**: Requires Kaggle API credentials +- **Databricks**: SDK for DBFS access +- **OpenSky**: Public flight data API + +## Common Issues +- Ensure PySpark >= 4.0.0 is installed +- Some data sources require API keys/credentials +- Network connectivity required for external APIs +- Rate limiting may affect some external services + +## Python Data Source API Specification + +### Core Abstract Base Classes + +#### DataSource +Primary abstract base class for custom data sources supporting read/write operations. + +**Key Methods:** +- `__init__(self, options: Dict[str, str])` - Initialize with user options +- `name() -> str` - Return format name (defaults to class name) +- `schema() -> StructType` - Define data source schema +- `reader(schema: StructType) -> DataSourceReader` - Create batch reader +- `writer(schema: StructType, overwrite: bool) -> DataSourceWriter` - Create batch writer +- `streamReader(schema: StructType) -> DataSourceStreamReader` - Create streaming reader +- `streamWriter(schema: StructType, overwrite: bool) -> DataSourceStreamWriter` - Create streaming writer +- `simpleStreamReader(schema: StructType) -> SimpleDataSourceStreamReader` - Create simple streaming reader + +#### DataSourceReader +Abstract base class for reading data from sources. + +**Key Methods:** +- `read(partition) -> Iterator` - Read data from partition, returns tuples/Rows/pyarrow.RecordBatch +- `partitions() -> List[InputPartition]` - Return input partitions for parallel reading + +#### DataSourceStreamReader +Abstract base class for streaming data sources with offset management. + +**Key Methods:** +- `initialOffset() -> Offset` - Return starting offset +- `latestOffset() -> Offset` - Return latest available offset +- `partitions(start: Offset, end: Offset) -> List[InputPartition]` - Get partitions for offset range +- `read(partition) -> Iterator` - Read data from partition +- `commit(end: Offset)` - Mark offsets as processed +- `stop()` - Clean up resources + +#### DataSourceWriter +Abstract base class for writing data to external sources. + +**Key Methods:** +- `write(iterator) -> WriteResult` - Write data from iterator +- `abort(messages: List[WriterCommitMessage])` - Handle write failures +- `commit(messages: List[WriterCommitMessage]) -> WriteResult` - Commit successful writes + +#### DataSourceArrowWriter +Optimized writer using PyArrow RecordBatch for improved performance. + +### Implementation Requirements + +1. **Serialization**: All classes must be pickle serializable +2. **Schema Definition**: Use PySpark StructType for schema specification +3. **Data Types**: Support standard Spark SQL data types +4. **Error Handling**: Implement proper exception handling +5. **Resource Management**: Clean up resources properly in streaming sources +6. **Use load() for paths**: Specify file paths in `load("/path")`, not `option("path", "/path")` + +### Usage Patterns + +```python +# Custom data source implementation +class MyDataSource(DataSource): + def __init__(self, options): + self.options = options + + def name(self): + return "myformat" + + def schema(self): + return StructType([StructField("id", IntegerType(), True)]) + + def reader(self, schema): + return MyDataSourceReader(self.options, schema) + +# Registration and usage +spark.dataSource.register(MyDataSource) +df = spark.read.format("myformat").option("key", "value").load() +``` + +### Performance Optimizations + +1. **Arrow Integration**: Return `pyarrow.RecordBatch` for better serialization +2. **Partitioning**: Implement `partitions()` for parallel processing +3. **Lazy Evaluation**: Defer expensive operations until read time + +## Documentation +- Main docs: https://allisonwang-db.github.io/pyspark-data-sources/ +- Individual data source docs in `docs/datasources/` +- Spark Data Source API: https://spark.apache.org/docs/4.0.0/api/python/tutorial/sql/python_data_source.html +- API Source Code: https://github.com/apache/spark/blob/master/python/pyspark/sql/datasource.py + +### Data Source Docstring Guidelines +When creating new data sources, include these sections in the class docstring: + +**Required Sections:** +- Brief description and `Name: "format_name"` +- `Options` section documenting all parameters with types/defaults +- `Examples` section with registration and basic usage + +**Key Guidelines:** +- **Include schema output**: Show `df.printSchema()` results for clarity +- **Document error cases**: Show what happens with missing files/invalid options +- **Document partitioning strategy**: Show how data sources leverage partitioning to increase performance +- **Document Arrow optimization**: Show how data data sources use Arrow to transmit data From 37ef88fa1d0e5a756f73e9fef32f880b4a24e1b2 Mon Sep 17 00:00:00 2001 From: Allison Wang Date: Mon, 21 Jul 2025 20:36:24 -0700 Subject: [PATCH 2/3] add arrow data source --- RELEASE.md | 15 ++ pyspark_datasources/__init__.py | 1 + pyspark_datasources/arrow.py | 210 +++++++++++++++++++++++ tests/create_test_data.py | 128 ++++++++++++++ tests/data/employees.arrow | Bin 0 -> 1514 bytes tests/data/products.parquet | Bin 0 -> 2032 bytes tests/data/sales/sales_q1.arrow | Bin 0 -> 1282 bytes tests/data/sales/sales_q2.arrow | Bin 0 -> 1282 bytes tests/data/sales/sales_q3.arrow | Bin 0 -> 1282 bytes tests/test_arrow_datasource.py | 141 ++++++++++++++++ tests/test_arrow_datasource_updated.py | 222 +++++++++++++++++++++++++ 11 files changed, 717 insertions(+) create mode 100644 pyspark_datasources/arrow.py create mode 100644 tests/create_test_data.py create mode 100644 tests/data/employees.arrow create mode 100644 tests/data/products.parquet create mode 100644 tests/data/sales/sales_q1.arrow create mode 100644 tests/data/sales/sales_q2.arrow create mode 100644 tests/data/sales/sales_q3.arrow create mode 100644 tests/test_arrow_datasource.py create mode 100644 tests/test_arrow_datasource_updated.py diff --git a/RELEASE.md b/RELEASE.md index 1480068..6fe91c1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -119,6 +119,21 @@ Then follow steps 2-4 above. - Authenticate: `gh auth login` - Check repository access: `gh repo view` +### PyArrow Compatibility Issues + +If you see `objc_initializeAfterForkError` crashes on macOS, set this environment variable: + +```bash +# For single commands +OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES python your_script.py + +# For Poetry environment +OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES poetry run python your_script.py + +# To set permanently in your shell (add to ~/.zshrc or ~/.bash_profile): +export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES +``` + ## Useful Poetry Version Commands ```bash diff --git a/pyspark_datasources/__init__.py b/pyspark_datasources/__init__.py index 680dcda..3076200 100644 --- a/pyspark_datasources/__init__.py +++ b/pyspark_datasources/__init__.py @@ -1,3 +1,4 @@ +from .arrow import ArrowDataSource from .fake import FakeDataSource from .github import GithubDataSource from .googlesheets import GoogleSheetsDataSource diff --git a/pyspark_datasources/arrow.py b/pyspark_datasources/arrow.py new file mode 100644 index 0000000..28d05ed --- /dev/null +++ b/pyspark_datasources/arrow.py @@ -0,0 +1,210 @@ +from typing import List, Iterator, Union +import os +import glob + +import pyarrow as pa +import pyarrow.parquet as pq +from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition +from pyspark.sql.types import StructType + + +class ArrowDataSource(DataSource): + """ + A data source for reading Apache Arrow files (.arrow, .parquet) using PyArrow. + + This data source supports reading Arrow/Parquet files from local filesystem or + cloud storage, leveraging PyArrow's efficient columnar format and returning + PyArrow RecordBatch objects for optimal performance with PySpark's Arrow integration. + + Name: `arrow` + + Options + ------- + - `format` : str, default "auto" + File format: "arrow", "parquet", or "auto" to detect from extension. + + Path Support + ----------- + Supports various path patterns in the load() method: + - Single file: "/path/to/file.arrow" or "/path/to/file.parquet" + - Glob patterns: "/path/to/*.arrow" or "/path/to/data*.parquet" + - Directory: "/path/to/directory" (reads all .arrow/.parquet files) + + Partitioning Strategy + -------------------- + The data source creates one partition per file for parallel processing: + - Single file: 1 partition + - Multiple files: N partitions (one per file) + - Directory: N partitions (one per .arrow/.parquet file found) + + This enables Spark to process multiple files in parallel across different + executor cores, improving performance for large datasets. + + Performance Notes + ---------------- + - Returns PyArrow RecordBatch objects for zero-copy data transfer + - Leverages PySpark 4.0's enhanced Arrow integration + - For DataFrames created in Spark, consider using the new df.to_arrow() method + in PySpark 4.0+ for efficient Arrow conversion + + Examples + -------- + Register the data source: + + >>> from pyspark_datasources import ArrowDataSource + >>> spark.dataSource.register(ArrowDataSource) + + Read a single Arrow file: + + >>> df = spark.read.format("arrow").load("/path/to/employees.arrow") + >>> df.show() + +---+-----------+---+-------+----------+------+ + | id| name|age| salary|department|active| + +---+-----------+---+-------+----------+------+ + | 1|Alice Smith| 28|65000.0| Tech| true| + +---+-----------+---+-------+----------+------+ + + Read a single Parquet file: + + >>> df = spark.read.format("arrow").load("/path/to/products.parquet") + >>> df.show() + + Read multiple files with glob pattern (creates multiple partitions): + + >>> df = spark.read.format("arrow").load("/data/sales/sales_*.arrow") + >>> df.show() + >>> print(f"Number of partitions: {df.rdd.getNumPartitions()}") + + Read all Arrow/Parquet files in a directory: + + >>> df = spark.read.format("arrow").load("/data/warehouse/") + >>> df.show() + + Read with explicit format specification: + + >>> df = spark.read.format("arrow").option("format", "parquet").load("/path/to/data") + >>> df.show() + + Read mixed Arrow and Parquet files in a directory: + + >>> df = spark.read.format("arrow").load("/path/to/mixed_format_dir/") + >>> df.show() + + Working with the result DataFrame and PySpark 4.0 Arrow integration: + + >>> df = spark.read.format("arrow").load("/path/to/data.arrow") + >>> + >>> # Process with Spark + >>> result = df.filter(df.age > 25).groupBy("department").count() + >>> result.show() + >>> + >>> # Convert back to Arrow using PySpark 4.0+ feature + >>> arrow_table = result.to_arrow() # New in PySpark 4.0+ + >>> print(f"Arrow table: {arrow_table}") + + Schema inference example: + + >>> # Schema is automatically inferred from the first file + >>> df = spark.read.format("arrow").load("/path/to/*.parquet") + >>> df.printSchema() + root + |-- product_id: long (nullable = true) + |-- product_name: string (nullable = true) + |-- price: double (nullable = true) + """ + + @classmethod + def name(cls): + return "arrow" + + def schema(self) -> StructType: + path = self.options.get("path") + if not path: + raise ValueError("Path option is required for Arrow data source") + + # Get the first file to determine schema + files = self._get_files(path) + if not files: + raise ValueError(f"No files found at path: {path}") + + # Read schema from first file + file_format = self._detect_format(files[0]) + if file_format == "parquet": + table = pq.read_table(files[0]) + else: # arrow format + with pa.ipc.open_file(files[0]) as reader: + table = reader.read_all() + + # Convert PyArrow schema to Spark schema using PySpark utility + from pyspark.sql.pandas.types import from_arrow_schema + return from_arrow_schema(table.schema) + + def reader(self, schema: StructType) -> "ArrowDataSourceReader": + return ArrowDataSourceReader(schema, self.options) + + def _get_files(self, path: str) -> List[str]: + """Get list of files matching the path pattern.""" + if os.path.isfile(path): + return [path] + elif os.path.isdir(path): + # Find all arrow/parquet files in directory + arrow_files = glob.glob(os.path.join(path, "*.arrow")) + parquet_files = glob.glob(os.path.join(path, "*.parquet")) + return sorted(arrow_files + parquet_files) + else: + # Treat as glob pattern + return sorted(glob.glob(path)) + + def _detect_format(self, file_path: str) -> str: + """Detect file format from extension or option.""" + format_option = self.options.get("format", "auto") + if format_option != "auto": + return format_option + + ext = os.path.splitext(file_path)[1].lower() + if ext == ".parquet": + return "parquet" + elif ext in [".arrow", ".ipc"]: + return "arrow" + else: + # Default to arrow format + return "arrow" + + +class ArrowDataSourceReader(DataSourceReader): + """Reader for Arrow data source.""" + + def __init__(self, schema: StructType, options: dict) -> None: + self.schema = schema + self.options = options + self.path = options.get("path") + if not self.path: + raise ValueError("Path option is required") + + def partitions(self) -> List[InputPartition]: + """Create partitions, one per file for parallel reading.""" + data_source = ArrowDataSource(self.options) + files = data_source._get_files(self.path) + return [InputPartition(file_path) for file_path in files] + + def read(self, partition: InputPartition) -> Iterator[pa.RecordBatch]: + """Read data from a single file partition, returning PyArrow RecordBatch.""" + file_path = partition.value + data_source = ArrowDataSource(self.options) + file_format = data_source._detect_format(file_path) + + try: + if file_format == "parquet": + # Read Parquet file + table = pq.read_table(file_path) + # Convert to RecordBatch iterator + for batch in table.to_batches(): + yield batch + else: + # Read Arrow IPC file + with pa.ipc.open_file(file_path) as reader: + for i in range(reader.num_record_batches): + batch = reader.get_batch(i) + yield batch + except Exception as e: + raise RuntimeError(f"Failed to read file {file_path}: {str(e)}") \ No newline at end of file diff --git a/tests/create_test_data.py b/tests/create_test_data.py new file mode 100644 index 0000000..a713348 --- /dev/null +++ b/tests/create_test_data.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""Create persistent test data files for Arrow data source testing.""" + +import os +import pyarrow as pa +import pyarrow.parquet as pq + + +def create_test_data_files(): + """Create persistent test data files in the tests directory.""" + tests_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.join(tests_dir, "data") + + # Create data directory if it doesn't exist + os.makedirs(data_dir, exist_ok=True) + + # Create sample datasets + + # Dataset 1: Employee data (Arrow format) + employee_data = { + 'id': [1, 2, 3, 4, 5], + 'name': ['Alice Johnson', 'Bob Smith', 'Charlie Brown', 'Diana Prince', 'Eve Wilson'], + 'age': [28, 34, 42, 26, 31], + 'salary': [65000.0, 78000.0, 85000.0, 58000.0, 72000.0], + 'department': ['Engineering', 'Sales', 'Engineering', 'Marketing', 'Sales'], + 'active': [True, True, False, True, True] + } + employee_table = pa.table(employee_data) + + # Write as Arrow IPC file + arrow_path = os.path.join(data_dir, "employees.arrow") + with pa.ipc.new_file(open(arrow_path, 'wb'), employee_table.schema) as writer: + writer.write_table(employee_table) + print(f"Created: {arrow_path}") + + # Dataset 2: Product data (Parquet format) + product_data = { + 'product_id': [101, 102, 103, 104, 105, 106], + 'product_name': ['Laptop Pro', 'Desktop Elite', 'Tablet Max', 'Phone X', 'Watch Sport', 'Headphones Ultra'], + 'category': ['Computer', 'Computer', 'Tablet', 'Phone', 'Wearable', 'Audio'], + 'price': [1299.99, 899.99, 499.99, 799.99, 299.99, 149.99], + 'in_stock': [True, False, True, True, False, True], + 'rating': [4.5, 4.2, 4.7, 4.3, 4.1, 4.6] + } + product_table = pa.table(product_data) + + # Write as Parquet file + parquet_path = os.path.join(data_dir, "products.parquet") + pq.write_table(product_table, parquet_path) + print(f"Created: {parquet_path}") + + # Dataset 3: Multiple partition files (for testing multiple partitions) + sales_data_q1 = { + 'quarter': ['Q1'] * 4, + 'month': ['Jan', 'Feb', 'Mar', 'Jan'], + 'sales_id': [1001, 1002, 1003, 1004], + 'amount': [15000.0, 18000.0, 22000.0, 16500.0], + 'region': ['North', 'South', 'East', 'West'] + } + + sales_data_q2 = { + 'quarter': ['Q2'] * 4, + 'month': ['Apr', 'May', 'Jun', 'Apr'], + 'sales_id': [2001, 2002, 2003, 2004], + 'amount': [19000.0, 21000.0, 25000.0, 18500.0], + 'region': ['North', 'South', 'East', 'West'] + } + + sales_data_q3 = { + 'quarter': ['Q3'] * 4, + 'month': ['Jul', 'Aug', 'Sep', 'Jul'], + 'sales_id': [3001, 3002, 3003, 3004], + 'amount': [23000.0, 26000.0, 28000.0, 24000.0], + 'region': ['North', 'South', 'East', 'West'] + } + + # Create sales directory for partitioned data + sales_dir = os.path.join(data_dir, "sales") + os.makedirs(sales_dir, exist_ok=True) + + # Write multiple Arrow files for partition testing + for quarter_data, quarter_name in [(sales_data_q1, 'Q1'), (sales_data_q2, 'Q2'), (sales_data_q3, 'Q3')]: + table = pa.table(quarter_data) + sales_path = os.path.join(sales_dir, f"sales_{quarter_name.lower()}.arrow") + with pa.ipc.new_file(open(sales_path, 'wb'), table.schema) as writer: + writer.write_table(table) + print(f"Created: {sales_path}") + + # Create mixed format directory (Arrow + Parquet files) + mixed_dir = os.path.join(data_dir, "mixed") + os.makedirs(mixed_dir, exist_ok=True) + + # Write one file as Arrow + customer_data_arrow = { + 'customer_id': [1, 2, 3], + 'name': ['Customer A', 'Customer B', 'Customer C'], + 'email': ['a@example.com', 'b@example.com', 'c@example.com'] + } + customer_table = pa.table(customer_data_arrow) + customer_arrow_path = os.path.join(mixed_dir, "customers.arrow") + with pa.ipc.new_file(open(customer_arrow_path, 'wb'), customer_table.schema) as writer: + writer.write_table(customer_table) + print(f"Created: {customer_arrow_path}") + + # Write another as Parquet with same schema + customer_data_parquet = { + 'customer_id': [4, 5, 6], + 'name': ['Customer D', 'Customer E', 'Customer F'], + 'email': ['d@example.com', 'e@example.com', 'f@example.com'] + } + customer_table_2 = pa.table(customer_data_parquet) + customer_parquet_path = os.path.join(mixed_dir, "customers_extra.parquet") + pq.write_table(customer_table_2, customer_parquet_path) + print(f"Created: {customer_parquet_path}") + + print(f"\nTest data files created in: {data_dir}") + print("Directory structure:") + for root, dirs, files in os.walk(data_dir): + level = root.replace(data_dir, '').count(os.sep) + indent = ' ' * 2 * level + print(f"{indent}{os.path.basename(root)}/") + sub_indent = ' ' * 2 * (level + 1) + for file in files: + print(f"{sub_indent}{file}") + + +if __name__ == "__main__": + create_test_data_files() \ No newline at end of file diff --git a/tests/data/employees.arrow b/tests/data/employees.arrow new file mode 100644 index 0000000000000000000000000000000000000000..3f56215133d276b7937a60567bd60567202be989 GIT binary patch literal 1514 zcmeHHJ5L)y5FWq4f+M0p2tlNr0tKXjbfL2eNC+iLWD43WW@WE(UZe8~lFmOOh0B!u ziAa=`6#R%3k)nu{k;wP$BRP?lGGontJFlDB+3~hJoi|6%MeabS6_I%%%}63s(vT*0 zph_LGKx$&GhgIMK9b&G46Yx9m7Pt;h@jV1i@q0@_2XHZ}TbMH~@(Mot-mgKw$0a}J zJ<+EsF9zBcA|K#ac;6a*(~>Xi7^zI8&aI3a=tm}7s0FG zO)z`c&P-RQ2gAM{4eieGI6WMgqW`k5^33RTCm){MU8Af@-{!`4_1>9Ik4y$%#F98m zq^BO_S-|ChPXapl^1bD7_SWI}@~4+O-#mW$;c!;`%&Y*{!0h>+?HQ{zis&7xOpi+P zYn8v(MKr*d#8?vNm&h4RFva}$C6M@DuuCtnat7Eh_HrB=`Z?g=xrPz+|BZRrH->LZ I+r9C>0Uao&-v9sr literal 0 HcmV?d00001 diff --git a/tests/data/products.parquet b/tests/data/products.parquet new file mode 100644 index 0000000000000000000000000000000000000000..877c646428a5d07dec0a31d0394e64bb6e414bbe GIT binary patch literal 2032 zcmb7FPi)&%7=Lz?)JYr&5zn%xh~`qOlh%=LEuyFsw$rp(X|pu6OAAPw>v&1pB=#I< zX(uEQ5~zntTz2BjffEM~h!c}I?6h&%0XT5kWf!<2af9!BiMyuBgm_Bc``-8G{eIu~ zJ=-O{JV!(HHZ81Ek&+Z4*^o`59Jv#n43SzwAoYZ>)FAz_KUgWD61{KQ39v?8HH=n7&odyze@mpM-6}HmfdLJDKgK@0sUGzeQ%^ zXF}uz;vg&of&J8KNP?r=d|3eZ+j7t`>etsF-F{!R0m%Wlxj!G@Qc>f@S)4&GC)`WemK-uGg?^ zLJ7!1qkXI6JJtcAXo`5I-)Pr}Br0j8sgg7eDGMZ$RL){jUJ|7fUizC;o2u}Co)zU1 zRFqI2u;D@D$UdxwU)|XS2zKKhsDo79F zL`ln{UM}CLywktL@=54T9P_%N>-r}AalEJNJ>973E4p6Ob2X+jz+nTWoUUKT752dj zyIoXm;t5A6tznRHfMOA+{?WS#1sg*@WJ(r-Fjn}B40qeAW|2iCZ1HJEG||+!GSK& zTyV#5fnVeebO(w+OCSf-MlB?jS!>?dAW%^x4NrPWj9A1(I^4wpnJ&GvmU|1 z-s~O>nx iWEs9Di1Weql07ou(#Y|B`hTZhsdCSB1A#qA_o9OLFp7o2#H@PQX)%4IEdKbj|3f}qT&P;lpKL0 z&=OMQ0F>MT-rEmJ6H(CeqM6zG?#_5O?M~;QyDstt_yv)M5Or~-B4w#z1r!|U0<8w_ z7%uROe1M)oA!rTcfPBUM~feBbf7Mx|XzK9J)Mx;XHuR6{5{SMzknO*Hnd}8OyLp5K%<{&M3k;2XH*#d8ay|@w zCGSRBo{zO$)OPmA_p<=9eecIf79PiW7H$V=*6pPkoB*yJ-i-eCE^bdn-I#_spCg;_ gwLqK?u9qB=0k=j@?$iG}<@@>VoraiOTmAQa10giEVPN_ z;Ha70YBy*Fx`$jn>rYtNn>~YJ z3nVkZp2Im4I0r&a(J8UIjo&sdc1YOO?1x&tW-Odd7OY+~#>=b|veoT**T%n;iR0X} zebV(;oi*~N4tK_-0dTabKgtCleRDwD|6V$1q`g)!h5)nN&6A4U8~$MUqv6VV_FPxI z`= 0 + + # Verify the data + batch = data_batches[0] + assert batch.num_rows == 5 + assert batch.num_columns == 6 + + print(f"\nDirect access test results:") + print(f"Schema: {schema}") + print(f"Data shape: {batch.num_rows} rows, {batch.num_columns} columns") + print(f"Data preview:\n{batch.to_pandas()}") \ No newline at end of file From 05ebc9d5fbac402a856e1dbb12e1353df9ef973b Mon Sep 17 00:00:00 2001 From: Allison Wang Date: Mon, 4 Aug 2025 11:55:10 -0700 Subject: [PATCH 3/3] update --- CLAUDE.md | 4 +- pyspark_datasources/arrow.py | 85 ++-------- tests/create_test_data.py | 128 -------------- tests/data/employees.arrow | Bin 1514 -> 0 bytes tests/data/products.parquet | Bin 2032 -> 0 bytes tests/data/sales/sales_q1.arrow | Bin 1282 -> 0 bytes tests/data/sales/sales_q2.arrow | Bin 1282 -> 0 bytes tests/data/sales/sales_q3.arrow | Bin 1282 -> 0 bytes tests/test_arrow_datasource.py | 141 ---------------- tests/test_arrow_datasource_updated.py | 222 ------------------------- tests/test_data_sources.py | 87 ++++++++++ 11 files changed, 108 insertions(+), 559 deletions(-) delete mode 100644 tests/create_test_data.py delete mode 100644 tests/data/employees.arrow delete mode 100644 tests/data/products.parquet delete mode 100644 tests/data/sales/sales_q1.arrow delete mode 100644 tests/data/sales/sales_q2.arrow delete mode 100644 tests/data/sales/sales_q3.arrow delete mode 100644 tests/test_arrow_datasource.py delete mode 100644 tests/test_arrow_datasource_updated.py diff --git a/CLAUDE.md b/CLAUDE.md index adb21ec..3705425 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,8 +55,9 @@ poetry shell # Activate virtual environment # Note: On macOS, set this environment variable to avoid fork safety issues export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES -pytest # Run all tests +pytest # Run all tests pytest tests/test_data_sources.py # Run specific test file +pytest tests/test_data_sources.py::test_arrow_datasource_single_file -v # Run a specific test ``` ### Documentation @@ -66,6 +67,7 @@ mkdocs build # Build static documentation ``` ### Package Management +Please refer to RELEASE.md for more details. ```bash poetry build # Build package poetry publish # Publish to PyPI (requires auth) diff --git a/pyspark_datasources/arrow.py b/pyspark_datasources/arrow.py index 28d05ed..899175c 100644 --- a/pyspark_datasources/arrow.py +++ b/pyspark_datasources/arrow.py @@ -3,39 +3,33 @@ import glob import pyarrow as pa -import pyarrow.parquet as pq from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition from pyspark.sql.types import StructType class ArrowDataSource(DataSource): """ - A data source for reading Apache Arrow files (.arrow, .parquet) using PyArrow. + A data source for reading Apache Arrow files (.arrow) using PyArrow. - This data source supports reading Arrow/Parquet files from local filesystem or + This data source supports reading Arrow IPC files from local filesystem or cloud storage, leveraging PyArrow's efficient columnar format and returning PyArrow RecordBatch objects for optimal performance with PySpark's Arrow integration. Name: `arrow` - Options - ------- - - `format` : str, default "auto" - File format: "arrow", "parquet", or "auto" to detect from extension. - Path Support ----------- Supports various path patterns in the load() method: - - Single file: "/path/to/file.arrow" or "/path/to/file.parquet" - - Glob patterns: "/path/to/*.arrow" or "/path/to/data*.parquet" - - Directory: "/path/to/directory" (reads all .arrow/.parquet files) + - Single file: "/path/to/file.arrow" + - Glob patterns: "/path/to/*.arrow" or "/path/to/data*.arrow" + - Directory: "/path/to/directory" (reads all .arrow files) Partitioning Strategy -------------------- The data source creates one partition per file for parallel processing: - Single file: 1 partition - Multiple files: N partitions (one per file) - - Directory: N partitions (one per .arrow/.parquet file found) + - Directory: N partitions (one per .arrow file found) This enables Spark to process multiple files in parallel across different executor cores, improving performance for large datasets. @@ -64,32 +58,17 @@ class ArrowDataSource(DataSource): | 1|Alice Smith| 28|65000.0| Tech| true| +---+-----------+---+-------+----------+------+ - Read a single Parquet file: - - >>> df = spark.read.format("arrow").load("/path/to/products.parquet") - >>> df.show() - Read multiple files with glob pattern (creates multiple partitions): >>> df = spark.read.format("arrow").load("/data/sales/sales_*.arrow") >>> df.show() >>> print(f"Number of partitions: {df.rdd.getNumPartitions()}") - Read all Arrow/Parquet files in a directory: + Read all Arrow files in a directory: >>> df = spark.read.format("arrow").load("/data/warehouse/") >>> df.show() - Read with explicit format specification: - - >>> df = spark.read.format("arrow").option("format", "parquet").load("/path/to/data") - >>> df.show() - - Read mixed Arrow and Parquet files in a directory: - - >>> df = spark.read.format("arrow").load("/path/to/mixed_format_dir/") - >>> df.show() - Working with the result DataFrame and PySpark 4.0 Arrow integration: >>> df = spark.read.format("arrow").load("/path/to/data.arrow") @@ -105,7 +84,7 @@ class ArrowDataSource(DataSource): Schema inference example: >>> # Schema is automatically inferred from the first file - >>> df = spark.read.format("arrow").load("/path/to/*.parquet") + >>> df = spark.read.format("arrow").load("/path/to/*.arrow") >>> df.printSchema() root |-- product_id: long (nullable = true) @@ -127,13 +106,9 @@ def schema(self) -> StructType: if not files: raise ValueError(f"No files found at path: {path}") - # Read schema from first file - file_format = self._detect_format(files[0]) - if file_format == "parquet": - table = pq.read_table(files[0]) - else: # arrow format - with pa.ipc.open_file(files[0]) as reader: - table = reader.read_all() + # Read schema from first file (Arrow IPC format) + with pa.ipc.open_file(files[0]) as reader: + table = reader.read_all() # Convert PyArrow schema to Spark schema using PySpark utility from pyspark.sql.pandas.types import from_arrow_schema @@ -147,28 +122,13 @@ def _get_files(self, path: str) -> List[str]: if os.path.isfile(path): return [path] elif os.path.isdir(path): - # Find all arrow/parquet files in directory + # Find all arrow files in directory arrow_files = glob.glob(os.path.join(path, "*.arrow")) - parquet_files = glob.glob(os.path.join(path, "*.parquet")) - return sorted(arrow_files + parquet_files) + return sorted(arrow_files) else: # Treat as glob pattern return sorted(glob.glob(path)) - def _detect_format(self, file_path: str) -> str: - """Detect file format from extension or option.""" - format_option = self.options.get("format", "auto") - if format_option != "auto": - return format_option - - ext = os.path.splitext(file_path)[1].lower() - if ext == ".parquet": - return "parquet" - elif ext in [".arrow", ".ipc"]: - return "arrow" - else: - # Default to arrow format - return "arrow" class ArrowDataSourceReader(DataSourceReader): @@ -190,21 +150,12 @@ def partitions(self) -> List[InputPartition]: def read(self, partition: InputPartition) -> Iterator[pa.RecordBatch]: """Read data from a single file partition, returning PyArrow RecordBatch.""" file_path = partition.value - data_source = ArrowDataSource(self.options) - file_format = data_source._detect_format(file_path) try: - if file_format == "parquet": - # Read Parquet file - table = pq.read_table(file_path) - # Convert to RecordBatch iterator - for batch in table.to_batches(): + # Read Arrow IPC file + with pa.ipc.open_file(file_path) as reader: + for i in range(reader.num_record_batches): + batch = reader.get_batch(i) yield batch - else: - # Read Arrow IPC file - with pa.ipc.open_file(file_path) as reader: - for i in range(reader.num_record_batches): - batch = reader.get_batch(i) - yield batch except Exception as e: - raise RuntimeError(f"Failed to read file {file_path}: {str(e)}") \ No newline at end of file + raise RuntimeError(f"Failed to read Arrow file {file_path}: {str(e)}") \ No newline at end of file diff --git a/tests/create_test_data.py b/tests/create_test_data.py deleted file mode 100644 index a713348..0000000 --- a/tests/create_test_data.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -"""Create persistent test data files for Arrow data source testing.""" - -import os -import pyarrow as pa -import pyarrow.parquet as pq - - -def create_test_data_files(): - """Create persistent test data files in the tests directory.""" - tests_dir = os.path.dirname(os.path.abspath(__file__)) - data_dir = os.path.join(tests_dir, "data") - - # Create data directory if it doesn't exist - os.makedirs(data_dir, exist_ok=True) - - # Create sample datasets - - # Dataset 1: Employee data (Arrow format) - employee_data = { - 'id': [1, 2, 3, 4, 5], - 'name': ['Alice Johnson', 'Bob Smith', 'Charlie Brown', 'Diana Prince', 'Eve Wilson'], - 'age': [28, 34, 42, 26, 31], - 'salary': [65000.0, 78000.0, 85000.0, 58000.0, 72000.0], - 'department': ['Engineering', 'Sales', 'Engineering', 'Marketing', 'Sales'], - 'active': [True, True, False, True, True] - } - employee_table = pa.table(employee_data) - - # Write as Arrow IPC file - arrow_path = os.path.join(data_dir, "employees.arrow") - with pa.ipc.new_file(open(arrow_path, 'wb'), employee_table.schema) as writer: - writer.write_table(employee_table) - print(f"Created: {arrow_path}") - - # Dataset 2: Product data (Parquet format) - product_data = { - 'product_id': [101, 102, 103, 104, 105, 106], - 'product_name': ['Laptop Pro', 'Desktop Elite', 'Tablet Max', 'Phone X', 'Watch Sport', 'Headphones Ultra'], - 'category': ['Computer', 'Computer', 'Tablet', 'Phone', 'Wearable', 'Audio'], - 'price': [1299.99, 899.99, 499.99, 799.99, 299.99, 149.99], - 'in_stock': [True, False, True, True, False, True], - 'rating': [4.5, 4.2, 4.7, 4.3, 4.1, 4.6] - } - product_table = pa.table(product_data) - - # Write as Parquet file - parquet_path = os.path.join(data_dir, "products.parquet") - pq.write_table(product_table, parquet_path) - print(f"Created: {parquet_path}") - - # Dataset 3: Multiple partition files (for testing multiple partitions) - sales_data_q1 = { - 'quarter': ['Q1'] * 4, - 'month': ['Jan', 'Feb', 'Mar', 'Jan'], - 'sales_id': [1001, 1002, 1003, 1004], - 'amount': [15000.0, 18000.0, 22000.0, 16500.0], - 'region': ['North', 'South', 'East', 'West'] - } - - sales_data_q2 = { - 'quarter': ['Q2'] * 4, - 'month': ['Apr', 'May', 'Jun', 'Apr'], - 'sales_id': [2001, 2002, 2003, 2004], - 'amount': [19000.0, 21000.0, 25000.0, 18500.0], - 'region': ['North', 'South', 'East', 'West'] - } - - sales_data_q3 = { - 'quarter': ['Q3'] * 4, - 'month': ['Jul', 'Aug', 'Sep', 'Jul'], - 'sales_id': [3001, 3002, 3003, 3004], - 'amount': [23000.0, 26000.0, 28000.0, 24000.0], - 'region': ['North', 'South', 'East', 'West'] - } - - # Create sales directory for partitioned data - sales_dir = os.path.join(data_dir, "sales") - os.makedirs(sales_dir, exist_ok=True) - - # Write multiple Arrow files for partition testing - for quarter_data, quarter_name in [(sales_data_q1, 'Q1'), (sales_data_q2, 'Q2'), (sales_data_q3, 'Q3')]: - table = pa.table(quarter_data) - sales_path = os.path.join(sales_dir, f"sales_{quarter_name.lower()}.arrow") - with pa.ipc.new_file(open(sales_path, 'wb'), table.schema) as writer: - writer.write_table(table) - print(f"Created: {sales_path}") - - # Create mixed format directory (Arrow + Parquet files) - mixed_dir = os.path.join(data_dir, "mixed") - os.makedirs(mixed_dir, exist_ok=True) - - # Write one file as Arrow - customer_data_arrow = { - 'customer_id': [1, 2, 3], - 'name': ['Customer A', 'Customer B', 'Customer C'], - 'email': ['a@example.com', 'b@example.com', 'c@example.com'] - } - customer_table = pa.table(customer_data_arrow) - customer_arrow_path = os.path.join(mixed_dir, "customers.arrow") - with pa.ipc.new_file(open(customer_arrow_path, 'wb'), customer_table.schema) as writer: - writer.write_table(customer_table) - print(f"Created: {customer_arrow_path}") - - # Write another as Parquet with same schema - customer_data_parquet = { - 'customer_id': [4, 5, 6], - 'name': ['Customer D', 'Customer E', 'Customer F'], - 'email': ['d@example.com', 'e@example.com', 'f@example.com'] - } - customer_table_2 = pa.table(customer_data_parquet) - customer_parquet_path = os.path.join(mixed_dir, "customers_extra.parquet") - pq.write_table(customer_table_2, customer_parquet_path) - print(f"Created: {customer_parquet_path}") - - print(f"\nTest data files created in: {data_dir}") - print("Directory structure:") - for root, dirs, files in os.walk(data_dir): - level = root.replace(data_dir, '').count(os.sep) - indent = ' ' * 2 * level - print(f"{indent}{os.path.basename(root)}/") - sub_indent = ' ' * 2 * (level + 1) - for file in files: - print(f"{sub_indent}{file}") - - -if __name__ == "__main__": - create_test_data_files() \ No newline at end of file diff --git a/tests/data/employees.arrow b/tests/data/employees.arrow deleted file mode 100644 index 3f56215133d276b7937a60567bd60567202be989..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1514 zcmeHHJ5L)y5FWq4f+M0p2tlNr0tKXjbfL2eNC+iLWD43WW@WE(UZe8~lFmOOh0B!u ziAa=`6#R%3k)nu{k;wP$BRP?lGGontJFlDB+3~hJoi|6%MeabS6_I%%%}63s(vT*0 zph_LGKx$&GhgIMK9b&G46Yx9m7Pt;h@jV1i@q0@_2XHZ}TbMH~@(Mot-mgKw$0a}J zJ<+EsF9zBcA|K#ac;6a*(~>Xi7^zI8&aI3a=tm}7s0FG zO)z`c&P-RQ2gAM{4eieGI6WMgqW`k5^33RTCm){MU8Af@-{!`4_1>9Ik4y$%#F98m zq^BO_S-|ChPXapl^1bD7_SWI}@~4+O-#mW$;c!;`%&Y*{!0h>+?HQ{zis&7xOpi+P zYn8v(MKr*d#8?vNm&h4RFva}$C6M@DuuCtnat7Eh_HrB=`Z?g=xrPz+|BZRrH->LZ I+r9C>0Uao&-v9sr diff --git a/tests/data/products.parquet b/tests/data/products.parquet deleted file mode 100644 index 877c646428a5d07dec0a31d0394e64bb6e414bbe..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2032 zcmb7FPi)&%7=Lz?)JYr&5zn%xh~`qOlh%=LEuyFsw$rp(X|pu6OAAPw>v&1pB=#I< zX(uEQ5~zntTz2BjffEM~h!c}I?6h&%0XT5kWf!<2af9!BiMyuBgm_Bc``-8G{eIu~ zJ=-O{JV!(HHZ81Ek&+Z4*^o`59Jv#n43SzwAoYZ>)FAz_KUgWD61{KQ39v?8HH=n7&odyze@mpM-6}HmfdLJDKgK@0sUGzeQ%^ zXF}uz;vg&of&J8KNP?r=d|3eZ+j7t`>etsF-F{!R0m%Wlxj!G@Qc>f@S)4&GC)`WemK-uGg?^ zLJ7!1qkXI6JJtcAXo`5I-)Pr}Br0j8sgg7eDGMZ$RL){jUJ|7fUizC;o2u}Co)zU1 zRFqI2u;D@D$UdxwU)|XS2zKKhsDo79F zL`ln{UM}CLywktL@=54T9P_%N>-r}AalEJNJ>973E4p6Ob2X+jz+nTWoUUKT752dj zyIoXm;t5A6tznRHfMOA+{?WS#1sg*@WJ(r-Fjn}B40qeAW|2iCZ1HJEG||+!GSK& zTyV#5fnVeebO(w+OCSf-MlB?jS!>?dAW%^x4NrPWj9A1(I^4wpnJ&GvmU|1 z-s~O>nx iWEs9Di1Weql07ou(#Y|B`hTZhsdCSB1A#qA_o9OLFp7o2#H@PQX)%4IEdKbj|3f}qT&P;lpKL0 z&=OMQ0F>MT-rEmJ6H(CeqM6zG?#_5O?M~;QyDstt_yv)M5Or~-B4w#z1r!|U0<8w_ z7%uROe1M)oA!rTcfPBUM~feBbf7Mx|XzK9J)Mx;XHuR6{5{SMzknO*Hnd}8OyLp5K%<{&M3k;2XH*#d8ay|@w zCGSRBo{zO$)OPmA_p<=9eecIf79PiW7H$V=*6pPkoB*yJ-i-eCE^bdn-I#_spCg;_ gwLqK?u9qB=0k=j@?$iG}<@@>VoraiOTmAQa10giEVPN_ z;Ha70YBy*Fx`$jn>rYtNn>~YJ z3nVkZp2Im4I0r&a(J8UIjo&sdc1YOO?1x&tW-Odd7OY+~#>=b|veoT**T%n;iR0X} zebV(;oi*~N4tK_-0dTabKgtCleRDwD|6V$1q`g)!h5)nN&6A4U8~$MUqv6VV_FPxI z`= 0 - - # Verify the data - batch = data_batches[0] - assert batch.num_rows == 5 - assert batch.num_columns == 6 - - print(f"\nDirect access test results:") - print(f"Schema: {schema}") - print(f"Data shape: {batch.num_rows} rows, {batch.num_columns} columns") - print(f"Data preview:\n{batch.to_pandas()}") \ No newline at end of file diff --git a/tests/test_data_sources.py b/tests/test_data_sources.py index 54cb133..836c416 100644 --- a/tests/test_data_sources.py +++ b/tests/test_data_sources.py @@ -1,4 +1,7 @@ import pytest +import tempfile +import os +import pyarrow as pa from pyspark.sql import SparkSession from pyspark_datasources import * @@ -91,3 +94,87 @@ def test_salesforce_datasource_registration(spark): error_msg = str(e).lower() # The error can be about unsupported mode or missing writer assert "unsupported" in error_msg or "writer" in error_msg or "not implemented" in error_msg + + +def test_arrow_datasource_single_file(spark): + """Test reading a single Arrow file.""" + spark.dataSource.register(ArrowDataSource) + + # Create test data + test_data = pa.table({ + 'id': [1, 2, 3], + 'name': ['Alice', 'Bob', 'Charlie'], + 'age': [25, 30, 35] + }) + + # Write to temporary Arrow file + with tempfile.NamedTemporaryFile(suffix='.arrow', delete=False) as tmp_file: + tmp_path = tmp_file.name + + try: + with pa.ipc.new_file(tmp_path, test_data.schema) as writer: + writer.write_table(test_data) + + # Read using Arrow data source + df = spark.read.format("arrow").load(tmp_path) + + # Verify results + assert df.count() == 3 + assert len(df.columns) == 3 + assert set(df.columns) == {'id', 'name', 'age'} + + # Verify data content + rows = df.collect() + assert len(rows) == 3 + assert rows[0]['name'] == 'Alice' + + finally: + # Clean up + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + +def test_arrow_datasource_multiple_files(spark): + """Test reading multiple Arrow files from a directory.""" + spark.dataSource.register(ArrowDataSource) + + # Create test data for multiple files + test_data1 = pa.table({ + 'id': [1, 2], + 'name': ['Alice', 'Bob'], + 'department': ['Engineering', 'Sales'] + }) + + test_data2 = pa.table({ + 'id': [3, 4], + 'name': ['Charlie', 'Diana'], + 'department': ['Marketing', 'HR'] + }) + + # Create temporary directory + with tempfile.TemporaryDirectory() as tmp_dir: + # Write multiple Arrow files + file1_path = os.path.join(tmp_dir, 'data1.arrow') + file2_path = os.path.join(tmp_dir, 'data2.arrow') + + with pa.ipc.new_file(file1_path, test_data1.schema) as writer: + writer.write_table(test_data1) + + with pa.ipc.new_file(file2_path, test_data2.schema) as writer: + writer.write_table(test_data2) + + # Read using Arrow data source from directory + df = spark.read.format("arrow").load(tmp_dir) + + # Verify results + assert df.count() == 4 # 2 rows from each file + assert len(df.columns) == 3 + assert set(df.columns) == {'id', 'name', 'department'} + + # Verify partitioning (should have 2 partitions, one per file) + assert df.rdd.getNumPartitions() == 2 + + # Verify all data is present + rows = df.collect() + names = {row['name'] for row in rows} + assert names == {'Alice', 'Bob', 'Charlie', 'Diana'}