In [85]:
import findspark
from pyspark.sql import SparkSession

findspark.init()
spark = SparkSession.builder.master("local[3]").appName("spark_pivot").getOrCreate()

In [86]:
data = [("Toyota",15000,"India"), ("Kia",18000,"South Korea"), ("Ford",14000,"India"),
      ("Renault",12000,"France"),("Renault",10000,"India"),("Toyota",12000,"Japan"),
      ("Kia",16000,"South Korea"),("Ford",13000,"USA"),("Renault",12000,"South Korea"),
      ("Toyota",15000,"Japan"),("Kia",13000,"South Korea"),("Ford",10000,"India")]


df = spark.createDataFrame(data,["Company","Price","Country"])
df.show()

+-------+-----+-----------+
|Company|Price|    Country|
+-------+-----+-----------+
| Toyota|15000|      India|
|    Kia|18000|South Korea|
|   Ford|14000|      India|
|Renault|12000|     France|
|Renault|10000|      India|
| Toyota|12000|      Japan|
|    Kia|16000|South Korea|
|   Ford|13000|        USA|
|Renault|12000|South Korea|
| Toyota|15000|      Japan|
|    Kia|13000|South Korea|
|   Ford|10000|      India|
+-------+-----+-----------+



## To Pivot or transporse columns , think about three things
+ The columnname that you are going to groupby (Equivalent to <b>ROWS</b> in an excel pivot)
+ The columnname that you are going to pass to pivot  (Equivalent to <b>COLUMNS</b> in an excel pivot)
+ The columnname that you are going to apply a aggrgate functon (equivalent to <b>VALUES</b> in an excel pivot)


In [87]:
df.groupBy("Company").pivot("Country").count().show()

+-------+------+-----+-----+-----------+----+
|Company|France|India|Japan|South Korea| USA|
+-------+------+-----+-----+-----------+----+
|    Kia|  null| null| null|          3|null|
|Renault|     1|    1| null|          1|null|
| Toyota|  null|    1|    2|       null|null|
|   Ford|  null|    2| null|       null|   1|
+-------+------+-----+-----+-----------+----+



### Shape of the output
+ Number of rows in the output depends on the unique values of the groupBy operation. 4 in this example
+ Number of columns in the output depends on the unique values passed to pivot (5) + number of columns passed to groupBy (1)
   + For eg, 5+1=6 in this example

### Applying different aggregate functions
+ count aggregate doesn't require a column name
+ Except count all the other aggregate functions require an integer column to be passed. 
    + For ex: sum('salary') and not sum('product') 

In [None]:
pivotDF = df.groupBy("Company").pivot("Country").sum('Price')
pivotDF.show()

## Two step Process to improve pivot performance
##### Spark > 2.1 internally uses this process 

In [None]:
df.groupBy("Company","Country").sum('Price').\
    groupBy('Company').pivot("Country").\
    sum('sum(Price)').show() # the column used is sum(Amount) ,not 'Amount'

### Notes about Pivot
+ pivot is applied on grouped data, not on top of a dataframe

## Pivot_Table
+ Apply spreadsheet like Pivot 
  + Index (What are the output rows?) 
  + Columns (pivot column) - What are the output columns?
  + aggfunc is to pass the aggregate function 

In [None]:
df_pandas = df.toPandas()
df_pandas

In [None]:
table = df_pandas.pivot_table(values='Price', index='Company',
                       columns='Country', aggfunc='sum')
table

In [None]:
#Use fill_value to fill na
table = df_pandas.pivot_table(values='Price', index='Company',
                       columns='Country', aggfunc='sum',fill_value=0)
table

In [None]:
# Column specific aggregation . Pass column to be aggregated as dictionary to aggfunc
table = df_pandas.pivot_table(values='Price', index='Company',
                       columns='Country', aggfunc={'Price':'sum'},fill_value=0)
table

In [None]:
# Apply aggregates on multiple tables
# Lets add one more column
df_pandas['Commission']=df_pandas['Price']*.07
df_pandas

In [None]:
# Column specific aggregation 
table = df_pandas.pivot_table(values=['Price','Commission'], index='Company',
                       columns='Country', aggfunc={'Price':'sum','Commission':'mean'},fill_value=0)
table

## Unpivot

In [None]:
pivotDF.show()

In [None]:
from pyspark.sql.functions import expr

unPivotDF = pivotDF.select("Company",
expr("stack(5, 'France', France, 'India', India, 'Japan', Japan,'South Korea',`South Korea`,'USA',USA) as (Country,Total)")).\
where("Total is not null")

unPivotDF.show()