In [2]:
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql.window import Window

spark = SparkSession.builder.appName("CSV Reads Optimization").master("local[*]").getOrCreate()

spark

In [3]:
input_path = "/opt/data/ncr_ride_bookings.csv"
df = spark.read.csv(input_path, header=True, inferSchema=False)
df.head()

Row(Date='2024-03-23', Time='12:29:38', Booking ID='"""CNR5884300"""', Booking Status='No Driver Found', Customer ID='"""CID1982111"""', Vehicle Type='eBike', Pickup Location='Palam Vihar', Drop Location='Jhilmil', Avg VTAT='null', Avg CTAT='null', Cancelled Rides by Customer='null', Reason for cancelling by Customer='null', Cancelled Rides by Driver='null', Driver Cancellation Reason='null', Incomplete Rides='null', Incomplete Rides Reason='null', Booking Value='null', Ride Distance='null', Driver Ratings='null', Customer Rating='null', Payment Method='null')

In [4]:
def clean_col(colname: str) -> str:
    return colname.strip().lower().replace(" ", "_").replace(".", "_")

df = df.toDF(*[clean_col(c) for c in df.columns])
df.head()

Row(date='2024-03-23', time='12:29:38', booking_id='"""CNR5884300"""', booking_status='No Driver Found', customer_id='"""CID1982111"""', vehicle_type='eBike', pickup_location='Palam Vihar', drop_location='Jhilmil', avg_vtat='null', avg_ctat='null', cancelled_rides_by_customer='null', reason_for_cancelling_by_customer='null', cancelled_rides_by_driver='null', driver_cancellation_reason='null', incomplete_rides='null', incomplete_rides_reason='null', booking_value='null', ride_distance='null', driver_ratings='null', customer_rating='null', payment_method='null')

In [5]:
df = df.select([
    F.trim(F.regexp_replace(F.col(c), '^"+|"+$', '')).alias(c)
    for c in df.columns
])

df.cache()
row_count = df.count()
print(row_count)


150000


In [7]:
exprs = []
for c in df.columns:
    exprs.append(F.count(F.col(c)).alias(f"{c}_non_nulls"))
    exprs.append(
        F.sum(F.when(F.col(c).cast("double").isNotNull(), 1).otherwise(0))
        .alias(f"{c}_cast_success")
    )

stats = df.agg(*exprs).collect()[0]

numeric_cols, categorical_cols = [], []
for c in df.columns:
    non_nulls = stats[f"{c}_non_nulls"]
    cast_success = stats[f"{c}_cast_success"]

    if non_nulls > 0 and (cast_success / non_nulls) > 0.9:
        numeric_cols.append(c)
    else:
        categorical_cols.append(c)
        
print("NUMERIC COL: ",numeric_cols)
print("CATEGORICAL COL: ",categorical_cols)

NUMERIC COL:  ['avg_vtat']
CATEGORICAL COL:  ['date', 'time', 'booking_id', 'booking_status', 'customer_id', 'vehicle_type', 'pickup_location', 'drop_location', 'avg_ctat', 'cancelled_rides_by_customer', 'reason_for_cancelling_by_customer', 'cancelled_rides_by_driver', 'driver_cancellation_reason', 'incomplete_rides', 'incomplete_rides_reason', 'booking_value', 'ride_distance', 'driver_ratings', 'customer_rating', 'payment_method']


In [12]:
agg_exprs = []
for c in df.columns:
    agg_exprs.append(
        F.sum(F.when(F.col(c).isNull() | (F.trim(F.col(c)) == ""), 1).otherwise(0))
        .alias(f"{c}_nulls")
    )
    agg_exprs.append(F.approx_count_distinct(c).alias(f"{c}_distinct"))

nulls_distinct = df.agg(*agg_exprs).collect()[0].asDict()


In [13]:
df_numeric = df.select([F.col(c).cast("double").alias(c) for c in numeric_cols])

num_exprs = []
for c in numeric_cols:
    num_exprs += [
        F.min(c).alias(f"{c}_min"),
        F.max(c).alias(f"{c}_max"),
        F.mean(c).alias(f"{c}_mean"),
        F.stddev(c).alias(f"{c}_stddev")
    ]

numeric_stats = df_numeric.agg(*num_exprs).collect()[0].asDict()

print("NUMERICA STATS: ", numeric_stats)

NUMERICA STATS:  {'avg_vtat_min': 2.0, 'avg_vtat_max': 20.0, 'avg_vtat_mean': 8.456351971326171, 'avg_vtat_stddev': 3.7735638264095708}


In [14]:
percentiles_dict = {}
for c in numeric_cols:
    percentiles_dict[c] = df_numeric.na.drop().approxQuantile(
        c, [0.01, 0.25, 0.5, 0.75, 0.99], 0.01
    )
print("PERCENTILES DICT: ", percentiles_dict)

PERCENTILES DICT:  {'avg_vtat': [2.0, 5.3, 8.2, 11.2, 20.0]}


