## 数据准备

In [1]:
import numpy as np
dataset = np.array([[1, 3, 4],
                    [2, 3, 5],
                    [1, 2, 3, 5],
                    [2, 5]])

## 频繁项集

In [2]:
def create_1st_itemset(dataset):
    '''
    根据初始数据集创建单个物品的项集
    '''
#     tmp=dataset.flatten(dataset)    # 当输入数据规整时可使用numpy
#     return list(map(frozenset,np.unique(tmp)))

    tmp = list()
    for itemset in dataset:
        for item in itemset:
            if [item] not in tmp:
                tmp.append([item])
    tmp.sort()
    return list(map(frozenset, tmp))    # frozenset可用作字典key

In [3]:
def itemset_filter(dataset, itemsets, min_sup=0.5):
    '''
    过滤小于支持度阈值的项集，返回频繁项集与支持度
    itemsets: 项集列表
    '''
    sup_dict = dict()    # 支持度字典

    # 先计数
    for sample in dataset:    # 原数据中的项集
        for itemset in itemsets:    # 项集列表中的所有项集
            if itemset.issubset(sample):
                sup_dict[itemset] = sup_dict.get(itemset, 0)+1

    len_data = len(dataset)    # 计数/总数=支持度
    freq_sets = list()    # 频繁项集列表

    # 计算支持度并过滤掉不合格的项集
    for itemset in sup_dict:
        sup = sup_dict[itemset]/len_data

        if sup >= min_sup:
            freq_sets.append(itemset)
        sup_dict[itemset] = sup

    return freq_sets, sup_dict

In [4]:
def extend_itemset(freq_sets):
    '''
    根据已有的k阶频繁项集生成k+1阶项集
    freq_sets: k阶频繁项集的列表
    '''
    res = list()
    raw_size = len(freq_sets)

    for i in range(raw_size):
        for j in range(i+1, raw_size):    # 两两组合
            # 当两项集头部都相等时，才进行合并
            head_1 = list(freq_sets[i])[:-1]
            head_2 = list(freq_sets[j])[:-1]
            if head_1 == head_2:
                res.append(freq_sets[i] | freq_sets[j])
    return res


# itemsets_1st= create_1st_itemset(dataset)
# freq_sets_1st,_=itemset_filter(dataset, itemsets_1st, 0.5)
# itemsets_2nd = extend_itemset(freq_sets_1st)
# freq_sets_2nd, _ = itemset_filter(dataset, itemsets_2nd, 0.5)
# itemsets_3th = extend_itemset(freq_sets_2nd)
# freq_sets_3th, _ = itemset_filter(dataset, itemsets_3th, 0.5)
# freq_sets_3th

In [5]:
def apriori(dataset, min_sup=0.5):
    itemsets_1st = create_1st_itemset(dataset)    # 构建1阶项集列表
    freq_sets_1st, sup_dict = itemset_filter(dataset, itemsets_1st, min_sup)

    all_freq_sets = [freq_sets_1st]    # 所有的频繁项集，以阶数为分隔

    while len(all_freq_sets[-1]) > 0:    # 不断寻找高阶频繁集，直到产生空集
        cur_itemsets = extend_itemset(all_freq_sets[-1])    # 对最后一阶频繁集升阶，产生高阶项集
        cur_freq_sets, cur_sup_dict = itemset_filter(
            dataset, cur_itemsets, min_sup)
        sup_dict.update(cur_sup_dict)
        all_freq_sets.append(cur_freq_sets)

    return all_freq_sets, sup_dict


# all_freq_sets, sup_dict = apriori(dataset)

In [6]:
# all_freq_sets

In [7]:
# sup_dict

## 关联规则
有了能计算频繁项集的方法，接下来就是在频繁项集中找到关联规则。为简单起见，首先计算$2$阶频繁项集可能产生的所有规则。

In [8]:
def compute_conf(freq_set, conseqs, sup_dict, min_conf=0.5):
    '''
    计算频繁项集中所有单后果规则的置信度
    freq_set: 单个频繁项集
    conseqs: 可能的后果列表
    '''
    strong_conseqs = list()    # 强后果列表，返回用于扩展高阶后果
    strong_rules = list()    # 规则列表

    for cons in conseqs:    # 计算所有可能后果的规则强度
        cur_conf = sup_dict[freq_set]/sup_dict[freq_set-cons]    # 共同置信度除以前因置信度

        if cur_conf >= min_conf:
            # 追加(前因,后果,置信度)三元组
            strong_rules.append((freq_set-cons, cons, cur_conf))
            strong_conseqs.append(cons)

    return strong_rules, strong_conseqs


