<a href="https://colab.research.google.com/github/BlackCurrantDS/DBSE_Project/blob/main/Part1_b)Generate_Rules_FPGrowth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import itertools


class FPNode(object):
    """
    A node in the FP tree.
    """

    def __init__(self, value, count, parent):
        """
        Create the node.
        """
        self.value = value
        self.count = count
        self.parent = parent
        self.link = None
        self.children = []

    def has_child(self, value):
        """
        Check if node has a particular child node.
        """
        for node in self.children:
            if node.value == value:
                return True

        return False

    def get_child(self, value):
        """
        Return a child node with a particular value.
        """
        for node in self.children:
            if node.value == value:
                return node

        return None

    def add_child(self, value):
        """
        Add a node as a child node.
        """
        child = FPNode(value, 1, self)
        self.children.append(child)
        return child


class FPTree(object):
    """
    A frequent pattern tree.
    """

    def __init__(self, transactions, threshold, root_value, root_count):
        """
        Initialize the tree.
        """
        self.frequent = self.find_frequent_items(transactions, threshold)
        self.headers = self.build_header_table(self.frequent)
        self.root = self.build_fptree(
            transactions, root_value,
            root_count, self.frequent, self.headers)

    @staticmethod
    def find_frequent_items(transactions, threshold):
        """
        Create a dictionary of items with occurrences above the threshold.
        """
        items = {}

        for transaction in transactions:
            for item in transaction:
                if item in items:
                    items[item] += 1
                else:
                    items[item] = 1

        for key in list(items.keys()):
            if items[key] < threshold:
                del items[key]

        return items

    @staticmethod
    def build_header_table(frequent):
        """
        Build the header table.
        """
        headers = {}
        for key in frequent.keys():
            headers[key] = None

        return headers

    def build_fptree(self, transactions, root_value,
                     root_count, frequent, headers):
        """
        Build the FP tree and return the root node.
        """
        root = FPNode(root_value, root_count, None)

        for transaction in transactions:
            sorted_items = [x for x in transaction if x in frequent]
            sorted_items.sort(key=lambda x: frequent[x], reverse=True)
            if len(sorted_items) > 0:
                self.insert_tree(sorted_items, root, headers)

        return root

    def insert_tree(self, items, node, headers):
        """
        Recursively grow FP tree.
        """
        first = items[0]
        child = node.get_child(first)
        if child is not None:
            child.count += 1
        else:
            # Add new child.
            child = node.add_child(first)

            # Link it to header structure.
            if headers[first] is None:
                headers[first] = child
            else:
                current = headers[first]
                while current.link is not None:
                    current = current.link
                current.link = child

        # Call function recursively.
        remaining_items = items[1:]
        if len(remaining_items) > 0:
            self.insert_tree(remaining_items, child, headers)

    def tree_has_single_path(self, node):
        """
        If there is a single path in the tree,
        return True, else return False.
        """
        num_children = len(node.children)
        if num_children > 1:
            return False
        elif num_children == 0:
            return True
        else:
            return True and self.tree_has_single_path(node.children[0])

    def mine_patterns(self, threshold):
        """
        Mine the constructed FP tree for frequent patterns.
        """
        if self.tree_has_single_path(self.root):
            return self.generate_pattern_list()
        else:
            return self.zip_patterns(self.mine_sub_trees(threshold))

    def zip_patterns(self, patterns):
        """
        Append suffix to patterns in dictionary if
        we are in a conditional FP tree.
        """
        suffix = self.root.value

        if suffix is not None:
            # We are in a conditional tree.
            new_patterns = {}
            for key in patterns.keys():
                new_patterns[tuple(sorted(list(key) + [suffix]))] = patterns[key]

            return new_patterns

        return patterns

    def generate_pattern_list(self):
        """
        Generate a list of patterns with support counts.
        """
        patterns = {}
        items = self.frequent.keys()

        # If we are in a conditional tree,
        # the suffix is a pattern on its own.
        if self.root.value is None:
            suffix_value = []
        else:
            suffix_value = [self.root.value]
            patterns[tuple(suffix_value)] = self.root.count

        for i in range(1, len(items) + 1):
            for subset in itertools.combinations(items, i):
                pattern = tuple(sorted(list(subset) + suffix_value))
                patterns[pattern] = \
                    min([self.frequent[x] for x in subset])

        return patterns

    def mine_sub_trees(self, threshold):
        """
        Generate subtrees and mine them for patterns.
        """
        patterns = {}
        mining_order = sorted(self.frequent.keys(),
                              key=lambda x: self.frequent[x])

        # Get items in tree in reverse order of occurrences.
        for item in mining_order:
            suffixes = []
            conditional_tree_input = []
            node = self.headers[item]

            # Follow node links to get a list of
            # all occurrences of a certain item.
            while node is not None:
                suffixes.append(node)
                node = node.link

            # For each occurrence of the item, 
            # trace the path back to the root node.
            for suffix in suffixes:
                frequency = suffix.count
                path = []
                parent = suffix.parent

                while parent.parent is not None:
                    path.append(parent.value)
                    parent = parent.parent

                for i in range(frequency):
                    conditional_tree_input.append(path)

            # Now we have the input for a subtree,
            # so construct it and grab the patterns.
            subtree = FPTree(conditional_tree_input, threshold,
                             item, self.frequent[item])
            subtree_patterns = subtree.mine_patterns(threshold)

            # Insert subtree patterns into main patterns dictionary.
            for pattern in subtree_patterns.keys():
                if pattern in patterns:
                    patterns[pattern] += subtree_patterns[pattern]
                else:
                    patterns[pattern] = subtree_patterns[pattern]

        return patterns


