# TPC-DS 10GiB - Apache Spark acceleration on GPU with RAPIDS Spark

based on https://colab.research.google.com/github/LucaCanali/Miscellaneous/blob/master/Performance_Testing/TPCDS_PySpark/Labs_and_Notes/TPCDS_PySpark_getstarted.ipynb#scrollTo=6bab7772

# Install packages

In [None]:
spark_version='3.5.5'
rapids_version='25.06.0'

In [None]:
%pip install --quiet \
  tpcds_pyspark==1.0.5 \
  pyspark=={spark_version} \
  pandas \
  sparkmeasure==0.23.2 \
  matplotlib

# Import modules

In [None]:
from importlib.resources import files
from pyspark.sql import SparkSession
from tpcds_pyspark import TPCDS
import glob
import os
import pandas as pd
import re
import time

# Download TPC-DS 10GiB Scale Parquet Dataset

In [None]:
if not os.path.isdir('tpcds_10'):
  if not os.path.isfile('tpcds_10.zip'):
    !wget https://sparkdltrigger.web.cern.ch/sparkdltrigger/TPCDS/tpcds_10.zip
  !unzip -q tpcds_10.zip

# Init a SparkSession with RAPIDS Spark

## Detect Scala Version used in PySpark package

In [None]:
pyspark_files = files('pyspark')
spark_sql_jar_path, *_ = glob.glob(f"{pyspark_files}/*/spark-sql_*jar")
spark_sql_jar = os.path.basename(spark_sql_jar_path)
scala_version = re.search(r'^spark-sql_(\d+.\d+)-.*\.jar$', spark_sql_jar).group(1)

## Find spark-measure artifact

In [None]:
tpcds_pyspark_files = files('tpcds_pyspark')
spark_measure_jar_paths = glob.glob(f"{tpcds_pyspark_files}/spark-measure_{scala_version}-*.jar")
assert spark_measure_jar_paths, f"No spark-measure artifact built for Pyspark's Scala version {scala_version}"
spark_measure_jar_paths.sort(reverse=True)
spark_measure_jar_path, *_ = spark_measure_jar_paths

In [None]:
spark = (
    SparkSession.builder
      .appName('TPCDS PySpark RAPIDS=ON/OFF')
      .config('spark.driver.memory', '5g')
      .config('spark.plugins', 'com.nvidia.spark.SQLPlugin')
      .config('spark.jars', spark_measure_jar_path)
      .config('spark.jars.packages', f"com.nvidia:rapids-4-spark_{scala_version}:{rapids_version}")
      .getOrCreate()
)
spark


# Verify SQL Acceleration on GPU can be enabled by checking the query plan

In [None]:
spark.conf.set('spark.rapids.sql.enabled', True)
sum_df = spark.range(1000).selectExpr('SUM(*)')
sum_df.collect()
sum_df.explain()

# TPCDS App

In [None]:
# https://github.com/LucaCanali/Miscellaneous/tree/master/Performance_Testing/TPCDS_PySpark/tpcds_pyspark/Queries

# queries = None to run all (takes much longer)
queries = None
queries = [
    'q14a',
    'q14b',
    'q23a',
    'q23b',
    # 'q24a',
    # 'q24b',
    # 'q88',
]

demo_start = time.time()
tpcds = TPCDS(data_path='./tpcds_10', num_runs=1, queries_repeat_times=1, queries=queries)

## Register TPC-DS tables before running queries

In [None]:
tpcds.map_tables()

## Measure Apache Spark GPU

In [None]:
tpcds.spark.conf.set('spark.rapids.sql.enabled', True)
%time tpcds.run_TPCDS()
gpu_grouped_results = tpcds.grouped_results_pdf.copy()
gpu_grouped_results

## Measure Apache Spark CPU

In [None]:
tpcds.spark.conf.set('spark.rapids.sql.enabled', False)
%time tpcds.run_TPCDS()
cpu_grouped_results = tpcds.grouped_results_pdf.copy()
cpu_grouped_results

## Show Speedup Factors achieved by GPU


In [None]:
res = pd.merge(cpu_grouped_results, gpu_grouped_results, on='query', how='inner', suffixes=['_cpu', '_gpu'])
res['speedup'] = res['elapsedTime_cpu'] / res['elapsedTime_gpu']
res = res.sort_values(by='elapsedTime_cpu', ascending=False)
res

In [None]:
demo_dur = time.time() - demo_start
print(f"CPU and GPU run took: {demo_dur=} seconds")

In [None]:
res.plot(title='TPC-DS query elapsedTime on CPU vs GPU (lower is better)', 
         kind='bar', x='query', y=['elapsedTime_cpu', 'elapsedTime_gpu'],
         color=['blue', '#76B900'])

In [None]:
res.plot(title='Speedup factors of TPC-DS queries on GPU', kind='bar', 
         x='query', y='speedup', color='#76B900')

# Run Queries interactively

In [None]:
query = 'q88'
with open(f"{tpcds_pyspark_files}/Queries/{query}.sql") as f:
  q = f.read()

In [None]:
print(q)

In [None]:
spark.conf.set('spark.rapids.sql.enabled', True)
df  = spark.sql(q)
%time df.collect()