In [24]:
from pyspark.ml.fpm import FPGrowth
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, collect_set, desc

In [25]:
spark = SparkSession.builder.appName('main').getOrCreate()

In [26]:
dataset_path = 'inputs/by-manufacturer.csv'
output_path = 'outputs/by-manufacturer'

In [27]:
def check(df):
    print(f"Count: {df.count()}")
    print('\nElements:')
    print(df.show(5, truncate = False))

In [28]:
df = spark.read.options(header = True).csv(dataset_path)

In [29]:
# check(df)

In [30]:
df = df.groupBy('sale_code').agg(collect_set('barcode'))

In [31]:
df = df.withColumnRenamed('collect_set(barcode)', 'items')


In [32]:
check(df)

Count: 15

Elements:
+---------+--------------------------------------------------------------------------------------------------------+
|sale_code|items                                                                                                   |
+---------+--------------------------------------------------------------------------------------------------------+
|592029   |[7896422519229, 7896523216454, 7891058004668, 7896023776373, 7896023732935]                             |
|590426   |[7896023729287, 7896004706559]                                                                          |
|590968   |[7896226108926, 8902220109544, 7896023720833, 7896023733765, 7896023731891, 7891317004286]              |
|591892   |[7896111913789, 7896111996249, 7898593050716, 7896023732935, 731509975000, 7898949409458, 7896111913802]|
|591014   |[7896663325108, 7896023733765]                                                                          |
+---------+--------------------------------

In [33]:
fp_growth = FPGrowth(itemsCol = 'items', minSupport = 0.0005, minConfidence = 0.001)

In [34]:
model = fp_growth.fit(df)

In [35]:
freq_itemsets = model.freqItemsets

In [36]:
freq_itemsets = freq_itemsets.select('freq', 'items')

In [37]:
freq_itemsets = freq_itemsets.sort(col('freq').desc())

In [38]:
# check(freq_itemsets)

In [39]:
rules = model.associationRules

In [40]:
rules = rules.sort(col('confidence').desc(), col('lift').desc())

In [41]:
# check(rules)

In [42]:
transformed = model.transform(df)

In [43]:
# check(transformed)

In [44]:
freq_itemsets.toPandas().to_csv(f"{output_path}/frequent_itemsets.csv", index = False)

In [45]:
rules.toPandas().to_csv(f"{output_path}/association_rules.csv", index = False)

In [46]:
transformed.toPandas().to_csv(f"{output_path}/transformed.csv", index = False)