In [2]:
import findspark
findspark.init()

import pyspark

In [3]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

22/06/29 18:13:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [12]:
from pyspark.rdd import portable_hash

In [4]:
countries = ("CN", "AU", "US")
data = []
for i in range(1, 13):
    data.append({"ID": i, "Country": countries[i % 3],  "Amount": 10+i})
 
df = spark.createDataFrame(data)
df.show()

+------+-------+---+
|Amount|Country| ID|
+------+-------+---+
|    11|     AU|  1|
|    12|     US|  2|
|    13|     CN|  3|
|    14|     AU|  4|
|    15|     US|  5|
|    16|     CN|  6|
|    17|     AU|  7|
|    18|     US|  8|
|    19|     CN|  9|
|    20|     AU| 10|
|    21|     US| 11|
|    22|     CN| 12|
+------+-------+---+



[Stage 0:>                                                          (0 + 1) / 1]                                                                                

In [14]:
def print_partitions(df):
    numPartitions = df.rdd.getNumPartitions()
    print("Total partitions: {}\n".format(numPartitions))
    print("Partitioner: {}\n".format(df.rdd.partitioner))
    df.explain()
    print("\n")
    parts = df.rdd.glom().collect()
    i = 0
    j = 0
    for p in parts:
        print("\nPartition {}:".format(i))
        for r in p:
            print("Row {}:{}".format(j, r))
            j = j+1
        i = i+1


In [15]:
df = df.repartition(3, "Country")
print_partitions(df)

Total partitions: 3

Partitioner: None

== Physical Plan ==
Exchange hashpartitioning(Country#1, 3), REPARTITION_WITH_NUM, [id=#29]
+- *(1) Scan ExistingRDD[Amount#0L,Country#1,ID#2L]





Partition 0:

Partition 1:
Row 0:Row(Amount=12, Country='US', ID=2)
Row 1:Row(Amount=13, Country='CN', ID=3)
Row 2:Row(Amount=15, Country='US', ID=5)
Row 3:Row(Amount=16, Country='CN', ID=6)
Row 4:Row(Amount=18, Country='US', ID=8)
Row 5:Row(Amount=19, Country='CN', ID=9)
Row 6:Row(Amount=21, Country='US', ID=11)
Row 7:Row(Amount=22, Country='CN', ID=12)

Partition 2:
Row 8:Row(Amount=11, Country='AU', ID=1)
Row 9:Row(Amount=14, Country='AU', ID=4)
Row 10:Row(Amount=17, Country='AU', ID=7)
Row 11:Row(Amount=20, Country='AU', ID=10)


In [18]:
from pyspark.sql import functions as F  
from pyspark.sql import udf  
udf_portable_hash = F.udf(lambda str: portable_hash(str))
df = df.withColumn("Hash#", udf_portable_hash(df.Country))
df = df.withColumn("Partition#", df["Hash#"] % 3)
df.show()

+------+-------+---+--------------------+----------+
|Amount|Country| ID|               Hash#|Partition#|
+------+-------+---+--------------------+----------+
|    12|     US|  2|-8328537658613580243|      -1.0|
|    13|     CN|  3|-7458853143580063552|      -1.0|
|    15|     US|  5|-8328537658613580243|      -1.0|
|    16|     CN|  6|-7458853143580063552|      -1.0|
|    18|     US|  8|-8328537658613580243|      -1.0|
|    19|     CN|  9|-7458853143580063552|      -1.0|
|    21|     US| 11|-8328537658613580243|      -1.0|
|    22|     CN| 12|-7458853143580063552|      -1.0|
|    11|     AU|  1| 6593628092971972691|       0.0|
|    14|     AU|  4| 6593628092971972691|       0.0|
|    17|     AU|  7| 6593628092971972691|       0.0|
|    20|     AU| 10| 6593628092971972691|       0.0|
+------+-------+---+--------------------+----------+



In [19]:
countries = ("CN", "AU", "US")
def country_partitioning(k):
    return countries.index(k)
    
udf_country_hash = F.udf(lambda str: country_partitioning(str))
numPartitions = 3
# df = df.partitionBy(numPartitions, country_partitioning)
df = df.withColumn("Hash#", udf_country_hash(df['Country']))
df = df.withColumn("Partition#", df["Hash#"] % numPartitions)
df.orderBy('Country').show()

+------+-------+---+-----+----------+
|Amount|Country| ID|Hash#|Partition#|
+------+-------+---+-----+----------+
|    20|     AU| 10|    1|       1.0|
|    11|     AU|  1|    1|       1.0|
|    14|     AU|  4|    1|       1.0|
|    17|     AU|  7|    1|       1.0|
|    16|     CN|  6|    0|       0.0|
|    13|     CN|  3|    0|       0.0|
|    22|     CN| 12|    0|       0.0|
|    19|     CN|  9|    0|       0.0|
|    12|     US|  2|    2|       2.0|
|    18|     US|  8|    2|       2.0|
|    21|     US| 11|    2|       2.0|
|    15|     US|  5|    2|       2.0|
+------+-------+---+-----+----------+



In [20]:
print_partitions(df.repartition(3, "Partition#"))

Total partitions: 3

Partitioner: None

== Physical Plan ==
Exchange hashpartitioning(Partition##60, 3), REPARTITION_WITH_NUM, [id=#105]
+- *(2) Project [Amount#0L, Country#1, ID#2L, pythonUDF0#88 AS Hash##54, (cast(pythonUDF0#88 as double) % 3.0) AS Partition##60]
   +- BatchEvalPython [<lambda>(Country#1)], [pythonUDF0#88]
      +- Exchange hashpartitioning(Country#1, 3), REPARTITION_WITH_NUM, [id=#100]
         +- *(1) Scan ExistingRDD[Amount#0L,Country#1,ID#2L]





Partition 0:
Row 0:Row(Amount=12, Country='US', ID=2, Hash#='2', Partition#=2.0)
Row 1:Row(Amount=15, Country='US', ID=5, Hash#='2', Partition#=2.0)
Row 2:Row(Amount=18, Country='US', ID=8, Hash#='2', Partition#=2.0)
Row 3:Row(Amount=21, Country='US', ID=11, Hash#='2', Partition#=2.0)

Partition 1:
Row 4:Row(Amount=13, Country='CN', ID=3, Hash#='0', Partition#=0.0)
Row 5:Row(Amount=16, Country='CN', ID=6, Hash#='0', Partition#=0.0)
Row 6:Row(Amount=19, Country='CN', ID=9, Hash#='0', Partition#=0.0)
Row 7:Row(Amount=22, 