## **Creating a dataframe and removing data skewness**

In [0]:


# Transactions table (big table, skewed)
Transactions = spark.createDataFrame([
    (1, 101, "R1", 100),
    (2, 101, "R1", 200),
    (3, 101, "R1", 150),
    (4, 102, "R2", 300),
    (5, 103, "R3", 400),
    (6, 101, "R1", 250),  # heavily skewed customer_id=101, region=R1
], ["txn_id", "customer_id", "region_id", "amount"])

# Customers table (small table)
Customers = spark.createDataFrame([
    (101, "R1", "Ravi", "Hyderabad"),
    (102, "R2", "Priya", "Mumbai"),
    (103, "R3", "John", "Delhi"),
], ["customer_id", "region_id", "name", "city"])


In [0]:

#normal join 

df_normal = Transactions.join(Customers, ["customer_id", "region_id"], "inner")
df_normal.show()


In [0]:
#broadcast
from pyspark.sql.functions import broadcast
df_broadcast = Transactions.join(broadcast(Customers), ["customer_id", "region_id"], "inner")
df_broadcast.show()


In [0]:
from pyspark.sql.functions import rand, explode, lit, array

# Step 4.1: Add salts to Transactions
Transactions_salted = (
    Transactions
    .withColumn("cust_salt", (rand()*3).cast("int"))   # 3 salts for customer_id
    .withColumn("region_salt", (rand()*2).cast("int")) # 2 salts for region_id
)

# Step 4.2: Replicate Customers table with same salts
Customers_salted = (
    Customers
    .withColumn("cust_salt", explode(array([lit(i) for i in range(3)])))
    .withColumn("region_salt", explode(array(lit(0), lit(1))))
)

# Step 4.3: Join with salts
df_salted = Transactions_salted.join(
    Customers_salted,
    ["customer_id", "region_id", "cust_salt", "region_salt"],
    "inner"
)

df_salted.show()


In [0]:
print("Normal Join Result:")
df_normal.show()

print("Broadcast Join Result:")
df_broadcast.show()

print("Salting Join Result (2 skewed keys):")
df_salted.show()
