In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import expr

In [2]:
spark = SparkSession.builder.appName('pyspark-by-examples').getOrCreate()

data = [("Banana", 1000, "USA"), ("Carrots", 1500, "USA"), ("Beans", 1600, "USA"),
        ("Orange", 2000, "USA"), ("Orange", 2000, "USA"), ("Banana", 400, "China"),
        ("Carrots", 1200, "China"), ("Beans", 1500, "China"), ("Orange", 4000, "China"),
        ("Banana", 2000, "Canada"), ("Carrots", 2000, "Canada"), ("Beans", 2000, "Mexico")]

df = spark.createDataFrame(data=data, schema=["Product", "Amount", "Country"])
df.printSchema()
df.show(truncate=False)

root
 |-- Product: string (nullable = true)
 |-- Amount: long (nullable = true)
 |-- Country: string (nullable = true)

+-------+------+-------+
|Product|Amount|Country|
+-------+------+-------+
|Banana |1000  |USA    |
|Carrots|1500  |USA    |
|Beans  |1600  |USA    |
|Orange |2000  |USA    |
|Orange |2000  |USA    |
|Banana |400   |China  |
|Carrots|1200  |China  |
|Beans  |1500  |China  |
|Orange |4000  |China  |
|Banana |2000  |Canada |
|Carrots|2000  |Canada |
|Beans  |2000  |Mexico |
+-------+------+-------+



The `pivot()` function is used to pivot/rotate the data from one DataFrame/Dataset column into 
multiple columns (transform row to column) and unpivot is used to transform it back (transform columns to rows).

## Pivot Dataframe

In [3]:
print('Pivot Dataframe')
pivot_df = df.groupBy('Product').pivot('Country').sum('Amount')
pivot_df.show()

Pivot Dataframe
+-------+------+-----+------+----+
|Product|Canada|China|Mexico| USA|
+-------+------+-----+------+----+
| Orange|  null| 4000|  null|4000|
|  Beans|  null| 1500|  2000|1600|
| Banana|  2000|  400|  null|1000|
|Carrots|  2000| 1200|  null|1500|
+-------+------+-----+------+----+



Spark 2.0 on-wards performance has been improved on Pivot, 
however, if you are using lower version; note that pivot is a 
very expensive operation hence, it is recommended to provide column data 
(if known) as an argument to function as shown below.

### Option 1

In [4]:
countries = ["USA", "China", "Canada", "Mexico"]
pivot_df_option1 = df.groupBy("Product").pivot("Country", countries).sum("Amount")
pivot_df_option1.show(truncate=False)

+-------+----+-----+------+------+
|Product|USA |China|Canada|Mexico|
+-------+----+-----+------+------+
|Orange |4000|4000 |null  |null  |
|Beans  |1600|1500 |null  |2000  |
|Banana |1000|400  |2000  |null  |
|Carrots|1500|1200 |2000  |null  |
+-------+----+-----+------+------+



### Option 2 (two-phase aggregation)

In [5]:
pivot_df_option2 = df.groupBy("Product", "Country") \
    .sum("Amount") \
    .groupBy("Product") \
    .pivot("Country") \
    .sum("sum(Amount)")
pivot_df_option2.show(truncate=False)

+-------+------+-----+------+----+
|Product|Canada|China|Mexico|USA |
+-------+------+-----+------+----+
|Orange |null  |4000 |null  |4000|
|Beans  |null  |1500 |2000  |1600|
|Banana |2000  |400  |null  |1000|
|Carrots|2000  |1200 |null  |1500|
+-------+------+-----+------+----+



## Unpivot DataFrame

In [6]:
print('Unpivot DataFrame')
unpivotExpr = "stack(3, 'Canada', Canada, 'China', China, 'Mexico', Mexico) as (Country,Total)"
unPivotDF = pivot_df.select("Product", expr(unpivotExpr)) \
    .where("Total is not null")
unPivotDF.show(truncate=False)

Unpivot DataFrame
+-------+-------+-----+
|Product|Country|Total|
+-------+-------+-----+
|Orange |China  |4000 |
|Beans  |China  |1500 |
|Beans  |Mexico |2000 |
|Banana |Canada |2000 |
|Banana |China  |400  |
|Carrots|Canada |2000 |
|Carrots|China  |1200 |
+-------+-------+-----+

