# 01 – Data Overview (PySpark)

This notebook introduces the Taobao CTR data set using **PySpark**.  Since each CSV file contains millions of rows, we limit the ingestion to one million records per file to make the analysis tractable.  Using Spark allows us to handle large data sizes efficiently while still performing exploratory analysis.

## Initialize Spark

We start by creating a Spark session.  Ensure that the `pyspark` package is installed in your environment.  If not, install it via `pip install pyspark`.  The Spark session manages the distributed computation context.

In [3]:
%pip install pyspark pyarrow


Collecting pyspark
  Using cached pyspark-4.0.1-py2.py3-none-any.whl
Collecting py4j==0.10.9.9 (from pyspark)
  Using cached py4j-0.10.9.9-py2.py3-none-any.whl.metadata (1.3 kB)
Using cached py4j-0.10.9.9-py2.py3-none-any.whl (203 kB)
Installing collected packages: py4j, pyspark

   ---------------------------------------- 0/2 [py4j]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- ------------------- 1/2 [pyspark]
   -------------------- --

In [None]:
from pyspark.sql import SparkSession

# Create a Spark session
spark = (
    SparkSession.builder
        .appName("CTR_Data_Overview")
        .config("spark.sql.shuffle.partitions", "200")        # safer default for large data
        .config("spark.driver.memory", "4g")                  # increase driver memory if needed
        .config("spark.executor.memory", "4g")                # increase executor memory
        .getOrCreate()
)

# Reduce log noise
spark.sparkContext.setLogLevel("WARN")

print("Spark version:", spark.version)


PySparkRuntimeError: [JAVA_GATEWAY_EXITED] Java gateway process exited before sending its port number.

## Load raw data (1 million rows each)

We load each CSV file from the `data/raw` directory using Spark’s `read.csv` method.  The `inferSchema` option instructs Spark to infer data types automatically, and `header=True` treats the first line as column names.  To avoid memory overload, we limit each DataFrame to one million rows using `.limit(1_000_000)`.

In [None]:
import os
from pyspark.sql.functions import col

raw_dir = os.path.join('..', 'data', 'raw')

# File paths
user_path = os.path.join(raw_dir, 'user_profile.csv')
ad_path = os.path.join(raw_dir, 'ad_feature.csv')
click_path = os.path.join(raw_dir, 'raw_sample.csv')
behaviour_path = os.path.join(raw_dir, 'behavior_log.csv')

# Read up to 1 million rows from each file
user_df = spark.read.csv(user_path, header=True, inferSchema=True).limit(1_000_000).cache()
ad_df = spark.read.csv(ad_path, header=True, inferSchema=True).limit(1_000_000).cache()
click_df = spark.read.csv(click_path, header=True, inferSchema=True).limit(1_000_000).cache()
behaviour_df = spark.read.csv(behaviour_path, header=True, inferSchema=True).limit(1_000_000).cache()

print('User profile rows:', user_df.count())
print('Ad feature rows:', ad_df.count())
print('Click log rows:', click_df.count())
print('Behaviour log rows:', behaviour_df.count())

## Inspect schemas and sample records

Spark infers data types for each column.  We print the schema of each DataFrame and display a few sample rows.  Note that large DataFrames are only lazily evaluated; calling `.show()` triggers computation.

In [None]:
# Display schema information
for name, df in [('user_profile', user_df), ('ad_feature', ad_df), ('raw_sample', click_df), ('behavior_log', behaviour_df)]:
    print(f'Schema for {name}:')
    df.printSchema()
    print('Sample rows:')
    df.show(5, truncate=False)


## Basic statistics

For a quick numerical summary, we compute descriptive statistics on numeric columns using the `describe()` method.  This method operates on Spark DataFrames, returning counts, means, standard deviations, minima and maxima.

In [None]:
# Compute descriptive statistics for selected numeric columns
# Replace column names according to your data schema
numeric_cols = ['age_level', 'price']
for col_name in numeric_cols:
    print(f'Descriptive stats for {col_name}:')
    click_df.select(col(col_name)).describe().show()
