In [23]:
from itertools import combinations

# fake_dataset = [
#     ["f", "c", "a", "m", "p"],
#     ["f", "c", "a", "b", "m"],
#     ["f", "b"],
#     ["c", "b", "p"],
#     ["f", "c", "a", "m", "p"]
# ]

fake_dataset = [
    ["milk", "bread", "beer"],
    ["bread", "coffee"],
    ["bread", "egg"],
    ["milk", "bread", "coffee"],
    ["milk", "egg"],
    ["bread", "egg"],
    ["milk", "egg"],
    ["milk", "bread", "egg", "beer"],
    ["milk", "bread", "egg"],
]


class node:
    def __init__(self, item, count=1):
        self.item = item
        self.count = count
        self.parent = None
        self.children = list()

    def __str__(self):
        return f'id={id(self)}, item={self.item}, count={self.count}, parent={self.parent.item}, children_num={len(self.children)}'


# nx3 list, [0]=Customer ID, [1]=Transaction ID, [2] Item ID
def load_data(filename):
    dataset = list()
    with open(filename) as f:
        tid = 1  # transaction id
        temp_item = list()
        for i in f.readlines():
            item = i.replace("\n", "").split(",")
            if tid < int(item[1]):
                tid = int(item[1])
                dataset.append(temp_item.copy())
                temp_item.clear()
            temp_item.append(item[2])
    # for i in dataset:
    #     print(i)
    return dataset


def first_scan(dataset):
    weights = dict()
    for transaction in dataset:
        for item in transaction:
            if weights.get(item):
                weights[item] += 1
            else:
                weights[item] = 1
    return weights


def reorder(dataset, weights, min_sup):
    for tid in range(len(dataset)):
        dataset[tid].sort(key=lambda x: weights[x], reverse=True)
        dataset[tid] = [item for item in dataset[tid] if weights[item] >= min_sup]
        

def create_tree(dataset, root=node(None), count=1):
    pre_node = root
    for transaction in dataset:
        for item in transaction:
            for c in pre_node.children:
                if item == c.item:
                    c.count += count
                    pre_node = c
                    break
            else:  # new node
                print("new node:", item, "parent:", pre_node.item)
                current_node = node(item, count)
                current_node.parent = pre_node
                pre_node.children.append(current_node)
                pre_node = current_node
                
        pre_node = root
        print("========================================================")
        show_tree(root)  # current tree after insert one transaction
        print("========================================================")

    return root


def create_header_table(node, header_table=dict()):
    if node.item != None:  # skip root node
        # create header table link
        if header_table.get(node.item):
            header_table[node.item].append(node)
        else:
            header_table[node.item] = [node]
    for c in node.children:
        create_header_table(c, header_table)
    return header_table


def find_path(header_table, target):
    path_dict = dict()
    for k, v in header_table.items():
        if k == target:
            for node in v:
                path = list()
                parent = node.parent
                while parent.item != None:  # get parent
                    path.append(parent.item)
                    parent = parent.parent
                path.reverse()
                if len(path) != 0:  # first child of root
                    path_dict[tuple(path)] = node.count
            return path_dict


def mine_tree(path_dict):
    # sort dict by value
    combination_dataset, counts = zip(*[(dict[0], dict[1])
                                      for dict in sorted(path_dict.items(), key=lambda x:x[1], reverse=True)])
    print("@@@combination_dataset@@@", combination_dataset)
    print("@@@counts@@@", counts)

    root = node(None)
    for transaction, count in zip(combination_dataset, counts):
        root = create_tree([transaction], root=root, count=count)
    print("@@@Combination FP tree@@@")
    show_tree(root)
    # header_table = create_header_table(root, dict())
    # print("@@@Combination header table@@@")
    # show_header_table(header_table)


def show_tree(node):
    if node.item != None:  # skip root node
        print(node)
    for c in node.children:
        show_tree(c)


def show_header_table(header_table):
    for k, v in header_table.items():
        print("item:", k)
        for node in v:
            print(id(node))

In [24]:
# dataset = load_data("gen/output.txt")
min_sup = 2
weights = first_scan(fake_dataset)
print(weights)
print("before ordering:", fake_dataset)
reorder(fake_dataset, weights, min_sup)
print("after ordering:", fake_dataset)

root = create_tree(fake_dataset)
print("@@@FP tree@@@")
show_tree(root)

header_table = create_header_table(root)
print("@@@header table@@@")
show_header_table(header_table)

# freq_itemset = list()
print("@@@path@@@")
path_dict = find_path(header_table, "egg")
for k, v in path_dict.items():
    print(k, v)
if len(path_dict)>0:
    mine_tree(path_dict)

{'milk': 6, 'bread': 7, 'beer': 2, 'coffee': 2, 'egg': 6}
before ordering: [['milk', 'bread', 'beer'], ['bread', 'coffee'], ['bread', 'egg'], ['milk', 'bread', 'coffee'], ['milk', 'egg'], ['bread', 'egg'], ['milk', 'egg'], ['milk', 'bread', 'egg', 'beer'], ['milk', 'bread', 'egg']]
after ordering: [['bread', 'milk', 'beer'], ['bread', 'coffee'], ['bread', 'egg'], ['bread', 'milk', 'coffee'], ['milk', 'egg'], ['bread', 'egg'], ['milk', 'egg'], ['bread', 'milk', 'egg', 'beer'], ['bread', 'milk', 'egg']]
new node: bread parent: None
new node: milk parent: bread
new node: beer parent: milk
id=1817060590688, item=bread, count=1, parent=None, children_num=1
id=1817060589920, item=milk, count=1, parent=bread, children_num=1
id=1817059735824, item=beer, count=1, parent=milk, children_num=0
new node: coffee parent: bread
id=1817060590688, item=bread, count=2, parent=None, children_num=2
id=1817060589920, item=milk, count=1, parent=bread, children_num=1
id=1817059735824, item=beer, count=1, pare