**Create the Dataframes**

In [0]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

transactions_data = [("A", 100), ("A", 200), ("A", 300), ("B", 150), ("C", 250)]
transactions_df = spark.createDataFrame(transactions_data, ["user_id", "amount"])

users_data = [("A", "India"), ("B", "USA"), ("C", "UK")]
users_df = spark.createDataFrame(users_data, ["user_id", "country"])


**Salt the transaction_df**

In [0]:
from pyspark.sql.functions import *

salt_count = 3

# Add a random salt
transactions_salted = transactions_df.withColumn("salt", floor(rand() * salt_count))

# Create salted user_id
transactions_salted = transactions_salted.withColumn("salted_user_id", concat_ws("_", "user_id", "salt"))


In [0]:
# Displaying transactions_df
transactions_salted.display()

user_id,amount,salt,salted_user_id
A,100,1,A_1
A,200,0,A_0
A,300,0,A_0
B,150,0,B_0
C,250,2,C_2


**Adding Salt to users_df (all the values of salt) and then Exploding it**

In [0]:
from pyspark.sql.functions import explode, array, lit

# Create array of salt values [0,1,2] for each row
users_df = users_df.withColumn("salt", array([lit(i) for i in range(salt_count)]))

# Exploding the values of the array 
users_expanded = users_df.withColumn("salt", explode(col('salt')))

# Create salted user_id
users_expanded = users_expanded.withColumn("salted_user_id", concat_ws("_", "user_id", "salt"))


In [0]:
# Displaying users_expanded

users_expanded.display()

user_id,country,salt,salted_user_id
A,India,0,A_0
A,India,1,A_1
A,India,2,A_2
B,USA,0,B_0
B,USA,1,B_1
B,USA,2,B_2
C,UK,0,C_0
C,UK,1,C_1
C,UK,2,C_2


**Applying JOIN**

In [0]:
joined_df = transactions_salted.join(users_expanded, on="salted_user_id", how="inner")

In [0]:
joined_df.display()

salted_user_id,user_id,amount,salt,user_id.1,country,salt.1
A_0,A,200,0,A,India,0
A_0,A,300,0,A,India,0
A_1,A,100,1,A,India,1
B_0,B,150,0,B,USA,0
C_2,C,250,2,C,UK,2
