In [5]:
from typing import List, Dict, FrozenSet

from graphviz import Digraph

In [6]:
def parse_data(data: List[str]) -> Dict[FrozenSet[str], int]:
    """
    Takes an array of transactions data and remove duplicate items in each
    row as well as removing duplicate rows.
    
    Parameters
    ----------
    data: [transactions:items]
        An array of items/transaction ids to identify the items.
    
    Returns
    -------
    result: 
        A dictionary of the compressed data.
    """
    result = {}
    for row in data:
        # Remove duplicate items from the row.
        # Also automatically orders the items alphabetically.
        unique = frozenset(row)
        # Remove duplicate rows.
        result[unique] = result.get(unique, 0) + 1
    return result

In [7]:
data = [['r', 'z', 'h', 'j', 'p'],
        ['z', 'y', 'x', 'w', 'v', 'u', 't', 's'],
        ['z'],
        ['r', 'x', 'n', 'o', 's'],
        ['y', 'r', 'x', 'z', 'q', 't', 'p'],
        ['y', 'z', 'x', 'e', 'q', 's', 't', 'm']]
data = parse_data(data)
data

{frozenset({'h', 'j', 'p', 'r', 'z'}): 1,
 frozenset({'s', 't', 'u', 'v', 'w', 'x', 'y', 'z'}): 1,
 frozenset({'z'}): 1,
 frozenset({'n', 'o', 'r', 's', 'x'}): 1,
 frozenset({'p', 'q', 'r', 't', 'x', 'y', 'z'}): 1,
 frozenset({'e', 'm', 'q', 's', 't', 'x', 'y', 'z'}): 1}

In [8]:
class Node:
    count = 0
    def __init__(self, item, parent, count=1):
        self.item = item
        self.count = count
        self.children = {}
        self.parent = parent
        self.node_link = None

    def incr(self, count: int):
        self.count += count
        
    def print(self, depth=0, spacing=2):
        print(' ' * depth * spacing, str(self))
        for child in self.children.values():
            child.print(depth + 1)
            
    def __str__(self):
        return f'{self.item}:{self.count}'
    
    def last(self):
        """
        Traverse through all the node_link and return the last node."""
        curr = self
        while curr.node_link is not None:
            curr = curr.node_link
        return curr

#     def print_graphviz(self):
#         dot = Digraph()
#         queue = [self]

#         while len(queue) > 0:
#             edges, queue = queue[0], queue[1:]
#             if edges is None: continue
#             for edge in edges.children.values():
#                 dot.edge(str(edges), str(edge))
#                 queue.append(edge)

#         return dot

In [9]:
class FrequentItemHeader:
    """
    FrequentItemHeader stores the item name
    and the node link pointing to the node with
    the same name.
    """
    def __init__(self, frequent_items):
        self.node_links = {}
        self.frequent_items = frequent_items
        
    def link(self, node):
        if node is None: return
        if node.item in self.node_links:
            last_node = self.node_links[node.item].last()
            last_node.node_link = node
        else:
            self.node_links[node.item] = node

In [10]:
def construct_fp_tree(data, min_support=3):
    # Collect the frequent items.
    freq_items = {}
    for row, count in data.items():
        for item in row:
            freq_items[item] = freq_items.get(item, 0) + count
            
    # Remove items that are below the minimum support.
    for key in list(freq_items):
        if freq_items[key] < min_support:
            del(freq_items[key])

    # Sort by value (count), followed by the key. 
    # This key sorting does not have impact on the result,
    # a (...lambda t: t[1], reverse=True) will produce the 
    # same outcome.
    frequent_item_list = sorted(freq_items.items(), 
                                key=lambda t: (-t[1], t[0]))

    # Create the header that links similar nodes.
    frequent_item_header = FrequentItemHeader(frequent_item_list)
    
    def insert_tree(tree, items, count):
        """
        Inserts the next item to the current tree as the next node.
        
        Parameters
        ----------
        tree: Node
            The parent node.
        items:
            The items in the current transaction.
        count: int
            The number of times the transactions appear in the database.
        """
        if len(items) == 0: return
        item, items = items[0], items[1:]
        if item in tree.children:
            # Increment the count if the children exist.
            tree.children[item].incr(count)
        else:
            # Create a new node and link back to the parent.
            tree.children[item] = Node(item, tree, count)
            
            # Node-link the header.
            frequent_item_header.link(tree.children[item])
        
        # If there are items, recursively add them.
        if len(items) > 0:
            insert_tree(tree.children[item], items, count)

    # Create the root of the FP-tree.
    root = Node(None, None)
    for row, count in data.items():
        # Sort them in ascending order to the frequent items list.
        items = [item 
                 for item, count in frequent_item_list
                 if item in row]
        insert_tree(root, items, count)
    return root, frequent_item_header

