<!---
  Licensed to the Apache Software Foundation (ASF) under one
  or more contributor license agreements.  See the NOTICE file
  distributed with this work for additional information
  regarding copyright ownership.  The ASF licenses this file
  to you under the Apache License, Version 2.0 (the
  "License"); you may not use this file except in compliance
  with the License.  You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing,
  software distributed under the License is distributed on an
  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  KIND, either express or implied.  See the License for the
  specific language governing permissions and limitations
  under the License.
-->

# DataFrame API with Ballista

This notebook demonstrates the DataFrame API available in Ballista.

The DataFrame API provides a programmatic way to build queries, which can be
more convenient than writing SQL for complex transformations.

In [None]:
from ballista import BallistaSessionContext, setup_test_cluster
from datafusion import col, lit
from datafusion import functions as f

# Set up test cluster and connect
host, port = setup_test_cluster()
ctx = BallistaSessionContext(f"df://{host}:{port}")

# Register sample data
ctx.register_parquet("test_data", "../testdata/test.parquet")
ctx.register_csv("csv_data", "../testdata/test.csv", has_header=True)

print(f"Connected! Session ID: {ctx.session_id}")

## Basic Operations

In [None]:
# Read a table as a DataFrame
df = ctx.table("test_data")

# Display schema
print("Schema:")
for field in df.schema():
    print(f"  {field.name}: {field.type}")

In [None]:
# Show first few rows
df.show(5)

## Selecting Columns

In [None]:
# Select specific columns by name
df.select("id", "bool_col", "tinyint_col")

In [None]:
# Select with column expressions
df.select(
    col("id"),
    col("tinyint_col").alias("tiny"),
    (col("id") * lit(10)).alias("id_times_10")
)

## Filtering Data

In [None]:
# Simple filter
df.filter(col("id") > lit(4))

In [None]:
# Complex filter with AND/OR
df.filter(
    (col("id") >= lit(2)) & (col("id") <= lit(5))
)

In [None]:
# Filter with boolean column
df.filter(col("bool_col") == lit(True))

## Sorting

In [None]:
# Sort ascending
df.sort(col("id").sort(ascending=True))

In [None]:
# Sort descending
df.sort(col("id").sort(ascending=False))

## Limiting Results

In [None]:
# Limit number of rows
df.limit(3)

## Aggregations

In [None]:
# Count all rows
result = df.aggregate([], [f.count_star().alias("total_count")])
result

In [None]:
# Group by and aggregate
df.aggregate(
    [col("bool_col")],
    [
        f.count_star().alias("count"),
        f.sum(col("id")).alias("sum_id"),
        f.avg(col("id")).alias("avg_id"),
    ]
)

## Distinct Values

In [None]:
# Get distinct values
df.select("bool_col").distinct()

## Chaining Operations

DataFrame operations can be chained together to build complex transformations.

In [None]:
# Complex chained query
result = (
    ctx.table("test_data")
    .select("id", "bool_col", "tinyint_col")
    .filter(col("id") > lit(2))
    .sort(col("id").sort(ascending=False))
    .limit(5)
)

result

In [None]:
# View the execution plan for the chained query
print(result.explain())

In [None]:
# Visual execution plan
result.explain_visual()

## Collecting Results

In [None]:
# Collect as Arrow batches
batches = result.collect()
print(f"Got {len(batches)} batch(es)")
print(f"Total rows: {sum(len(batch) for batch in batches)}")

In [None]:
# Collect as Arrow table
table = result.to_arrow_table()
print(f"Arrow table: {table.num_rows} rows, {table.num_columns} columns")

In [None]:
# Convert to Pandas
pdf = result.to_pandas()
pdf

In [None]:
# Get the count without collecting all data
count = ctx.table("test_data").count()
print(f"Total rows in test_data: {count}")

## Working with CSV Data

In [None]:
# Read CSV directly without registering
csv_df = ctx.read_csv("../testdata/test.csv", has_header=True)

# Show schema and data
print("CSV Schema:")
for field in csv_df.schema():
    print(f"  {field.name}: {field.type}")

csv_df

In [None]:
# Filter CSV data
csv_df.filter(col("a") > lit(2))

## Next Steps

- See `distributed_queries.ipynb` for examples of distributed query execution
- Check the [DataFusion Python documentation](https://datafusion.apache.org/python/) for more DataFrame operations
- Review the SQL magic commands in `getting_started.ipynb` for interactive querying