In [15]:
if categorical_cols:
    stack_expr = F.expr(
        "stack({0}, {1}) as (column_name, value)".format(
            len(categorical_cols),
            ", ".join([f"'{c}', {c}" for c in categorical_cols])
        )
    )

    cat_df = df.select(stack_expr)
    freq_df = cat_df.groupBy("column_name", "value").count()

    w = Window.partitionBy("column_name").orderBy(F.desc("count"))
    top_values_df = freq_df.withColumn("rank", F.row_number().over(w)) \
                           .filter(F.col("rank") <= 3)

    top_values = {}
    for r in top_values_df.collect():
        top_values.setdefault(r["column_name"], []).append(
            (r["value"], r["count"])
        )
else:
    top_values = {}

In [17]:
skew_info = {}

# ---- Categorical skew (hot key detection)
for c in categorical_cols:
    vals = top_values.get(c)
    if vals:
        top_count = vals[0][1]
        dominance_ratio = round(top_count / row_count, 4) if row_count else None

        if dominance_ratio is None:
            skew_label = "unknown"
        elif dominance_ratio > 0.9:
            skew_label = "HOT_KEY_RISK"
        elif dominance_ratio > 0.7:
            skew_label = "HIGH_SKEW"
        elif dominance_ratio > 0.3:
            skew_label = "SKEWED"
        else:
            skew_label = "BALANCED"

        skew_info[c] = (dominance_ratio, skew_label)

# ---- Numeric skew (distribution skew)
for c in numeric_cols:
    p = percentiles_dict.get(c)
    if p and len(p) == 5:
        p01, p25, p50, p75, p99 = p

        if (p50 - p01) != 0:
            skew_score = round((p99 - p50) / (p50 - p01), 3)

            if skew_score > 1.5:
                skew_label = "RIGHT_SKEWED"
            elif skew_score < 0.7:
                skew_label = "LEFT_SKEWED"
            else:
                skew_label = "SYMMETRIC"
        else:
            skew_score, skew_label = None, "unknown"

        skew_info[c] = (skew_score, skew_label)


In [18]:
report_rows = []

for c in df.columns:
    null_count = nulls_distinct[f"{c}_nulls"]
    null_pct = round((null_count / row_count) * 100, 2) if row_count else None
    distinct_count = nulls_distinct[f"{c}_distinct"]

    cardinality_ratio = round(distinct_count / row_count, 4) if row_count else None

    if cardinality_ratio is None:
        cardinality_level = "unknown"
    elif cardinality_ratio < 0.1:
        cardinality_level = "low"
    elif cardinality_ratio < 0.5:
        cardinality_level = "mid"
    else:
        cardinality_level = "high"

    dtype = "numeric" if c in numeric_cols else "categorical"

    min_val = max_val = mean_val = stddev_val = None
    percentiles = None
    outlier_risk = None
    top_vals = None

    skew_score, skew_label = skew_info.get(c, (None, None))

    if c in numeric_cols:
        min_val = numeric_stats.get(f"{c}_min")
        max_val = numeric_stats.get(f"{c}_max")
        mean_val = numeric_stats.get(f"{c}_mean")
        stddev_val = numeric_stats.get(f"{c}_stddev")
        percentiles = percentiles_dict.get(c)

        if stddev_val is not None and mean_val is not None:
            outlier_risk = "HIGH" if stddev_val > abs(mean_val) else "LOW"

    else:
        top_vals = top_values.get(c)

    if null_pct is not None and null_pct > 30:
        quality_flag = "HIGH_NULLS"
    elif skew_label in ("HOT_KEY_RISK", "HIGH_SKEW"):
        quality_flag = "SKEW_RISK"
    elif cardinality_level == "low":
        quality_flag = "LOW_VARIANCE"
    else:
        quality_flag = "OK"

    report_rows.append((
        c, dtype, null_count, null_pct,
        distinct_count, cardinality_ratio, cardinality_level,
        skew_score, skew_label,
        min_val, max_val, mean_val, stddev_val,
        str(percentiles), outlier_risk, str(top_vals),
        quality_flag
    ))



In [12]:
report_df = spark.createDataFrame(
    report_rows,
    [
        "column_name", "data_type",
        "null_count", "null_pct",
        "distinct_count", "cardinality_ratio", "cardinality_level",
        "skew_score", "skew_label",
        "min_val", "max_val", "mean_val", "stddev_val",
        "percentiles", "outlier_risk", "top_values",
        "quality_flag"
    ]
)

report_df.show(100, truncate=False)

+---------------------------------+-----------+----------+--------+--------------+-----------------+-----------------+----------+------------+-------+-------+-----------------+------------------+---------------------------+------------+----------------------------------------------------------------------------------------------+------------+
|column_name                      |data_type  |null_count|null_pct|distinct_count|cardinality_ratio|cardinality_level|skew_score|skew_label  |min_val|max_val|mean_val         |stddev_val        |percentiles                |outlier_risk|top_values                                                                                    |quality_flag|
+---------------------------------+-----------+----------+--------+--------------+-----------------+-----------------+----------+------------+-------+-------+-----------------+------------------+---------------------------+------------+--------------------------------------------------------------------------