In [11]:
def ascend_tree(node):
    paths = []
    
    while node and node.parent:
        paths.append(node.item)
        node = node.parent

    return paths
    
def traverse_node_link(node):
    conditional_patterns = {}

    while node is not None:
        patterns = ascend_tree(node)
        freqitems = frozenset(patterns[1:])
        conditional_patterns[freqitems] = conditional_patterns.get(freqitems, 0) + node.count
        node = node.node_link

    return conditional_patterns

# Start mining from the bottom of the tree.
def mine_tree(header, tree, prefix_base=set(), result=[], minimum_support=3):
    # Traverse from the bottom - the ones with the least count.
    prefixes = sorted(header.frequent_items,
                      key=lambda t: t[1])

    for prefix, _count in prefixes:
        # Create a copy so that the original prefix is not affected.
        prefixset = prefix_base.copy()
        prefixset.add(prefix)
        
        # Found a pattern.
        result.append(prefixset)
        
        node_link = header.node_links[prefix]
        conditional_patterns = traverse_node_link(node_link)
        subtree, subheader = construct_fp_tree(conditional_patterns, minimum_support)
        mine_tree(subheader, subtree, prefixset, result)

In [12]:
tree, header = construct_fp_tree(data)
# tree.print_graphviz()

In [13]:
tree.print()

 None:1
   z:5
     r:1
     x:3
       s:2
         t:2
           y:2
       r:1
         t:1
           y:1
   x:1
     r:1
       s:1


In [14]:
header.frequent_items

[('z', 5), ('x', 4), ('r', 3), ('s', 3), ('t', 3), ('y', 3)]

In [15]:
result = []
mine_tree(header, tree, set(), result)

target = [{'r'},
  {'s'},
  {'s', 'x'},
  {'y'},
  {'x', 'y'},
  {'y', 'z'},
  {'x', 'y', 'z'},
  {'t'},
  {'t', 'x'},
  {'t', 'y'},
  {'t', 'x', 'y'},
  {'t', 'z'},
  {'t', 'x', 'z'},
  {'t', 'y', 'z'},
  {'t', 'x', 'y', 'z'},
  {'x'},
  {'x', 'z'},
  {'z'}]

def compare_sets(a, b):
    return all([i 
                for i in a 
                if i in b])
compare_sets(result, target) 

True

In [16]:
data = [['a', 'b'],
        ['b', 'c', 'd'],
        ['a', 'c', 'd', 'e'],
        ['a', 'd', 'e'],
        ['a', 'b', 'c'],
        ['a', 'b', 'c', 'd'],
        ['b', 'c'],
        ['a', 'b', 'c'],
        ['a', 'b', 'd'],
        ['b', 'c', 'e']]
data = parse_data(data)
tree, header = construct_fp_tree(data)
result = []
mine_tree(header, tree, set(), result)
result

target = [{'e'}, 
          {'d', 'e'},
          {'a', 'd', 'e'},
          {'c', 'e'},
          {'a', 'e'},
          {'d'},
          {'c', 'd'},
          {'b', 'c', 'd'},
          {'a', 'c', 'd'},
          {'b', 'd'},
          {'a', 'b', 'd'},
          {'a', 'd'},
          {'c'},
          {'b', 'c'},
          {'a', 'b', 'c'},
          {'a', 'c'},
          {'b'},
          {'a', 'b'},
          {'a'}]
compare_sets(result, target)

True

In [17]:
tree.print()

 None:1
   b:8
     a:5
       c:3
         d:1
       d:1
     c:3
       d:1
       e:1
   a:2
     c:1
       d:1
         e:1
     d:1
       e:1
