In [1]:
import findspark
import pyspark
from pyspark.sql import *
from pyspark.sql.functions import col, countDistinct
from pyspark.sql.functions import struct
from pyspark.sql import SparkSession, functions as F, types
from pyspark.sql.types import StructType
from pyspark.sql.types import *
from pyspark.sql.functions import broadcast
from math import log
from pyspark.sql.functions import log10
findspark.init()
spark = SparkSession.builder.appName('Feature_Encoder').master("local").getOrCreate()

In [14]:
#Create dataset
schema = StructType([StructField("Country", StringType(),True), 
                     StructField("Label", IntegerType(),True)])
                     
data = [("US",0),("CA",1),("IN",0),("IN",1),("US",0),("CA",1),("IN",0),("CA",1),("CA",0),("CA",1),("US",0),("CA",1),("CN",1),("CN",1)]

df = spark.createDataFrame(data, schema=schema)
total_count = df.count()
df.show()


+-------+-----+
|Country|Label|
+-------+-----+
|     US|    0|
|     CA|    1|
|     IN|    0|
|     IN|    1|
|     US|    0|
|     CA|    1|
|     IN|    0|
|     CA|    1|
|     CA|    0|
|     CA|    1|
|     US|    0|
|     CA|    1|
|     CN|    1|
|     CN|    1|
+-------+-----+



In [15]:
#categorical column name 
cat_col = "Country"
label_col = "label"
positive_class = 1
negative_class = 0 

#Count the number distinct category with their respective class
df1 = df.groupby(cat_col,label_col).count() 

#Create two seperate dataframes for positive and negative class
#Dataframe containing positive class
df_class1 = df1.select(cat_col,label_col,'count').where(df[label_col]== positive_class)
df_class1 = df_class1.withColumnRenamed("count","count_c1").dropDuplicates().drop(label_col)

#Dataframe containing negative class
df_class0 = df1.select(cat_col,label_col,'count').where(df[label_col]== negative_class)
df_class0 = df_class0.withColumnRenamed("count","count_c0").dropDuplicates().drop(label_col)

In [16]:
# Join both the dataframes positive class and negative class
df_join = df_class0.join(df_class1, on = [cat_col], how = "full").na.fill(0)

#Calculate the Weighted Average Encoding for all the cases 
# total_C0 = df_join.select(F.sum("count_c0")).collect()[0][0]
# total_C1 = df_join.select(F.sum("count_c1")).collect()[0][0]

total_C0 = df_join.select(F.sum("count_c0")).take(1)[0][0]
total_C1 = df_join.select(F.sum("count_c1")).take(1)[0][0]
we_col = (df_join['count_c1'] / total_C1) / (df_join['count_c0'] / total_C0)


df_join =df_join.withColumn("WOE",we_col).na.fill(0) 
df_join.show()


#Special Case Handling : 
#where the value is either missing or when the number of positive or negative values for a category is zero
df_join = df_join.withColumn("WOE_prime", \
              F.when((df_join["count_c0"] == 0) | (df_join["count_c1"] == 0) ,\
                     (col("count_c0") + col("count_c1")) / total_count)\
                             .otherwise(df_join["WOE"]))

df_join.show()



+-------+--------+--------+-----+
|Country|count_c0|count_c1|  WOE|
+-------+--------+--------+-----+
|     CN|       0|       2|  0.0|
|     CA|       1|       5| 3.75|
|     US|       3|       0|  0.0|
|     IN|       2|       1|0.375|
+-------+--------+--------+-----+

+-------+--------+--------+-----+-------------------+
|Country|count_c0|count_c1|  WOE|          WOE_prime|
+-------+--------+--------+-----+-------------------+
|     CN|       0|       2|  0.0|0.14285714285714285|
|     CA|       1|       5| 3.75|               3.75|
|     US|       3|       0|  0.0|0.21428571428571427|
|     IN|       2|       1|0.375|              0.375|
+-------+--------+--------+-----+-------------------+



In [17]:
df_encoded = df.join(broadcast(df_join), on=[cat_col], how = "right").drop('count_c0','count_c1','WOE')
df_encoded = df_encoded.withColumn(cat_col+"_encoded",log10("WOE_prime")).drop("WOE_prime")

In [18]:
df_encoded.show()

+-------+-----+--------------------+
|Country|label|     Country_encoded|
+-------+-----+--------------------+
|     CN|    1| -0.8450980400142569|
|     CN|    1| -0.8450980400142569|
|     CA|    1|  0.5740312677277188|
|     CA|    1|  0.5740312677277188|
|     CA|    1|  0.5740312677277188|
|     CA|    0|  0.5740312677277188|
|     CA|    1|  0.5740312677277188|
|     CA|    1|  0.5740312677277188|
|     US|    0| -0.6690067809585756|
|     US|    0| -0.6690067809585756|
|     US|    0| -0.6690067809585756|
|     IN|    0|-0.42596873227228116|
|     IN|    1|-0.42596873227228116|
|     IN|    0|-0.42596873227228116|
+-------+-----+--------------------+

