In [1]:
import findspark
import pyspark
from pyspark.sql import *
from pyspark.sql.functions import col, countDistinct
from pyspark.sql.functions import struct
from pyspark.sql import SparkSession, functions as F, types
from pyspark.sql.types import StructType
from pyspark.sql.types import *
from pyspark.sql.functions import broadcast
findspark.init()
spark = SparkSession.builder.appName('Feature_Encoder').master("local").getOrCreate()

In [7]:
#Create dataset
schema = StructType([StructField("Country", StringType(),True), 
                     StructField("Label", IntegerType(),True)])
                     
data = [("US",0),("CA",1),("IN",0),("IN",1),("US",0),("CA",1),("IN",0),("CA",1),("CA",0),("CA",1),("US",0),("CA",1),("CN",1),("CN",1)]

df = spark.createDataFrame(data, schema=schema)
total_count = df.count()
df.show()

+-------+-----+
|Country|Label|
+-------+-----+
|     US|    0|
|     CA|    1|
|     IN|    0|
|     IN|    1|
|     US|    0|
|     CA|    1|
|     IN|    0|
|     CA|    1|
|     CA|    0|
|     CA|    1|
|     US|    0|
|     CA|    1|
|     CN|    1|
|     CN|    1|
+-------+-----+



In [3]:
#categorical column name 
cat_col = "Country"
label_col = "label"
positive_class = 1
negative_class = 0 

#Count the number distinct category with their respective class
df1 = df.groupby(cat_col,label_col).count() 

#Create two seperate dataframes for positive and negative class
#Dataframe containing positive class
df_class1 = df1.select(cat_col,label_col,'count').where(df[label_col]== positive_class)
df_class1 = df_class1.withColumnRenamed("count","count_c1").dropDuplicates().drop(label_col)

#Dataframe containing negative class
df_class0 = df1.select(cat_col,label_col,'count').where(df[label_col]== negative_class)
df_class0 = df_class0.withColumnRenamed("count","count_c0").dropDuplicates().drop(label_col)

df_class1.show()
df_class0.show()

+-------+--------+
|Country|count_c1|
+-------+--------+
|     CN|       2|
|     IN|       1|
|     CA|       5|
+-------+--------+

+-------+--------+
|Country|count_c0|
+-------+--------+
|     IN|       2|
|     US|       3|
|     CA|       1|
+-------+--------+



In [42]:
# Join both the dataframes positive class and negative class
df_join = df_class0.join(df_class1, on = [cat_col], how = "full").na.fill(0)

#Calculate the SUpervised Ratio for all the cases 
sr_col = df_join['count_c1'] / (df_join['count_c1'] + df_join['count_c0'])
df_join =df_join.withColumn("SR",sr_col).na.fill(0) 


#Special cases handling:
#where the value is either missing or when the number of positive or negative values for a category is zero

df_join = df_join.withColumn(cat_col+"_encoded", \
              F.when((df_join["count_c0"] == 0) | (df_join["count_c1"] == 0) ,\
                     (col("count_c0") + col("count_c1")) / total_count)\
                             .otherwise(df_join["SR"]))
df_join.show()

+-------+--------+--------+------------------+-------------------+
|Country|count_c0|count_c1|                SR|    Country_encoded|
+-------+--------+--------+------------------+-------------------+
|     CN|       0|       2|               1.0|0.14285714285714285|
|     CA|       1|       5|0.8333333333333334| 0.8333333333333334|
|     US|       3|       0|               0.0|0.21428571428571427|
|     IN|       2|       1|0.3333333333333333| 0.3333333333333333|
+-------+--------+--------+------------------+-------------------+



In [43]:
df_sr = df.join(broadcast(df_join), on=[cat_col], how = "right").drop('count_c0','count_c1','SR')
df_sr.show()

+-------+-----+-------------------+
|Country|Label|    Country_encoded|
+-------+-----+-------------------+
|     CN|    1|0.14285714285714285|
|     CN|    1|0.14285714285714285|
|     CA|    1| 0.8333333333333334|
|     CA|    1| 0.8333333333333334|
|     CA|    1| 0.8333333333333334|
|     CA|    0| 0.8333333333333334|
|     CA|    1| 0.8333333333333334|
|     CA|    1| 0.8333333333333334|
|     US|    0|0.21428571428571427|
|     US|    0|0.21428571428571427|
|     US|    0|0.21428571428571427|
|     IN|    0| 0.3333333333333333|
|     IN|    1| 0.3333333333333333|
|     IN|    0| 0.3333333333333333|
+-------+-----+-------------------+