def find_frequent_patterns(transactions, support_threshold):
    """
    Given a set of transactions, find the patterns in it
    over the specified support threshold.
    """
    tree = FPTree(transactions, support_threshold, None, None)
    return tree.mine_patterns(support_threshold)


def generate_association_rules(patterns, confidence_threshold):
    """
    Given a set of frequent itemsets, return a dict
    of association rules in the form
    {(left): ((right), confidence)}
    """
    rules = {}
    for itemset in patterns.keys():
        upper_support = patterns[itemset]

        for i in range(1, len(itemset)):
            for antecedent in itertools.combinations(itemset, i):
                antecedent = tuple(sorted(antecedent))
                consequent = tuple(sorted(set(itemset) - set(antecedent)))

                if antecedent in patterns:
                    lower_support = patterns[antecedent]
                    confidence = float(upper_support) / lower_support

                    if confidence >= confidence_threshold:
                        rules[antecedent] = (consequent, confidence)

    return rules

In [11]:
import pandas as pd
df = pd.read_csv('/content/breast_train_transactions.txt', header = None, sep="," , names=['a','b','c', 'd', 'e', 'f','g', 'h', 'i', 'k'])

In [12]:
df.head(5)

Unnamed: 0,a,b,c,d,e,f,g,h,i,k
0,class@no,a1@30-39,a2@premeno,a3@30-34,a4@0-2,a5@no,a6@3,a7@left,a8@left_low,a9@no
1,class@no,a1@40-49,a2@premeno,a3@20-24,a4@0-2,a5@no,a6@2,a7@right,a8@right_up,a9@no
2,class@no,a1@40-49,a2@premeno,a3@20-24,a4@0-2,a5@no,a6@2,a7@left,a8@left_low,a9@no
3,class@no,a1@60-69,a2@ge40,a3@15-19,a4@0-2,a5@no,a6@2,a7@right,a8@left_up,a9@no
4,class@no,a1@40-49,a2@premeno,a3@0-4,a4@0-2,a5@no,a6@2,a7@right,a8@right_low,a9@no


In [13]:
trans = df.values.tolist()

In [14]:
patterns = find_frequent_patterns(trans, .01)

In [15]:
patterns

