In [47]:
from pyspark.ml.fpm import FPGrowth
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, collect_set, desc

In [48]:
spark = SparkSession.builder.appName('main').getOrCreate()

In [49]:
dataset_path = 'inputs/by-group.csv'
output_path = 'outputs/by-group'

In [50]:
def check(df):
    print(f"Count: {df.count()}")
    print('\nElements:')
    print(df.show(5, truncate = False))

In [51]:
df = spark.read.options(header = True).csv(dataset_path)

In [52]:
# check(df)

In [53]:
df = df.groupBy('sale_code').agg(collect_set('barcode'))

In [54]:
df = df.withColumnRenamed('collect_set(barcode)', 'items')


In [55]:
check(df)

Count: 19

Elements:
+---------+------------------------------------------------------------------------------------------+
|sale_code|items                                                                                     |
+---------+------------------------------------------------------------------------------------------+
|590863   |[7891010589639, 3504105035617, 310119052440]                                              |
|590335   |[7896902215788, 7896007551088]                                                            |
|591083   |[7891317124038, 7896026306188, 7891317120863, 7891010032241, 7891721025334, 7896281152018]|
|590487   |[7896061995903]                                                                           |
|590866   |[7891010589639, 3504105035617, 310119052440]                                              |
+---------+------------------------------------------------------------------------------------------+
only showing top 5 rows

None


In [56]:
fp_growth = FPGrowth(itemsCol = 'items', minSupport = 0.0005, minConfidence = 0.001)

In [57]:
model = fp_growth.fit(df)

In [58]:
freq_itemsets = model.freqItemsets

In [59]:
freq_itemsets = freq_itemsets.select('freq', 'items')

In [60]:
freq_itemsets = freq_itemsets.sort(col('freq').desc())

In [61]:
# check(freq_itemsets)

In [62]:
rules = model.associationRules

In [63]:
rules = rules.sort(col('confidence').desc(), col('lift').desc())

In [64]:
# check(rules)

In [65]:
transformed = model.transform(df)

In [66]:
# check(transformed)

In [67]:
freq_itemsets.toPandas().to_csv(f"{output_path}/frequent_itemsets.csv", index = False)

In [68]:
rules.toPandas().to_csv(f"{output_path}/association_rules.csv", index = False)

In [69]:
transformed.toPandas().to_csv(f"{output_path}/transformed.csv", index = False)