In [0]:
df=spark.read.load("/Volumes/luffy/phase2/bronze/customer_df_without_target/")
display(df)

CustomerID,total_spent,total_transactions,total_quantity,last_purchase_date
17420.0,598.8300000000002,30,265,2011-10-20
16552.0,379.73,17,219,2011-04-11
17572.0,226.75,12,95,2011-09-29
15350.0,115.65,5,51,2010-12-01
12921.0,16784.440000000013,741,9598,2011-12-06
13090.0,9132.249999999998,161,2332,2011-12-01
14135.0,4690.31,134,3850,2011-12-08
12915.0,363.65,22,93,2011-07-14
17685.0,3191.5299999999997,130,1956,2011-11-23
17581.0,11353.969999999988,452,5993,2011-12-09


In [0]:
quantile_value = df.approxQuantile("total_spent", [0.8], 0.01)[0]
from pyspark.sql.functions import when, col

df = df.withColumn(
    "is_high_valued",
    when(col("total_spent") >= quantile_value, 1).otherwise(0)
)

In [0]:
display(df)

CustomerID,total_spent,total_transactions,total_quantity,last_purchase_date,is_high_valued
17420.0,598.8300000000002,30,265,2011-10-20,0
16552.0,379.73,17,219,2011-04-11,0
17572.0,226.75,12,95,2011-09-29,0
15350.0,115.65,5,51,2010-12-01,0
12921.0,16784.440000000013,741,9598,2011-12-06,1
13090.0,9132.249999999998,161,2332,2011-12-01,1
14135.0,4690.31,134,3850,2011-12-08,1
12915.0,363.65,22,93,2011-07-14,0
17685.0,3191.5299999999997,130,1956,2011-11-23,1
17581.0,11353.969999999988,452,5993,2011-12-09,1


In [0]:
# spark.sql("""
# create volume if not exists luffy.phase2.silver          
# """)

In [0]:
df.write.save("/Volumes/luffy/phase2/silver/df")

In [0]:
df=spark.read.load("/Volumes/luffy/phase2/silver/df/")

In [0]:
df.groupBy("is_high_valued").count().show()

+--------------+-----+
|is_high_valued|count|
+--------------+-----+
|             1|  904|
|             0| 3469|
+--------------+-----+



In [0]:
minority, majority = df.groupBy("is_high_valued").count().collect()

In [0]:
from pyspark.sql.functions import when, col

count_0 = majority['count']
count_1 = minority['count']
total = count_0 + count_1

weight_0 = total / (2 * count_0)
weight_1 = total / (2 * count_1)

df1 = df.withColumn(
    "class_weight",
    when(col("is_high_valued") == 1, weight_1).otherwise(weight_0)
)

In [0]:
display(df1)

CustomerID,total_spent,total_transactions,total_quantity,last_purchase_date,is_high_valued,class_weight
17420.0,598.8300000000002,30,265,2011-10-20,0,0.6302969155376189
16552.0,379.73,17,219,2011-04-11,0,0.6302969155376189
17572.0,226.75,12,95,2011-09-29,0,0.6302969155376189
15350.0,115.65,5,51,2010-12-01,0,0.6302969155376189
12921.0,16784.440000000013,741,9598,2011-12-06,1,2.418694690265487
13090.0,9132.249999999998,161,2332,2011-12-01,1,2.418694690265487
14135.0,4690.31,134,3850,2011-12-08,1,2.418694690265487
12915.0,363.65,22,93,2011-07-14,0,0.6302969155376189
17685.0,3191.5299999999997,130,1956,2011-11-23,1,2.418694690265487
17581.0,11353.969999999988,452,5993,2011-12-09,1,2.418694690265487


In [0]:
df1.write.format("delta").mode("overwrite").save("/Volumes/luffy/phase2/silver/df1")

In [0]:
train_df, test_df = df1.randomSplit([0.8, 0.2], seed=42)
display(train_df)
display(test_df)

CustomerID,total_spent,total_transactions,total_quantity,last_purchase_date,is_high_valued,class_weight
,2062871.1599989403,135080,689008,2011-12-09,1,2.418694690265487
12346.0,154367.2,2,148430,2011-01-18,1,2.418694690265487
12347.0,4309.999999999997,182,2458,2011-12-07,1,2.418694690265487
12348.0,1797.24,31,2341,2011-09-25,0,0.6302969155376189
12349.0,1757.55,73,631,2011-11-21,0,0.6302969155376189
12350.0,334.40000000000003,17,197,2011-02-02,0,0.6302969155376189
12352.0,3466.670000000001,95,602,2011-11-03,1,2.418694690265487
12353.0,89.0,4,20,2011-05-19,0,0.6302969155376189
12355.0,459.4,13,240,2011-05-09,0,0.6302969155376189
12356.0,2811.4300000000007,59,1591,2011-11-17,1,2.418694690265487


CustomerID,total_spent,total_transactions,total_quantity,last_purchase_date,is_high_valued,class_weight
12354,1079.4,58,530,2011-04-21,0,0.6302969155376189
12357,6207.669999999996,131,2708,2011-11-06,1,2.418694690265487
12359,6499.630000000005,254,1632,2011-12-02,1,2.418694690265487
12370,3545.689999999998,167,2353,2011-10-19,1,2.418694690265487
12372,1298.0400000000002,52,794,2011-09-29,0,0.6302969155376189
12373,364.6,14,197,2011-02-01,0,0.6302969155376189
12374,742.93,33,342,2011-11-14,0,0.6302969155376189
12377,1628.1199999999997,77,944,2011-01-28,0,0.6302969155376189
12380,2729.06,105,1128,2011-11-18,1,2.418694690265487
12386,401.9,10,354,2011-01-06,0,0.6302969155376189


In [0]:
display(
    train_df.groupBy("is_high_valued").count()
)

display(
    test_df.groupBy("is_high_valued").count()
)

is_high_valued,count
1,734
0,2751


is_high_valued,count
1,170
0,718