{('a8@?',): 1,
 ('a3@30-34', 'a8@?'): 1,
 ('a6@3', 'a8@?'): 1,
 ('a8@?', 'class@yes'): 1,
 ('a1@50-59', 'a8@?'): 1,
 ('a2@ge40', 'a8@?'): 1,
 ('a7@left', 'a8@?'): 1,
 ('a4@0-2', 'a8@?'): 1,
 ('a8@?', 'a9@no'): 1,
 ('a5@no', 'a8@?'): 1,
 ('a3@30-34', 'a6@3', 'a8@?'): 1,
 ('a3@30-34', 'a8@?', 'class@yes'): 1,
 ('a1@50-59', 'a3@30-34', 'a8@?'): 1,
 ('a2@ge40', 'a3@30-34', 'a8@?'): 1,
 ('a3@30-34', 'a7@left', 'a8@?'): 1,
 ('a3@30-34', 'a4@0-2', 'a8@?'): 1,
 ('a3@30-34', 'a8@?', 'a9@no'): 1,
 ('a3@30-34', 'a5@no', 'a8@?'): 1,
 ('a6@3', 'a8@?', 'class@yes'): 1,
 ('a1@50-59', 'a6@3', 'a8@?'): 1,
 ('a2@ge40', 'a6@3', 'a8@?'): 1,
 ('a6@3', 'a7@left', 'a8@?'): 1,
 ('a4@0-2', 'a6@3', 'a8@?'): 1,
 ('a6@3', 'a8@?', 'a9@no'): 1,
 ('a5@no', 'a6@3', 'a8@?'): 1,
 ('a1@50-59', 'a8@?', 'class@yes'): 1,
 ('a2@ge40', 'a8@?', 'class@yes'): 1,
 ('a7@left', 'a8@?', 'class@yes'): 1,
 ('a4@0-2', 'a8@?', 'class@yes'): 1,
 ('a8@?', 'a9@no', 'class@yes'): 1,
 ('a5@no', 'a8@?', 'class@yes'): 1,
 ('a1@50-59', 'a2@ge

In [16]:
#writing patterns
with open('fp_growth_freqpatterns', 'w') as f:
  for key, value in patterns.items():
      lhs= [''.join(map(str, x)) for x in key]
      rhs=value
      print(lhs,":",rhs,file=f) 



In [17]:
#removing brackets
f = open('/content/fp_growth_freqpatterns','r')
a = ["[","]","'"," "]
lst = []
for line in f:
    for word in a:
        if word in line:
            line = line.replace(word,'')
    lst.append(line)
f.close()
f = open('/content/fp_growth_freqpatterns','w')
for line in lst:
    f.write(line)
f.close()

In [18]:
rules = generate_association_rules(patterns, 0.8)
rules

{('a8@?',): (('a1@50-59',
   'a2@ge40',
   'a3@30-34',
   'a4@0-2',
   'a5@no',
   'a6@3',
   'a7@left',
   'a9@no',
   'class@yes'),
  1.0),
 ('a3@30-34',
  'a8@?'): (('a1@50-59',
   'a2@ge40',
   'a4@0-2',
   'a5@no',
   'a6@3',
   'a7@left',
   'a9@no',
   'class@yes'), 1.0),
 ('a6@3',
  'a8@?'): (('a1@50-59',
   'a2@ge40',
   'a3@30-34',
   'a4@0-2',
   'a5@no',
   'a7@left',
   'a9@no',
   'class@yes'), 1.0),
 ('a8@?',
  'class@yes'): (('a1@50-59',
   'a2@ge40',
   'a3@30-34',
   'a4@0-2',
   'a5@no',
   'a6@3',
   'a7@left',
   'a9@no'), 1.0),
 ('a1@50-59',
  'a8@?'): (('a2@ge40',
   'a3@30-34',
   'a4@0-2',
   'a5@no',
   'a6@3',
   'a7@left',
   'a9@no',
   'class@yes'), 1.0),
 ('a2@ge40',
  'a8@?'): (('a1@50-59',
   'a3@30-34',
   'a4@0-2',
   'a5@no',
   'a6@3',
   'a7@left',
   'a9@no',
   'class@yes'), 1.0),
 ('a7@left',
  'a8@?'): (('a1@50-59',
   'a2@ge40',
   'a3@30-34',
   'a4@0-2',
   'a5@no',
   'a6@3',
   'a9@no',
   'class@yes'), 1.0),
 ('a4@0-2',
  'a8@?'): (('a1@5

In [19]:
with open('fpgrowth_out.txt', 'w') as f:
    print(rules, file=f)

In [20]:
#all rules which has class labels@yes in rhs and no class labels in lhs
for key, value in rules.items():
  if ('class@no' or  'class@yes') not in key and len(value[0])==1 and "class@yes" in value[0] : 
    print(key)
    print(value)

('a1@50-59', 'a2@ge40', 'a3@30-34', 'a4@0-2', 'a5@no', 'a6@3', 'a7@left', 'a9@no')
(('class@yes',), 1.0)
('a1@50-59', 'a2@ge40', 'a3@30-34', 'a4@0-2', 'a5@no', 'a6@3', 'a7@left', 'a8@?', 'a9@no')
(('class@yes',), 1.0)
('a2@ge40', 'a3@20-24', 'a5@yes', 'a8@left_low')
(('class@yes',), 1.0)
('a2@ge40', 'a3@20-24', 'a5@yes', 'a7@left', 'a8@left_low')
(('class@yes',), 1.0)
('a3@20-24', 'a6@3', 'a7@left', 'a8@left_low', 'a9@yes')
(('class@yes',), 1.0)
('a3@20-24', 'a5@yes', 'a6@3', 'a7@left', 'a8@left_low', 'a9@yes')
(('class@yes',), 1.0)
('a2@ge40', 'a3@20-24', 'a5@yes', 'a6@3', 'a7@left', 'a8@left_low')
(('class@yes',), 1.0)
('a2@ge40', 'a3@20-24', 'a6@3', 'a7@left', 'a8@left_low', 'a9@yes')
(('class@yes',), 1.0)
('a2@ge40', 'a5@yes', 'a6@3', 'a7@left', 'a8@left_low', 'a9@yes')
(('class@yes',), 1.0)
('a1@60-69', 'a2@ge40', 'a6@3', 'a7@left', 'a8@left_low', 'a9@yes')
(('class@yes',), 1.0)
('a2@ge40', 'a3@20-24', 'a5@yes', 'a6@3', 'a7@left', 'a8@left_low', 'a9@yes')
(('class@yes',), 1.0)
('a

In [21]:
#all rules which has class labels@no in rhs and no class labels in lhs
for key, value in rules.items():
  if ('class@no' or  'class@yes') not in key and len(value[0])==1 and  "class@no" in value[0] : 
    print(key)
    print(value)

('a1@40-49', 'a2@premeno', 'a3@15-19', 'a5@no', 'a6@3', 'a7@right')
(('class@no',), 1.0)
('a2@premeno', 'a3@15-19', 'a5@no', 'a6@3', 'a7@right', 'a9@yes')
(('class@no',), 1.0)
('a1@40-49', 'a2@premeno', 'a3@15-19', 'a5@no', 'a6@3', 'a7@right', 'a9@yes')
(('class@no',), 1.0)
('a1@40-49', 'a2@premeno', 'a5@no', 'a6@3', 'a7@right', 'a8@right_low', 'a9@yes')
(('class@no',), 1.0)
('a1@40-49', 'a2@premeno', 'a3@15-19', 'a5@no', 'a6@3', 'a7@right', 'a8@right_low', 'a9@yes')
(('class@no',), 1.0)
('a1@40-49', 'a2@premeno', 'a3@15-19', 'a4@12-14', 'a5@no', 'a6@3', 'a7@right', 'a8@right_low', 'a9@yes')
(('class@no',), 1.0)
('a1@70-79', 'a2@ge40', 'a4@0-2', 'a5@no', 'a6@1', 'a9@no')
(('class@no',), 1.0)
('a2@ge40', 'a3@40-44', 'a4@0-2', 'a5@no', 'a6@1', 'a7@right', 'a9@no')
(('class@no',), 1.0)
('a2@ge40', 'a3@40-44', 'a4@0-2', 'a5@no', 'a6@1', 'a7@right', 'a8@right_up', 'a9@no')
(('class@no',), 1.0)
('a1@70-79', 'a2@ge40', 'a3@40-44', 'a4@0-2', 'a5@no', 'a6@1', 'a7@right', 'a8@right_up', 'a9@no')

In [22]:
#write to file
#all rules which has class labels@yes in rhs and no class labels in lhs
ls = [">"]
with open('class@no_fpgrowthout.txt', 'w') as f:
  for key, value in rules.items():
    if ('class@no' or  'class@yes') not in key and len(value[0])==1 and "class@no" in value[0] :
      lhs= [''.join(map(str, x)) for x in key]
      rhs=''.join(map(str, value[0]))
      print(lhs,">",rhs.replace(",",""),file=f) 

In [23]:
#removing brackets
f = open('/content/class@no_fpgrowthout.txt','r')
a = ["[","]","'"," "]
lst = []
for line in f:
    for word in a:
        if word in line:
            line = line.replace(word,'')
    lst.append(line)
f.close()
f = open('/content/class@no_fpgrowthout.txt','w')
for line in lst:
    f.write(line)
f.close()

In [24]:
#for class@yes


In [25]:
#write to file
#all rules which has class labels@yes in rhs and no class labels in lhs
ls = [">"]
with open('class@yes_fpgrowthout.txt', 'w') as f:
  for key, value in rules.items():
    if ('class@no' or  'class@yes') not in key and len(value[0])==1 and "class@yes" in value[0] :
      lhs= [''.join(map(str, x)) for x in key]
      rhs=''.join(map(str, value[0]))
      print(lhs,">",rhs.replace(",",""),file=f) 

In [26]:
#removing brackets
f = open('/content/class@yes_fpgrowthout.txt','r')
a = ["[","]","'"," "]
lst = []
for line in f:
    for word in a:
        if word in line:
            line = line.replace(word,'')
    lst.append(line)
f.close()
f = open('/content/class@yes_fpgrowthout.txt','w')
for line in lst:
    f.write(line)
f.close()