# Apriori Algorithm

In [1]:
import pandas as pd
from mlxtend.frequent_patterns import apriori
from mlxtend.frequent_patterns import association_rules

In [2]:
# 读入原始数据集
df1 = pd.read_excel('./retail.xlsx', sheet_name='Year 2009-2010')
df2 = pd.read_excel('./retail.xlsx', sheet_name='Year 2010-2011')
raw_data = pd.concat([df1, df2], ignore_index=True)
print(f'Number of samples: {raw_data.shape[0]}')

Number of samples: 1067371


## Data Preprocessing

首先去除已经取消的订单（Invoice字段以C开头），然后去除代表邮费的订单（StockCode中包含POST字段）。

In [3]:
# 去除已取消订单；StockCode中的POST表示邮费，将其删除
raw_data['Invoice'] = raw_data['Invoice'].astype(str)
raw_data['StockCode'] = raw_data['StockCode'].astype(str)

data1 = raw_data[~raw_data['Invoice'].str.startswith('C')]
data1 = data1[~(data1['StockCode'] == 'POST')]

print(f'Number of samples: {data1.shape[0]}')

Number of samples: 1045984


接下来统计每个属性的缺失值数量，然后将含有缺失值的样本（行）删除。

In [4]:
# 统计每个属性缺失值数量，处理缺失值（丢弃）
print(data1.isna().sum())

Invoice             0
StockCode           0
Description      4375
Quantity            0
InvoiceDate         0
Price               0
Customer ID    242202
Country             0
dtype: int64


In [5]:
data1 = data1.dropna()
data1.reset_index(drop=True, inplace=True)
print(f'Number of samples: {data1.shape[0]}')

Number of samples: 803782


由于StockCode和Description是一一对应的，因此这里删除所有不是一一对应的情况的商品所对应的行。

In [6]:
unique_count = data1.groupby('StockCode')['Description'].nunique()
multi_stock = unique_count[unique_count > 1].index
data1 = data1[~data1['StockCode'].isin(multi_stock)]
data1.reset_index(drop=True, inplace=True)
print(f'Number of samples: {data1.shape[0]}')
print(data1.StockCode.nunique(), data1.Description.nunique())

Number of samples: 602841
4006 3983


最后替换异常值，统计Quantity和Price两个属性的1%和99%分位数的值，将小于1%分位数的值替换为1%分位数，将大于99%的值替换为99%分位数。

In [7]:
quantity_1 = data1['Quantity'].quantile(0.01)
quantity_99 = data1['Quantity'].quantile(0.99)
price_1 = data1['Price'].quantile(0.01)
price_99 = data1['Price'].quantile(0.99)

quantity_up = quantity_99 + 1.5 * (quantity_99 - quantity_1)
quantity_low = quantity_1 - 1.5 * (quantity_99 - quantity_1)
price_up = price_99 + 1.5 * (price_99 - price_1)
price_low = price_1 - 1.5 * (price_99 - price_1)

quantity_outlier = data1[(data1['Quantity'] < quantity_1) | (data1['Quantity'] > quantity_99)]
price_outlier = data1[(data1['Price'] < price_1) | (data1['Price'] > price_99)]
print(f'quantity outlier: {len(quantity_outlier)}, price outlier: {len(price_outlier)}')

for i in quantity_outlier.index:
    if data1.loc[i, 'Quantity'] <= quantity_1:
        data1.loc[i, 'Quantity'] = quantity_low
    else:
        data1.loc[i, 'Quantity'] = quantity_up

for i in price_outlier.index:
    if data1.loc[i, 'Price'] <= price_1:
        data1.loc[i, 'Price'] = price_low
    else:
        data1.loc[i, 'Price'] = price_up

quantity outlier: 4458, price outlier: 9938


  data1.loc[i, 'Quantity'] = quantity_up


## Preparing Invoice-Product Matrix for ARL Data Structure

这里要将原始数据转化成适合进行关联规则分析的格式，即每行代表一笔交易，后面每一列为一种商品，在这笔交易中存在该商品，则该位置为1，否则为0。

由于原始数据量过大，因此这里首先查看每个国家的数据总量，选择一个国家的数据进行关联规则分析（这里选择France）。

In [8]:
# 每个国家数据量
country_counts = data1['Country'].value_counts()
print(country_counts)