# freq_set = all_freq_sets[1][0]    # 2阶频繁项集列表中的首个项集
# single_conseqs = [frozenset([item])
#                   for item in freq_set]    # 所有可能的1阶后果，用于查询支持度
# compute_conf(freq_set, single_conseqs, sup_dict)

同频繁项集的发现一样，规则也需要扩充，并且是从后果的角度来扩充。由上述函数可以得到所有$1$阶后果的强规则，那么如何得到$2$阶后果甚至更高阶后果的强规则？需要利用```extend_itemset```函数来对已有的后果列表进行升阶，然后在高阶后果列表中取出可能的后果加入已有后果。

In [9]:
# 因为这里使用递归手段，所以有些参数和变量需要从函数中独立出来


def extend_rules(freq_set, conseqs, sup_dict, tol_rules, min_conf=0.5):
    conseq_dim = len(conseqs[0])    # 后果的阶数

    if len(freq_set) > conseq_dim+1:    # 只有2阶以上的频繁项集才有可能有高阶后果
        high_level_conseqs = extend_itemset(conseqs)    # 后果升阶
        cur_rules, high_level_conseqs = compute_conf(
            freq_set, high_level_conseqs, sup_dict, min_conf)
        tol_rules.extend(cur_rules)

        if len(high_level_conseqs) > 1:
            extend_rules(freq_set, high_level_conseqs, sup_dict,tol_rules, min_conf)


# tol_rules = list()    # 记录所有规则的列表
# freq_set = all_freq_sets[2][0]    # 3阶频繁项集列表中的首个项集
# single_conseqs = [frozenset([item])
#                   for item in freq_set]
# extend_rules(freq_set, single_conseqs, sup_dict, tol_rules)

In [10]:
# tol_rules

接下来的任务就是由低阶往高阶遍历所有频繁项集列表，并发掘强规则。

In [11]:
def find_rules(all_freq_sets, sup_dict, min_conf=0.5):
    tol_rules = list()
    for i in range(1, len(all_freq_sets)):    # 跳过1阶频繁项集，因为1阶项集无法产生规则
        for freq_set in all_freq_sets[i]:
            conseqs = [frozenset([item]) for item in freq_set]    # 初始可能的1阶后果列表
            if i > 1:
                extend_rules(freq_set, conseqs, sup_dict, tol_rules, min_conf)
            else:
                cur_rules, _ = compute_conf(
                    freq_set, conseqs, sup_dict, min_conf)
                tol_rules.extend(cur_rules)

    return tol_rules

# find_rules(all_freq_sets,sup_dict,min_conf=0.6)

In [12]:
# all_freq_sets, sup_dict = apriori(dataset,min_sup=0.5)
# find_rules(all_freq_sets,sup_dict,min_conf=0.5)

## Toydataset

In [13]:
data=[line.split() for line in open('datasets/mushroom/agaricus-lepiota.data').readlines()]
all_freq_sets, sup_dict = apriori(data,min_sup=0.7)
find_rules(all_freq_sets,sup_dict,min_conf=0.7)

[(frozenset({'34'}), frozenset({'36'}), 0.834217841799343),
 (frozenset({'36'}), frozenset({'34'}), 0.9691720493247211),
 (frozenset({'34'}), frozenset({'85'}), 1.0),
 (frozenset({'85'}), frozenset({'34'}), 0.9741506646971935),
 (frozenset({'34'}), frozenset({'86'}), 0.9989891331817033),
 (frozenset({'86'}), frozenset({'34'}), 0.997728419989904),
 (frozenset({'34'}), frozenset({'90'}), 0.9219105382865808),
 (frozenset({'90'}), frozenset({'34'}), 0.9743589743589743),
 (frozenset({'36'}), frozenset({'85'}), 1.0),
 (frozenset({'85'}), frozenset({'36'}), 0.8385032003938946),
 (frozenset({'86'}), frozenset({'36'}), 0.8354366481574962),
 (frozenset({'36'}), frozenset({'86'}), 0.9718144450968879),
 (frozenset({'36'}), frozenset({'90'}), 0.9489136817381092),
 (frozenset({'90'}), frozenset({'36'}), 0.8632478632478632),
 (frozenset({'86'}), frozenset({'85'}), 1.0),
 (frozenset({'85'}), frozenset({'86'}), 0.9753815854258986),
 (frozenset({'90'}), frozenset({'85'}), 1.0),
 (frozenset({'85'}), froz