In [None]:
from pyspark.ml.fpm import FPGrowth
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, collect_set, desc

In [None]:
spark = SparkSession.builder.appName('main').getOrCreate()

In [None]:
dataset_path = 'inputs/by-product.csv'
output_path = 'outputs/by-product'

In [None]:
def check(df):
    print(f"Count: {df.count()}")
    print('\nElements:')
    print(df.show(5, truncate = False))

In [None]:
df = spark.read.options(header = True).csv(dataset_path)

In [None]:
# check(df)

In [None]:
df = df.groupBy('sale_code').agg(collect_set('barcode'))

In [None]:
df = df.withColumnRenamed('collect_set(barcode)', 'items')


In [None]:
check(df)

In [None]:
fp_growth = FPGrowth(itemsCol = 'items', minSupport = 0.0005, minConfidence = 0.001)

In [None]:
model = fp_growth.fit(df)

In [None]:
freq_itemsets = model.freqItemsets

In [None]:
freq_itemsets = freq_itemsets.select('freq', 'items')

In [None]:
freq_itemsets = freq_itemsets.sort(col('freq').desc())

In [None]:
# check(freq_itemsets)

In [None]:
rules = model.associationRules

In [None]:
rules = rules.sort(col('confidence').desc(), col('lift').desc())

In [None]:
# check(rules)

In [None]:
transformed = model.transform(df)

In [None]:
# check(transformed)

In [None]:
freq_itemsets.toPandas().to_csv(f"{output_path}/frequent_itemsets.csv", index = False)

In [None]:
rules.toPandas().to_csv(f"{output_path}/association_rules.csv", index = False)

In [None]:
transformed.toPandas().to_csv(f"{output_path}/transformed.csv", index = False)