Country
United Kingdom          548514
EIRE                     11788
Germany                  10367
France                    8719
Netherlands               3347
Spain                     2696
Switzerland               2005
Belgium                   1886
Portugal                  1655
Australia                 1223
Channel Islands           1113
Norway                    1103
Italy                      986
Sweden                     944
Cyprus                     870
Finland                    600
Denmark                    578
Austria                    569
Greece                     518
Poland                     412
Unspecified                386
Japan                      357
United Arab Emirates       276
USA                        271
Singapore                  228
Malta                      216
Israel                     200
Iceland                    199
Lithuania                  163
Canada                     161
RSA                         85
Brazil                      76


In [9]:
# 提取出法国数据
data_fr = data1[data1['Country'] == 'France']
print(data_fr.shape)

(8719, 8)


In [10]:
# 格式转化
basket = data_fr.groupby(['Invoice', 'StockCode'])['Quantity'].sum().unstack().reset_index().fillna(0)
basket = basket.set_index('Invoice')
basket[basket > 0] = 1
basket.head()

StockCode,10002,10120,10123C,10123G,10125,10135,11001,15036,15039,15044C,...,90184C,90201B,90201C,90209A,90214C,90214E,90214L,90214S,C2,M
Invoice,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
489439,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
489557,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
489883,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
490139,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
490152,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## Determination of Association Rules

以法国为例，设置最小支持度为0.01，然后调用 `mlxtend.frequent_patterns` 中的 `apriori` 筛选出满足 `min_support` 的频繁项集及其对应支持度。

In [11]:
# 计算support values，min_support设置为0.01
freq_set = apriori(basket, min_support=0.01, use_colnames=True)
print(freq_set.shape[0])



2900


In [12]:
freq_set.head()

Unnamed: 0,support,itemsets
0,0.052632,(10002)
1,0.010187,(10125)
2,0.061121,(15056BL)
3,0.033956,(15056N)
4,0.016978,(15056P)


接下来调用 `mlxtend.frequent_patterns` 中的 `association_rules`，从刚刚的频繁项集中生成关联规则，最小支持度阈值设置为0.01，然后按照支持度降序排列，输出支持度最高的前5个关联规则。

In [13]:
# 从频繁项集中生成关联规则，评估关联规则的指标为support，最小支持度阈值为0.01
rule = association_rules(freq_set, metric='support', min_threshold=0.01)
rule.head()

Unnamed: 0,antecedents,consequents,antecedent support,consequent support,support,confidence,lift,leverage,conviction,zhangs_metric
0,(20749),(10002),0.088285,0.052632,0.011885,0.134615,2.557692,0.007238,1.094737,0.667997
1,(10002),(20749),0.052632,0.088285,0.011885,0.225806,2.557692,0.007238,1.177632,0.642857
2,(21731),(10002),0.219015,0.052632,0.011885,0.054264,1.031008,0.000357,1.001726,0.038509
3,(10002),(21731),0.052632,0.219015,0.011885,0.225806,1.031008,0.000357,1.008772,0.031746
4,(21791),(10002),0.049236,0.052632,0.011885,0.241379,4.586207,0.009293,1.248804,0.822449


这里查看支持度前五高的关联规则，其中antecendents是前项，consequents是后项，这两项共同组成关联规则。

例如，一个关联规则可以表示为：{A, B} -> {C}，其中{A, B}是前项，{C}是后项。

In [14]:
# 查看支持度最高的前五个关联规则
top5_rule = rule.sort_values(by='support', ascending=False).head(5)
top5_rule

Unnamed: 0,antecedents,consequents,antecedent support,consequent support,support,confidence,lift,leverage,conviction,zhangs_metric
423,(21086),(21094),0.144312,0.132428,0.120543,0.835294,6.307541,0.101432,5.267402,0.983372
422,(21094),(21086),0.132428,0.144312,0.120543,0.910256,6.307541,0.101432,9.534805,0.969902
1556,(22556),(22554),0.17657,0.164686,0.106961,0.605769,3.678331,0.077882,2.118846,0.884274
1557,(22554),(22556),0.164686,0.17657,0.106961,0.649485,3.678331,0.077882,2.349196,0.871693
1439,(22554),(22551),0.164686,0.134126,0.096774,0.587629,4.381182,0.074686,2.099745,0.923905
