In [1]:
# 分析Market_Basket 物品分类中的频繁项集和关联规则
import pandas as pd
from mlxtend.frequent_patterns import apriori
from mlxtend.frequent_patterns import association_rules

In [2]:
# 数据加载
basket = pd.read_csv('./Market_Basket_Optimisation.csv', header=None)
# print(basket.head())

# 进行one-hot编码（离散特征有多少取值，就用多少维来表示这个特征）
basket['tmp'] = basket[0].str.cat([basket[i] for i in range(1, 20)], sep='|', na_rep='')
basket_hot_encoded = basket.tmp.str.get_dummies(sep='|')
basket.drop(['tmp'], axis=1, inplace=True)
# print(basket_hot_encoded.head())

# 挖掘频繁项集，最小支持度为0.02
itemsets = apriori(basket_hot_encoded,use_colnames=True, min_support=0.03)
# 按照支持度从大到小进行
itemsets = itemsets.sort_values(by="support" , ascending=False) 
print('-'*20, '频繁项集', '-'*20)
print(itemsets)
print('总共%d项' % len(itemsets))

# 根据频繁项集计算关联规则，设置最小提升度为2
rules =  association_rules(itemsets, metric='lift', min_threshold=1.5)
# 按照提升度从大到小进行排序
rules = rules.sort_values(by="lift" , ascending=False) 
#rules.to_csv('./rules.csv')
print('-'*20, '关联规则', '-'*20)
print(rules)


-------------------- 频繁项集 --------------------
     support                            itemsets
25  0.238368                     (mineral water)
11  0.179709                              (eggs)
31  0.174110                         (spaghetti)
13  0.170911                      (french fries)
7   0.163845                         (chocolate)
18  0.132116                         (green tea)
24  0.129583                              (milk)
19  0.098254                       (ground beef)
16  0.095321                 (frozen vegetables)
27  0.095054                          (pancakes)
2   0.087188                           (burgers)
4   0.081056                              (cake)
8   0.080389                           (cookies)
12  0.079323                          (escalope)
23  0.076523                    (low fat yogurt)
29  0.071457                            (shrimp)
33  0.068391                          (tomatoes)
26  0.065858                         (olive oil)
15  0.063325          