# GPU-Accelerated Spark Connect Demo - ETL and ML Pipeline (Spark 4.0+)

Based on the Data and AI Summit 2025 session: [GPU Accelerated Spark Connect](https://www.databricks.com/dataaisummit/session/gpu-accelerated-spark-connect)


## Connect to Spark via Spark Connect


In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
# Create GPU-accelerated Spark session using Spark Connect 4.0+
spark = (
  SparkSession.builder
    .remote('sc://spark-connect-server')
    .appName('GPU-Accelerated-ETL-ML-Demo') 
    .getOrCreate()
)

print(f'Spark version: {spark.version}')

Spark version: 4.0.0


## 2. Smoke Test GPU


In [2]:
df = (
  spark.range(2 ** 35)
    .withColumn('mod10', col('id') % lit(10))
    .groupBy('mod10').agg(count('*'))
    .orderBy('mod10')
)
df.show()
# workaround to get a plan with GpuOverrides applied by disabling adaptive execution
spark.conf.set('spark.sql.adaptive.enabled', False)
df.explain(mode='formatted')
spark.conf.set('spark.sql.adaptive.enabled', True)

+-----+----------+
|mod10|  count(1)|
+-----+----------+
|    0|3435973837|
|    1|3435973837|
|    2|3435973837|
|    3|3435973837|
|    4|3435973837|
|    5|3435973837|
|    6|3435973837|
|    7|3435973837|
|    8|3435973836|
|    9|3435973836|
+-----+----------+

== Physical Plan ==
GpuColumnarToRow (10)
+- GpuSort (9)
   +- GpuShuffleCoalesce (8)
      +- GpuColumnarExchange (7)
         +- GpuHashAggregate (6)
            +- GpuShuffleCoalesce (5)
               +- GpuColumnarExchange (4)
                  +- GpuHashAggregate (3)
                     +- GpuProject (2)
                        +- GpuRange (1)


(1) GpuRange
Output [1]: [id#298L]
Arguments: 0, 34359738368, 1, 64, [id#298L], 536870912

(2) GpuProject
Input [1]: [id#298L]
Arguments: [(id#298L % 10) AS mod10#333L], true

(3) GpuHashAggregate
Input [1]: [mod10#333L]
Keys [1]: [mod10#333L]
Functions [1]: [partial_gpucount(1, false)]
Aggregate Attributes [1]: [count#336L]
Results [2]: [mod10#333L, count#337L]
Lore: 

(4) G

##  Should GPU Be Used from the next cell on?


In [3]:
accelerate_on_gpu = True

### ETL on GPU?

In [4]:
spark.conf.set('spark.rapids.sql.enabled', accelerate_on_gpu)  

### ML on GPU?

In [5]:
if accelerate_on_gpu:
  spark.conf.set('spark.connect.ml.backend.classes', 'com.nvidia.rapids.ml.Plugin')
else:
  spark.conf.unset('spark.connect.ml.backend.classes')

## Normalize references to the same bank 

In [6]:
import csv
with open('work/name_mapping.csv', 'r') as name_mapping_file:
  nm_reader = csv.reader(name_mapping_file,)
  name_mapping = [r for r in nm_reader]
name_mapping_df = spark.createDataFrame(name_mapping, ['from_seller_name', 'to_seller_name'])

(
  name_mapping_df
    .where(col('to_seller_name') == 'Wells Fargo' )
    .show(truncate=False)
)

+------------------------------------------------------+--------------+
|from_seller_name                                      |to_seller_name|
+------------------------------------------------------+--------------+
|WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015|Wells Fargo   |
|WELLS FARGO BANK,  NA                                 |Wells Fargo   |
|WELLS FARGO BANK, N.A.                                |Wells Fargo   |
|WELLS FARGO BANK, NA                                  |Wells Fargo   |
+------------------------------------------------------+--------------+



In [7]:
# String columns
cate_col_names = [
  'orig_channel',
  'first_home_buyer',
  'loan_purpose',
  'property_type',
  'occupancy_status',
  'property_state',
  'product_type',
  'relocation_mortgage_indicator',
  'seller_name',
  'mod_flag'
]
# Numeric columns
label_col_name = 'delinquency_12'
numeric_col_names = [
  'orig_interest_rate',
  'orig_upb',
  'orig_loan_term',
  'orig_ltv',
  'orig_cltv',
  'num_borrowers',
  'dti',
  'borrower_credit_score',
  'num_units',
  'zip',
  'mortgage_insurance_percent',
  'current_loan_delinquency_status',
  'current_actual_upb',
  'interest_rate',
  'loan_age',
  'msa',
  'non_interest_bearing_upb',
  label_col_name
]
all_col_names = cate_col_names + numeric_col_names

## Define ETL Process

### Functions to read raw columns

In [8]:
def read_raw_csv(spark, path):
  def _get_quarter_from_csv_file_name():
    return substring_index(substring_index(input_file_name(), '.', 1), '/', -1)

  with open('csv_raw_schema.ddl', 'r') as f:
    _csv_raw_schema_str = f.read()
  
  return (
    spark.read
      .format('csv') 
      .option('nullValue', '') 
      .option('header', False) 
      .option('delimiter', '|') 
      .schema(_csv_raw_schema_str) 
      .load(path) 
      .withColumn('quarter', _get_quarter_from_csv_file_name())
  )

def extract_perf_columns(rawDf):
  perfDf = rawDf.select(
    col('loan_id'),
    date_format(to_date(col('monthly_reporting_period'),'MMyyyy'), 'MM/dd/yyyy').alias('monthly_reporting_period'),
    upper(col('servicer')).alias('servicer'),
    col('interest_rate'),
    col('current_actual_upb'),
    col('loan_age'),
    col('remaining_months_to_legal_maturity'),
    col('adj_remaining_months_to_maturity'),
    date_format(to_date(col('maturity_date'),'MMyyyy'), 'MM/yyyy').alias('maturity_date'),
    col('msa'),
    col('current_loan_delinquency_status'),
    col('mod_flag'),
    col('zero_balance_code'),
    date_format(to_date(col('zero_balance_effective_date'),'MMyyyy'), 'MM/yyyy').alias('zero_balance_effective_date'),
    date_format(to_date(col('last_paid_installment_date'),'MMyyyy'), 'MM/dd/yyyy').alias('last_paid_installment_date'),
    date_format(to_date(col('foreclosed_after'),'MMyyyy'), 'MM/dd/yyyy').alias('foreclosed_after'),
    date_format(to_date(col('disposition_date'),'MMyyyy'), 'MM/dd/yyyy').alias('disposition_date'),
    col('foreclosure_costs'),
    col('prop_preservation_and_repair_costs'),
    col('asset_recovery_costs'),
    col('misc_holding_expenses'),
    col('holding_taxes'),
    col('net_sale_proceeds'),
    col('credit_enhancement_proceeds'),
    col('repurchase_make_whole_proceeds'),
    col('other_foreclosure_proceeds'),
    col('non_interest_bearing_upb'),
    col('principal_forgiveness_upb'),
    col('repurchase_make_whole_proceeds_flag'),
    col('foreclosure_principal_write_off_amount'),
    col('servicing_activity_indicator'),
    col('quarter')
  )
  return perfDf.select('*').filter('current_actual_upb != 0.0')

def extract_acq_columns(rawDf):
  acqDf = rawDf.select(
    col('loan_id'),
    col('orig_channel'),
    upper(col('seller_name')).alias('seller_name'),
    col('orig_interest_rate'),
    col('orig_upb'),
    col('orig_loan_term'),
    date_format(to_date(col('orig_date'),'MMyyyy'), 'MM/yyyy').alias('orig_date'),
    date_format(to_date(col('first_pay_date'),'MMyyyy'), 'MM/yyyy').alias('first_pay_date'),
    col('orig_ltv'),
    col('orig_cltv'),
    col('num_borrowers'),
    col('dti'),
    col('borrower_credit_score'),
    col('first_home_buyer'),
    col('loan_purpose'),
    col('property_type'),
    col('num_units'),
    col('occupancy_status'),
    col('property_state'),
    col('zip'),
    col('mortgage_insurance_percent'),
    col('product_type'),
    col('coborrow_credit_score'),
    col('mortgage_insurance_type'),
    col('relocation_mortgage_indicator'),
    dense_rank().over(Window.partitionBy('loan_id').orderBy(to_date(col('monthly_reporting_period'),'MMyyyy'))).alias('rank'),
    col('quarter')
  )

  return acqDf.select('*').filter(col('rank')==1)