In [1]:
import findspark
findspark.init()

In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.getOrCreate()

In [4]:
spark

In [7]:
from pyspark.sql.functions import *

#### Creating a Dataframe

In [5]:
data = [("A", 100), ("A", 200), ("A", 300), ("B", 400), ("C", 500)]
df = spark.createDataFrame(data, ["user_id", "purchase"])

In [6]:
df.show()

+-------+--------+
|user_id|purchase|
+-------+--------+
|      A|     100|
|      A|     200|
|      A|     300|
|      B|     400|
|      C|     500|
+-------+--------+



#### Salt Column

In [8]:
df = df.withColumn("salt_column", floor(rand()*3))

In [9]:
df.show()

+-------+--------+-----------+
|user_id|purchase|salt_column|
+-------+--------+-----------+
|      A|     100|          0|
|      A|     200|          1|
|      A|     300|          0|
|      B|     400|          0|
|      C|     500|          0|
+-------+--------+-----------+



#### Creating Concata column on original groupBy col and salt_column to create a new groupBy col

In [15]:
df = df.withColumn("user_id_salt", concat(col("user_id"), lit("-"), col("salt_column")))

In [21]:
df.show()

+-------+--------+-----------+------------+
|user_id|purchase|salt_column|user_id_salt|
+-------+--------+-----------+------------+
|      A|     100|          0|         A-0|
|      A|     200|          1|         A-1|
|      A|     300|          0|         A-0|
|      B|     400|          0|         B-0|
|      C|     500|          0|         C-0|
+-------+--------+-----------+------------+



#### Applying Group By on this new col

In [22]:
df = df.groupBy("user_id_salt").agg(sum("purchase"))

In [23]:
df.show()

+------------+-------------+
|user_id_salt|sum(purchase)|
+------------+-------------+
|         A-0|          400|
|         A-1|          200|
|         B-0|          400|
|         C-0|          500|
+------------+-------------+

