# Getting Started with Snowpark Connect for Apache Spark

Snowpark Connect allows you to run the **PySpark DataFrame API** on **Snowflake infrastructure**. 

**In this guide, you will learn:**
- How Snowpark Connect executes PySpark on Snowflake infrastructure
- How to build production pipelines with Snowpark Connect
    - Data ingestion patterns (tables, stages, cloud storage)
    - Transformations, joins, and aggregations
    - Writing data with partitioning and compression
    - Integrating with telemetry and best practices


## Snowpark Connect - Key Concepts

**Execution Model:**
- Your DataFrame operations are translated to Snowflake SQL
- Computation happens in Snowflake warehouses
- Results stream back via Apache Arrow format
- No Spark cluster, driver, or executors

**Query Pushdown:**
- ‚úÖ **Fully Optimized:** DataFrame operations, SQL functions, aggregations push down to Snowflake
- ‚ö†Ô∏è **Performance Impact:** Python UDFs run client-side (fetch data ‚Üí process ‚Üí send back)
- üí° **Better Alternative:** Use built-in SQL functions instead of UDFs

**CLI: Snowpark Submit**
- Run Spark workloads as batch jobs using Snowpark Submit CLI
- To install: ```pip install snowpark-submit```
- To submit a job, run: 

``` 
snowpark-submit \
  --snowflake-workload-name MY_JOB \
  --snowflake-connection-name MY_CONNECTION \
  --compute-pool MY_COMPUTE_POOL \
  app.py
```
[CLI examples](https://docs.snowflake.com/en/developer-guide/snowpark-connect/snowpark-submit-examples)

## Demo Pipeline: Tasty Bytes Menu Analytics

In this section, we will cover key aspects of using Snowpark Connect through a production-ready pipeline that analyzes menu profitability for Tasty Bytes, a fictitious global food truck network. 

### Pipeline structure
 1. Setup & Configuration
 2. Telemetry Initialization
 3. Data Ingestion
 4. Data Validation
 5. Transformations (profit analysis)
 6. Data Quality Checks
 7. Write Output
 8. Cleanup & Summary


In [None]:
# =============================================================================
# STEP 0: INITIALIZE SESSION
# =============================================================================

# import packages - select "snowpark_connect" from the packages dropdown
from snowflake import snowpark_connect
from snowflake.snowpark.context import get_active_session

# initialize session
session = get_active_session()
spark = snowpark_connect.server.init_spark_session()

# print session info
print(session)

In [None]:
# =============================================================================
# STEP 1: SETUP & CONFIGURATION
# =============================================================================

import os
import uuid
import logging
from datetime import datetime
from dataclasses import dataclass
from typing import Optional, List

from pyspark.sql import DataFrame
from pyspark.sql.functions import (
    col, lit, when, coalesce, trim, upper,
    sum, avg, count, min, max, countDistinct,
    current_timestamp, round
)

# --- Configuration for Tasty Bytes Menu Pipeline ---
@dataclass
class PipelineConfig:
    """Centralized configuration for the menu analytics pipeline."""
    
    # Pipeline metadata
    pipeline_name: str = "tastybytes_menu_analytics"
    pipeline_version: str = "1.0.0"
    environment: str = os.getenv("ENVIRONMENT", "dev")
    
    # Snowflake context
    role: str = "SNOWFLAKE_LEARNING_ROLE"
    warehouse: str = "SNOWFLAKE_LEARNING_WH"
    database: str = "SNOWFLAKE_LEARNING_DB"
    
    # Source configuration
    stage_url: str = "s3://sfquickstarts/tastybytes/"
    stage_path: str = "@blob_stage/raw_pos/menu/"
    source_table: str = "MENU_RAW"
    
    # Output configuration
    output_table: str = "MENU_BRAND_SUMMARY"
    output_mode: str = "overwrite"
    
    # Quality thresholds
    min_expected_rows: int = 50
    max_expected_rows: int = 10_000
    max_null_percentage: float = 0.05
    min_profit_margin: float = 0.0  # No negative margins allowed

config = PipelineConfig()

# --- Initialize sessions ---
# Note: session and spark should already be initialized
print(f"üöÄ Pipeline: {config.pipeline_name} v{config.pipeline_version}")
print(f"üìç Environment: {config.environment}")

# =============================================================================

In [None]:
# =============================================================================
# STEP 2: TELEMETRY INITIALIZATION
# =============================================================================

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(config.pipeline_name)

@dataclass
class PipelineMetrics:
    """Track metrics throughout pipeline execution."""
    run_id: str
    start_time: datetime
    end_time: Optional[datetime] = None
    status: str = "running"
    
    # Stage metrics
    rows_ingested: int = 0
    rows_after_validation: int = 0
    rows_after_transform: int = 0
    rows_written: int = 0
    
    # Business metrics
    total_menu_items: int = 0
    total_brands: int = 0
    avg_profit_margin: float = 0.0
    
    # Quality metrics
    validation_passed: bool = False
    quality_score: float = 0.0
    warnings: List[str] = None
    
    def __post_init__(self):
        self.warnings = self.warnings or []
    
    def add_warning(self, message: str):
        self.warnings.append(message)
        logger.warning(f"‚ö†Ô∏è  {message}")
    
    def duration_seconds(self) -> float:
        end = self.end_time or datetime.now()
        return (end - self.start_time).total_seconds()

metrics = PipelineMetrics(
    run_id=str(uuid.uuid4())[:8],
    start_time=datetime.now()
)

logger.info(f"{'='*60}")
logger.info(f"üöÄ PIPELINE START | Run ID: {metrics.run_id}")
logger.info(f"{'='*60}")

# =============================================================================

In [None]:
# =============================================================================
# STEP 3: DATA INGESTION
# =============================================================================

def setup_snowflake_context(session, config: PipelineConfig) -> str:
    """Set up Snowflake role, warehouse, database, and schema."""
    logger.info("Setting up Snowflake context...")
    
    session.sql(f"USE ROLE {config.role}").collect()
    session.sql(f"USE WAREHOUSE {config.warehouse}").collect()
    session.sql(f"USE DATABASE {config.database}").collect()
    
    # Create user-specific schema
    current_user = session.sql("SELECT current_user()").collect()[0][0]
    schema_name = f"{current_user}_MENU_ANALYTICS"
    session.sql(f"CREATE SCHEMA IF NOT EXISTS {schema_name}").collect()
    session.sql(f"USE SCHEMA {schema_name}").collect()
    
    logger.info(f"  Context: {config.database}.{schema_name}")
    return schema_name

def create_stage_and_table(session, config: PipelineConfig):
    """Create external stage and target table."""
    logger.info("Creating stage and table...")
    
    # Create stage for S3 access
    session.sql(f"""
        CREATE OR REPLACE STAGE blob_stage
        URL = '{config.stage_url}'
        FILE_FORMAT = (TYPE = CSV)
    """).collect()
    
    # Create raw table with proper schema
    session.sql(f"""
        CREATE OR REPLACE TABLE {config.source_table} (
            MENU_ID NUMBER(19,0),
            MENU_TYPE_ID NUMBER(38,0),
            MENU_TYPE VARCHAR,
            TRUCK_BRAND_NAME VARCHAR,
            MENU_ITEM_ID NUMBER(38,0),
            MENU_ITEM_NAME VARCHAR,
            ITEM_CATEGORY VARCHAR,
            ITEM_SUBCATEGORY VARCHAR,
            COST_OF_GOODS_USD NUMBER(38,4),
            SALE_PRICE_USD NUMBER(38,4),
            MENU_ITEM_HEALTH_METRICS_OBJ VARIANT
        )
    """).collect()
    
    logger.info("  ‚úì Stage and table created")

def load_data_from_stage(session, config: PipelineConfig) -> int:
    """Load CSV data using COPY INTO."""
    logger.info(f"Loading data from: {config.stage_path}")
    
    result = session.sql(f"""
        COPY INTO {config.source_table}
        FROM {config.stage_path}
    """).collect()
    
    # Get loaded row count
    count_result = session.sql(f"SELECT COUNT(*) FROM {config.source_table}").collect()
    row_count = count_result[0][0]
    
    logger.info(f"  ‚úì Loaded {row_count:,} rows")
    return row_count

# Execute ingestion
logger.info("STEP 3: Data Ingestion")
schema_name = setup_snowflake_context(session, config)
create_stage_and_table(session, config)
metrics.rows_ingested = load_data_from_stage(session, config)

# Read into Spark DataFrame for processing
df_raw = spark.read.table(config.source_table)
logger.info(f"  ‚úì DataFrame created with {len(df_raw.columns)} columns")

# =============================================================================

## Other Data Ingestion Methods

### Read from Snowflake Tables
```python
df = spark.read.table("MY_DATABASE.MY_SCHEMA.MY_TABLE")
```

```python
df = spark.sql("SELECT * FROM MY_TABLE WHERE status = 'active' LIMIT 1000")
```

---

### Read from Snowflake Stages

```python
# parquet
df = spark.read.parquet("@MY_STAGE/path/to/data.parquet")
```

```python
# CSV
df = spark.read.format("csv") \
    .option("header", True) \
    .option("inferSchema", False) \
    .option("delimiter", ",") \
    .option("quote", '"') \
    .option("escape", "\\") \
    .option("nullValue", "") \
    .option("dateFormat", "yyyy-MM-dd") \
    .load("@MY_STAGE/path/to/data.csv")
```

```python
# JSON
df = spark.read.format("json") \
    .option("multiLine", True) \
    .option("allowComments", True) \
    .option("dateFormat", "yyyy-MM-dd") \
    .load("@MY_STAGE/path/to/data.json")
```

---

### Direct Cloud Storage Access (requires credentials)

```python
# S3
df = spark.read.parquet("s3://my-bucket/path/to/data/")
```

```python
# GCP
df = spark.read.parquet("gs://my-bucket/path/to/data/")
```

```python
# Azure
df = spark.read.parquet("wasbs://container@account.blob.core.windows.net/path/to/data/")
```

---

### Read Multiple Files

```python
# Wildcard pattern
# Match all parquet files in a directory
df = spark.read.parquet("@MY_STAGE/data/*.parquet")

# Match files with a naming pattern
df = spark.read.csv("@MY_STAGE/logs/2024-*/events_*.csv")

# Recursive directory search
df = spark.read.option("recursiveFileLookup", True).parquet("@MY_STAGE/nested_data/")
```


### Handling Large CSVs

#### ‚ùå Slow Pattern
Scans entire file just to guess column types.

```python
df = spark.read.format("csv") \
    .option("header", True) \
    .option("inferSchema", True) \
    .load("@MY_STAGE/large_file.csv")
```

#### ‚úÖ Fast Pattern (define schema explicitly)

```python
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, BooleanType

schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), True),
    StructField("amount", DoubleType(), True),
    StructField("created_at", TimestampType(), True),
    StructField("is_active", BooleanType(), True)
])

df = spark.read.format("csv") \
    .option("header", True) \
    .schema(schema) \
    .load("@MY_STAGE/large_file.csv")
```

#### ‚úÖ Compressed Files (auto-decompressed on read)

```python
df = spark.read.format("csv") \
    .option("header", True) \
    .schema(schema) \
    .load("@MY_STAGE/data.csv.gz")
```

#### ‚úÖ Best Practice: Convert CSV to Parquet

```python
# Read CSV once
df_csv = spark.read.format("csv") \
    .option("header", True) \
    .schema(schema) \
    .load("@MY_STAGE/raw.csv")

# Write as Parquet
df_csv.write.mode("overwrite").parquet("@MY_STAGE/optimized/data.parquet")

# Read from Parquet for all future queries
df = spark.read.parquet("@MY_STAGE/optimized/data.parquet")
```


In [None]:
# =============================================================================
# STEP 4: DATA VALIDATION
# =============================================================================

class ValidationError(Exception):
    """Raised when critical validation fails."""
    pass

def validate_menu_data(
    df: DataFrame,
    config: PipelineConfig,
    metrics: PipelineMetrics
) -> DataFrame:
    """
    Validate menu data before processing.
    
    Checks:
    - Row count within expected range
    - Required columns present
    - No nulls in key columns
    - No duplicate menu items
    - Prices are positive
    """
    logger.info("STEP 4: Data Validation")
    errors = []
    row_count = metrics.rows_ingested
    
    # Check 1: Row count bounds
    if row_count < config.min_expected_rows:
        errors.append(f"Row count {row_count:,} below minimum {config.min_expected_rows:,}")
    if row_count > config.max_expected_rows:
        errors.append(f"Row count {row_count:,} exceeds maximum {config.max_expected_rows:,}")
    else:
        logger.info(f"  ‚úì Row count ({row_count:,}) within bounds")
    
    # Check 2: Required columns
    required = ["MENU_ITEM_ID", "MENU_ITEM_NAME", "TRUCK_BRAND_NAME", 
                "COST_OF_GOODS_USD", "SALE_PRICE_USD"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        errors.append(f"Missing required columns: {missing}")
    else:
        logger.info(f"  ‚úì All {len(required)} required columns present")
    
    # Check 3: Null check on key columns
    key_columns = ["MENU_ITEM_ID", "MENU_ITEM_NAME", "COST_OF_GOODS_USD", "SALE_PRICE_USD"]
    for col_name in key_columns:
        null_count = df.filter(col(col_name).isNull()).count()
        null_pct = null_count / row_count if row_count > 0 else 0
        if null_pct > config.max_null_percentage:
            errors.append(f"Column '{col_name}' has {null_pct:.1%} nulls")
        elif null_pct > 0:
            metrics.add_warning(f"Column '{col_name}' has {null_pct:.1%} nulls")
    
    if not any("nulls" in str(e) for e in errors):
        logger.info("  ‚úì Null check passed")
    
    # Check 4: No duplicate menu items
    unique_items = df.select("MENU_ITEM_ID").distinct().count()
    if unique_items < row_count:
        duplicates = row_count - unique_items
        metrics.add_warning(f"Found {duplicates:,} duplicate menu items")
    else:
        logger.info("  ‚úì No duplicate menu items")
    
    # Check 5: Positive prices
    negative_prices = df.filter(
        (col("COST_OF_GOODS_USD") < 0) | (col("SALE_PRICE_USD") < 0)
    ).count()
    if negative_prices > 0:
        errors.append(f"Found {negative_prices:,} items with negative prices")
    else:
        logger.info("  ‚úì All prices are positive")
    
    # Handle validation results
    if errors:
        for error in errors:
            logger.error(f"  ‚úó VALIDATION FAILED: {error}")
        raise ValidationError(f"Validation failed with {len(errors)} error(s)")
    
    metrics.validation_passed = True
    metrics.rows_after_validation = row_count
    logger.info("  ‚úÖ Validation PASSED")
    
    return df

# Execute validation
df_validated = validate_menu_data(df_raw, config, metrics)

# =============================================================================

## Data Transformations
### Feature Support Matrix:

Understanding what PySpark features are supported helps you write efficient code.

#### ‚úÖ Fully Supported DataFrame Operations:
- `select`, `filter`, `where`
- `groupBy`, `agg` (all aggregation functions)
- `join` (inner, left, right, outer, broadcast)
- `orderBy`, `sort`
- `distinct`, `dropDuplicates`
- Window functions (`row_number`, `rank`, `lag`, `lead`, etc.)
- Built-in functions (95%+ coverage)
- `cache`, `persist` (creates temp tables in Snowflake)

#### ‚ö†Ô∏è Limited Support:
- `repartition` (logical operation only)
- `coalesce` (similar to repartition)
- Python UDFs (work but slow - avoid if possible)
- Pandas UDFs (work but slow - avoid if possible)
- MLlib (partial - transformers work, estimators limited)

#### ‚ùå NOT Supported:
- RDD API completely
- `.rdd`, `.foreach()`, `.foreachPartition()`
- Structured Streaming
- GraphX
- Custom data sources
- `.checkpoint()`


### Data Types Support

**‚úÖ Supported:**
- String, Integer, Long, Float, Double, Decimal
- Boolean, Date, Timestamp
- Array, Map, Struct
- Binary

**‚ùå Not Supported:**
- DayTimeIntervalType
- YearMonthIntervalType
- UserDefinedTypes

#### Supported File Formats:
- ‚úÖ Parquet, CSV, JSON, Avro, ORC
- ‚ùå Delta Lake, Hudi not supported


In [None]:
# =============================================================================
# STEP 5: TRANSFORMATIONS
# =============================================================================

def transform_menu_data(
    df: DataFrame,
    metrics: PipelineMetrics
) -> DataFrame:
    """
    Apply business transformations to menu data.
    
    Stages:
    5A. Data Cleaning
    5B. Profit Calculations
    5C. Categorization
    5D. Aggregation by Brand
    """
    logger.info("STEP 5: Transformations")
    
    # -------------------------------------------------------------------------
    # STAGE 5A: Data Cleaning
    # -------------------------------------------------------------------------
    logger.info("  5A: Data Cleaning")
    
    df_clean = df \
        .withColumn("TRUCK_BRAND_NAME", trim(upper(col("TRUCK_BRAND_NAME")))) \
        .withColumn("ITEM_CATEGORY", trim(upper(col("ITEM_CATEGORY")))) \
        .withColumn("MENU_TYPE", trim(upper(col("MENU_TYPE")))) \
        .filter(col("COST_OF_GOODS_USD").isNotNull()) \
        .filter(col("SALE_PRICE_USD").isNotNull())
    
    # -------------------------------------------------------------------------
    # STAGE 5B: Profit Calculations
    # -------------------------------------------------------------------------
    logger.info("  5B: Profit Calculations")
    
    df_with_profit = df_clean \
        .withColumn(
            "PROFIT_USD",
            round(col("SALE_PRICE_USD") - col("COST_OF_GOODS_USD"), 2)
        ) \
        .withColumn(
            "PROFIT_MARGIN_PCT",
            round(
                (col("SALE_PRICE_USD") - col("COST_OF_GOODS_USD")) / 
                col("SALE_PRICE_USD") * 100, 
                2
            )
        )
    
    # -------------------------------------------------------------------------
    # STAGE 5C: Categorization
    # -------------------------------------------------------------------------
    logger.info("  5C: Profit Categorization")
    
    df_categorized = df_with_profit \
        .withColumn(
            "PROFIT_TIER",
            when(col("PROFIT_MARGIN_PCT") >= 70, "Premium")
            .when(col("PROFIT_MARGIN_PCT") >= 50, "High")
            .when(col("PROFIT_MARGIN_PCT") >= 30, "Medium")
            .otherwise("Low")
        ) \
        .withColumn(
            "PRICE_TIER",
            when(col("SALE_PRICE_USD") >= 10, "Premium")
            .when(col("SALE_PRICE_USD") >= 5, "Mid-Range")
            .otherwise("Value")
        )
    
    # -------------------------------------------------------------------------
    # STAGE 5D: Aggregation by Brand
    # -------------------------------------------------------------------------
    logger.info("  5D: Aggregation by Brand")
    
    df_brand_summary = df_categorized.groupBy(
        "TRUCK_BRAND_NAME",
        "MENU_TYPE"
    ).agg(
        count("*").alias("ITEM_COUNT"),
        round(avg("COST_OF_GOODS_USD"), 2).alias("AVG_COST_USD"),
        round(avg("SALE_PRICE_USD"), 2).alias("AVG_PRICE_USD"),
        round(avg("PROFIT_USD"), 2).alias("AVG_PROFIT_USD"),
        round(avg("PROFIT_MARGIN_PCT"), 2).alias("AVG_MARGIN_PCT"),
        round(min("PROFIT_USD"), 2).alias("MIN_PROFIT_USD"),
        round(max("PROFIT_USD"), 2).alias("MAX_PROFIT_USD"),
        round(sum("PROFIT_USD"), 2).alias("TOTAL_POTENTIAL_PROFIT_USD")
    ).orderBy(col("AVG_MARGIN_PCT").desc())
    
    # Add metadata
    df_final = df_brand_summary \
        .withColumn("PIPELINE_RUN_ID", lit(metrics.run_id)) \
        .withColumn("PROCESSED_AT", current_timestamp())
    
    # Update metrics
    metrics.rows_after_transform = df_final.count()
    metrics.total_brands = df_final.select("TRUCK_BRAND_NAME").distinct().count()
    metrics.total_menu_items = df_categorized.count()
    
    # Calculate overall average margin
    avg_margin = df_final.agg(avg("AVG_MARGIN_PCT")).collect()[0][0]
    metrics.avg_profit_margin = float(avg_margin) if avg_margin else 0.0
    
    logger.info(f"  ‚úì Aggregated to {metrics.rows_after_transform:,} brand/menu-type combinations")
    logger.info(f"  ‚úì {metrics.total_brands} unique brands analyzed")
    logger.info(f"  ‚úì Average profit margin: {metrics.avg_profit_margin:.1f}%")
    
    return df_final

# Execute transformations
df_transformed = transform_menu_data(df_validated, metrics)
df_transformed.cache()

# Preview results
logger.info("\nüìä Brand Summary Preview:")
df_transformed.select(
    "TRUCK_BRAND_NAME", "MENU_TYPE", "ITEM_COUNT", 
    "AVG_PRICE_USD", "AVG_PROFIT_USD", "AVG_MARGIN_PCT"
).show(10, truncate=False)

# =============================================================================

## Data Transformation API Reference

All transformations push down to Snowflake SQL - no data leaves the warehouse until you explicitly collect results.

### Required Imports

```python
from pyspark.sql.functions import (
    col, lit, when, coalesce,
    sum, avg, count, min, max, countDistinct, collect_list,
    year, month, dayofmonth, hour, minute, second, dayofweek, dayofyear, weekofyear, quarter,
    to_date, to_timestamp, date_add, date_sub, months_between, datediff, date_trunc,
    upper, lower, initcap, trim, ltrim, rtrim, lpad, rpad,
    concat, substring, regexp_replace, regexp_extract,
    row_number, rank, dense_rank, percent_rank, lag, lead, first, last,
    broadcast
)
from pyspark.sql.window import Window
```

---

### 1. Selecting and Filtering

**Select columns:**

```python
df.select("id", "name", "amount")
df.select(col("id"), col("name").alias("customer_name"))  # with rename
```

**Filter rows:**

```python
df.filter(col("status") == "active")
df.where(col("amount") > 100)
df.filter((col("status") == "active") & (col("amount") > 100))  # AND
df.filter((col("status") == "active") | (col("amount") > 1000)) # OR
```

**Remove duplicates:**

```python
df.distinct()                                    # all columns
df.dropDuplicates(["customer_id", "order_date"]) # specific columns
```

---

### 2. Adding and Modifying Columns

**Add calculated columns:**

```python
df.withColumn("total", col("price") * col("quantity"))
df.withColumn("discounted_price", col("price") * 0.9)
df.withColumn("status_flag", when(col("status") == "active", 1).otherwise(0))
```

**Rename columns:**

```python
df.withColumnRenamed("old_name", "new_name")
df.toDF("col1", "col2", "col3")  # rename all at once
```

**Cast data types:**

```python
df.withColumn("amount", col("amount").cast("double"))
df.withColumn("event_date", col("timestamp_col").cast("date"))
```

**Handle null values:**

```python
df.fillna(0, subset=["amount"])                         # single column
df.fillna({"amount": 0, "name": "Unknown"})             # multiple columns
df.withColumn("value", coalesce(col("value"), lit(0)))  # using coalesce
```

---

### 3. Aggregations

**Group and aggregate:**

```python
df.groupBy("category").agg(
    count("*").alias("row_count"),
    countDistinct("customer_id").alias("unique_customers"),
    sum("amount").alias("total_amount"),
    avg("amount").alias("avg_amount"),
    min("amount").alias("min_amount"),
    max("amount").alias("max_amount"),
    collect_list("product_name").alias("products")
)
```

**Multiple grouping columns:**

```python
df.groupBy("category", "region").agg(
    sum("amount").alias("total_sales"),
    avg("amount").alias("average_sale")
)
```

---

### 4. Joins

| Join Type | Syntax |
|-----------|--------|
| **Inner** | `df1.join(df2, "id", "inner")` |
| **Left** | `df1.join(df2, "customer_id", "left")` |
| **Right** | `df1.join(df2, "customer_id", "right")` |
| **Outer** | `df1.join(df2, "customer_id", "outer")` |
| **Cross** | `df1.crossJoin(df2)` |

**Explicit join condition:**

```python
df1.join(df2, df1["id"] == df2["id"], "inner")
```

**Broadcast small tables** (performance optimization):

```python
df_large.join(broadcast(df_small), "key_column", "left")
```

**Handle column name conflicts:**

```python
df_joined = df1.alias("a").join(df2.alias("b"), col("a.id") == col("b.id"))
df_result = df_joined.select(col("a.id"), col("a.name"), col("b.value"))
```

---

### 5. Window Functions

**Define window specification:**

```python
window_spec = Window.partitionBy("customer_id").orderBy("order_date")

# With frame bounds
window_frame = Window.partitionBy("customer_id") \
    .orderBy("order_date") \
    .rowsBetween(-2, 0)
```

**Ranking functions:**

```python
df.withColumn("row_num", row_number().over(window_spec))
df.withColumn("rank", rank().over(window_spec))              # gaps for ties
df.withColumn("dense_rank", dense_rank().over(window_spec))  # no gaps
df.withColumn("pct_rank", percent_rank().over(window_spec))
```

**Analytic functions:**

```python
df.withColumn("prev_amount", lag("amount", 1).over(window_spec))
df.withColumn("next_amount", lead("amount", 1).over(window_spec))
df.withColumn("first_order", first("amount").over(window_spec))
df.withColumn("last_order", last("amount").over(window_spec))
```

**Running calculations:**

```python
window_running = Window.partitionBy("customer_id") \
    .orderBy("order_date") \
    .rowsBetween(Window.unboundedPreceding, 0)

df.withColumn("running_total", sum("amount").over(window_running))
df.withColumn("running_avg", avg("amount").over(window_running))
```

---

### 6. Date and Time Operations

**Extract components:**

| Function | Example |
|----------|---------|
| `year()` | `df.withColumn("year", year("ts"))` |
| `month()` | `df.withColumn("month", month("ts"))` |
| `dayofmonth()` | `df.withColumn("day", dayofmonth("ts"))` |
| `hour()` | `df.withColumn("hour", hour("ts"))` |
| `dayofweek()` | `df.withColumn("dow", dayofweek("ts"))` |
| `quarter()` | `df.withColumn("qtr", quarter("ts"))` |

**Parse strings to dates:**

```python
df.withColumn("date_col", to_date("date_string", "yyyy-MM-dd"))
df.withColumn("ts_col", to_timestamp("ts_string", "yyyy-MM-dd HH:mm:ss"))
```

**Date arithmetic:**

```python
df.withColumn("next_week", date_add("date_col", 7))
df.withColumn("last_week", date_sub("date_col", 7))
df.withColumn("days_diff", datediff("end_date", "start_date"))
df.withColumn("months_diff", months_between("end_date", "start_date"))
df.withColumn("month_start", date_trunc("month", "date_col"))
```

---

### 7. String Operations

**Case transformation:**

```python
df.withColumn("upper_name", upper("name"))
df.withColumn("lower_name", lower("name"))
df.withColumn("title_name", initcap("name"))
```

**Pattern matching:**

```python
df.withColumn("has_email", col("text").contains("@"))
df.withColumn("cleaned", regexp_replace("text", "[^a-zA-Z0-9]", ""))
df.withColumn("domain", regexp_extract("email", "@(.+)$", 1))
```

**Trim and pad:**

```python
df.withColumn("trimmed", trim("text"))
df.withColumn("padded_id", lpad("id", 10, "0"))    # "42" ‚Üí "0000000042"
df.withColumn("padded_right", rpad("code", 5, "X")) # "AB" ‚Üí "ABXXX"
```

---

### 8. Sorting and Limiting

**Sort results:**

```python
df.orderBy("name")                                            # ascending
df.orderBy(col("amount").desc())                              # descending
df.orderBy(col("category").asc(), col("amount").desc())       # multiple
df.orderBy(col("value").asc_nulls_first())                    # nulls first
df.orderBy(col("value").desc_nulls_last())                    # nulls last
```

**Limit rows:**

```python
df.limit(10)                                    # first 10 rows
df.orderBy(col("amount").desc()).limit(10)     # top 10 by amount
```


In [None]:
# =============================================================================
# STEP 6: DATA QUALITY CHECKS
# =============================================================================

def quality_checks(
    df: DataFrame,
    config: PipelineConfig,
    metrics: PipelineMetrics
) -> float:
    """Validate output data quality before writing."""
    logger.info("STEP 6: Data Quality Checks")
    checks_passed = 0
    total_checks = 4
    
    row_count = metrics.rows_after_transform
    
    # Check 1: Non-zero output
    if row_count > 0:
        checks_passed += 1
        logger.info(f"  ‚úì Output has {row_count:,} rows")
    else:
        logger.error("  ‚úó Output is empty!")
    
    # Check 2: No duplicate keys
    key_cols = ["TRUCK_BRAND_NAME", "MENU_TYPE"]
    distinct_keys = df.select(key_cols).distinct().count()
    if distinct_keys == row_count:
        checks_passed += 1
        logger.info("  ‚úì No duplicate brand/menu-type combinations")
    else:
        logger.error(f"  ‚úó Found {row_count - distinct_keys:,} duplicate keys")
    
    # Check 3: All margins are reasonable (not negative)
    negative_margins = df.filter(col("AVG_MARGIN_PCT") < config.min_profit_margin).count()
    if negative_margins == 0:
        checks_passed += 1
        logger.info("  ‚úì All profit margins are positive")
    else:
        metrics.add_warning(f"Found {negative_margins:,} brands with negative margins")
    
    # Check 4: Data completeness
    null_brands = df.filter(col("TRUCK_BRAND_NAME").isNull()).count()
    if null_brands == 0:
        checks_passed += 1
        logger.info("  ‚úì All records have brand names")
    else:
        logger.error(f"  ‚úó Found {null_brands:,} records without brand names")
    
    quality_score = checks_passed / total_checks
    logger.info(f"\n  üìà Quality Score: {quality_score:.0%} ({checks_passed}/{total_checks} checks)")
    
    return quality_score

# Execute quality checks
metrics.quality_score = quality_checks(df_transformed, config, metrics)

# =============================================================================

In [None]:
# =============================================================================
# STEP 7: WRITE OUTPUT
# =============================================================================

def write_output(
    df: DataFrame,
    config: PipelineConfig,
    metrics: PipelineMetrics
) -> int:
    """Write transformed data to Snowflake table."""
    logger.info("STEP 7: Write Output")
    logger.info(f"  Destination: {config.output_table}")
    logger.info(f"  Mode: {config.output_mode}")
    
    # Write to Snowflake
    df.write.mode(config.output_mode).saveAsTable(config.output_table)
    
    # Verify
    written_count = spark.read.table(config.output_table).count()
    logger.info(f"  ‚úì Rows written: {written_count:,}")
    
    return written_count

# Execute write
metrics.rows_written = write_output(df_transformed, config, metrics)

# =============================================================================


## Other Ways to Write Data

Write transformed data back to Snowflake tables, stages, or cloud storage.

---

### Write to Snowflake Tables

```python
# Overwrite replaces all existing data
df.write.mode("overwrite").saveAsTable("MY_TABLE")

# Append adds new rows to existing data (use for incremental loads)
df.write.mode("append").saveAsTable("MY_TABLE")

# Ignore skips write if destination exists (use for idempotent operations)
df.write.mode("ignore").saveAsTable("MY_TABLE")

# Error fails if destination exists
df.write.mode("error").saveAsTable("MY_TABLE")
```

### Write to Snowflake Stages

```python
# Parquet (recommended)
df.write.mode("overwrite").parquet("@MY_STAGE/output/data.parquet")

# CSV
df.write.mode("overwrite") \
    .option("header", True) \
    .option("compression", "gzip") \
    .option("delimiter", ",") \
    .option("quote", '"') \
    .csv("@MY_STAGE/output/data.csv")

# JSON
df.write.mode("overwrite") \
    .option("compression", "gzip") \
    .json("@MY_STAGE/output/data.json")
```


### Write to Cloud Storage

```python
# S3
df.write.mode("overwrite").parquet("s3://my-bucket/output/data/")
```

```python
# GCP
df.write.mode("overwrite").parquet("gs://my-bucket/output/data/")
```

```python
# Azure
df.write.mode("overwrite").parquet("wasbs://container@account.blob.core.windows.net/output/data/")
```

---

## Performance and Compression


### Partitioning for Performance



```python
# Multi-column partitioning
df.write.mode("overwrite") \
    .partitionBy("year", "month") \
    .parquet("@MY_STAGE/partitioned_data/")


# Single column partitioning
df.write.mode("overwrite") \
    .partitionBy("event_date") \
    .parquet("@MY_STAGE/daily_data/")
```

**Benefits:**
- Faster queries that filter on partition columns
- Enables parallel reads of different partitions
- Easier data management (delete old partitions)

**Considerations:**
- Too many partitions = too many small files
- Choose columns with moderate cardinality
- Date-based partitioning is very common

---

### 5. Compression Options

| Compression | Speed | Ratio | Use Case |
|-------------|-------|-------|----------|
| **Snappy** | Fast | Moderate | Default for Parquet, frequently accessed data |
| **Gzip** | Slow | High | Archival, infrequent access |
| **LZ4** | Very Fast | Moderate | High-throughput workloads |
| **Zstd** | Medium | High | Good balance of speed and compression |
| **None** | N/A | None | Rarely recommended |

```python
# Snappy (default)
df.write.option("compression", "snappy").parquet("@MY_STAGE/data/")

# Gzip
df.write.option("compression", "gzip").parquet("@MY_STAGE/data/")

# LZ4
df.write.option("compression", "lz4").parquet("@MY_STAGE/data/")

# Zstd
df.write.option("compression", "zstd").parquet("@MY_STAGE/data/")

# None
df.write.option("compression", "none").parquet("@MY_STAGE/data/")
```


In [None]:
# =============================================================================
# STEP 8: CLEANUP & SUMMARY
# =============================================================================

def cleanup_and_summarize(df_cached: DataFrame, metrics: PipelineMetrics):
    """Release resources and log final summary."""
    logger.info("STEP 8: Cleanup & Summary")
    
    # Release cache
    try:
        df_cached.unpersist()
        logger.info("  ‚úì Cache released")
    except Exception as e:
        logger.warning(f"  Could not unpersist: {e}")
    
    # Finalize metrics
    metrics.end_time = datetime.now()
    metrics.status = "SUCCESS" if metrics.quality_score >= 0.75 else "COMPLETED_WITH_WARNINGS"
    
    # Print summary
    print(f"\n{'='*60}")
    print("üéâ PIPELINE SUMMARY")
    print(f"{'='*60}")
    print(f"  Run ID:            {metrics.run_id}")
    print(f"  Status:            {metrics.status}")
    print(f"  Duration:          {metrics.duration_seconds():.1f} seconds")
    print(f"  {'‚îÄ'*56}")
    print(f"  Menu Items:        {metrics.total_menu_items:,}")
    print(f"  Brands Analyzed:   {metrics.total_brands}")
    print(f"  Avg Profit Margin: {metrics.avg_profit_margin:.1f}%")
    print(f"  {'‚îÄ'*56}")
    print(f"  Rows In:           {metrics.rows_ingested:,}")
    print(f"  Rows Out:          {metrics.rows_written:,}")
    print(f"  Quality Score:     {metrics.quality_score:.0%}")
    
    if metrics.warnings:
        print(f"  {'‚îÄ'*56}")
        print(f"  ‚ö†Ô∏è  Warnings ({len(metrics.warnings)}):")
        for w in metrics.warnings:
            print(f"      ‚Ä¢ {w}")
    
    print(f"{'='*60}\n")

# Execute cleanup
cleanup_and_summarize(df_transformed, metrics)

# =============================================================================
# BEST PRACTICES
# =============================================================================
#
# ‚úÖ PERFORMANCE:
#    - Use COPY INTO for bulk loading (faster than spark.read for CSV)
#    - Cache DataFrame for multiple operations
#    - Avoid UDFs
#
# ‚úÖ RELIABILITY:
#    - Quality gates before writing output
#    - Type-safe configuration with dataclasses
#    - Proper error handling with custom exceptions
#
# ‚úÖ OBSERVABILITY:
#    - Structured logging at each stage
#    - Unique run ID for tracing
#    - Row counts and business metrics tracked
#
# =============================================================================


## Best Practices

Follow these best practices to get optimal performance from Snowpark Connect.

1. **Use SQL Functions Over UDFs:** Python UDFs require data to be transferred to the client, processed, and sent back - this is 10-100x slower than native operations!
2. **Broadcast Joins for Small Tables:** When joining a large table with a small dimension table, use `broadcast()` to optimize the join.
3. **Cache Frequently Accessed DataFrames:** Caching creates temporary tables in Snowflake for faster repeated access. Remember to `unpersist()` when done!
4. **Minimize Data Movement:** Process data in Snowflake and only transfer final results. Avoid `collect()` on large datasets!
5. **Partition Awareness:** Filter on partitioned columns to enable partition pruning and reduce data scanned.

## Additional Resources

**Official Documentation:**
- [Snowflake Documentation](https://docs.snowflake.com/)
- [PySpark API Reference](https://spark.apache.org/docs/latest/api/python/)