In [4]:
import pandas as pd

# 自动执行efficient_apriori和mlxtend.frequent_patterns工具包
class MarketBasket():
    def __init__(self, path):
        self.rule1(path)
        self.rule2(path)

    # 定义efficient_apriori的apriori算法
    def rule1(self, path):
        from efficient_apriori import apriori
        print('rule1 loading............')
        # 打开csv文件
        f = open(path, 'r')
        content = f.read()
        # 逐行处理数据
        rows = content.split('\n')
        transaction = []
        for row in rows:
            if row != '':
                # 文件中每行作为一个set()添加至list
                transaction.append(set(row.lower().split(',')))
        f.close()
        # 计算efficient_apriori的频繁项集和关联规则，transaction作为list输入算法
        itemsets, rules = apriori(transaction, min_support = 0.05, min_confidence = 0.25)
        print('itemsets of rule1：', '\n', itemsets)
        print('rules of rule1：', '\n', rules)

    # 定义mlxtend.frequent_patterns的apriori算法
    def rule2(self, path):
        from mlxtend.frequent_patterns import apriori
        from mlxtend.frequent_patterns import association_rules
        print('\n' + 'rule2 loading............')
        # 将csv文件转化为DataFrame格式
        df = pd.read_csv(path, header = None)
        # 利用apply()将DataFrame中的所有列合并至一列
        df['Combination'] = df.apply(lambda x: '|'.join(x.str.lower().dropna().astype(str)), axis = 1)
        # 将合并后的列（Series格式）用get_dummies()方法进行独热编码
        df_encoded = df['Combination'].str.get_dummies()
        # 独热编码后转化为DataFrame格式，作为参数传入apriori()，并进行排序
        itemsets = apriori(df_encoded, use_colnames = True, min_support = 0.05)
        itemsets = itemsets.sort_values(by = 'support', ascending = False)
        # 显示全部DataFrame列（设置为最多显示100列）
        pd.options.display.max_columns = 100
        # 输入计算好的频繁项集，计算关联规则并排序
        rules = association_rules(itemsets, metric = 'confidence', min_threshold = 0.25)
        rules = rules.sort_values(by = 'confidence', ascending = False)
        print('itemsets of rule2:', '\n', itemsets)
        print('rules of rule2:', '\n', rules)

if __name__ == '__main__':
    mkt_bkt = MarketBasket('C:/Users/Administrator/Desktop/RS/L3Data/Market_Basket_Optimisation.csv')

rule1 loading............
itemsets of rule1： 
 {1: {('burgers',): 654, ('cake',): 608, ('chicken',): 450, ('chocolate',): 1229, ('cookies',): 603, ('cooking oil',): 383, ('eggs',): 1348, ('escalope',): 595, ('french fries',): 1282, ('frozen smoothie',): 475, ('frozen vegetables',): 715, ('grated cheese',): 393, ('green tea',): 991, ('ground beef',): 737, ('low fat yogurt',): 574, ('milk',): 972, ('mineral water',): 1788, ('olive oil',): 494, ('pancakes',): 713, ('shrimp',): 536, ('soup',): 379, ('spaghetti',): 1306, ('tomatoes',): 513, ('turkey',): 469, ('whole wheat rice',): 439}, 2: {('chocolate', 'mineral water'): 395, ('eggs', 'mineral water'): 382, ('mineral water', 'spaghetti'): 448}}
rules of rule1： 
 [{chocolate} -> {mineral water}, {eggs} -> {mineral water}, {spaghetti} -> {mineral water}, {mineral water} -> {spaghetti}]

rule2 loading............
itemsets of rule2: 
      support                    itemsets
16  0.238368             (mineral water)
6   0.179709                