In [2]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql import functions as func
from pyspark.sql.types import StructField, StructType, IntegerType, FloatType

In [3]:
spark_session = SparkSession.builder.appName('Get total spent by customer').getOrCreate()

In [4]:
schema = StructType([
    StructField(name='cus_id', dataType=IntegerType(), nullable=True),
    StructField('product_id', IntegerType(), True),
    StructField('price', FloatType(), True)
])

In [5]:
data = spark_session.read.schema(schema).csv('../../data/customer-orders.csv')
data.printSchema()

root
 |-- cus_id: integer (nullable = true)
 |-- product_id: integer (nullable = true)
 |-- price: float (nullable = true)



In [10]:
total_spent_per_cus = data.groupBy('cus_id') \
                          .agg(func.round(func.sum('price'), 2).alias('total_spent')) \
                          .sort('total_spent', ascending=False) \
                          .select('cus_id', 'total_spent')
total_spent_per_cus.show(5)

+------+-----------+
|cus_id|total_spent|
+------+-----------+
|    68|    6375.45|
|    73|     6206.2|
|    39|    6193.11|
|    54|    6065.39|
|    71|    5995.66|
+------+-----------+
only showing top 5 rows



In [None]:
spark_session.stop()