In [3]:
import math
import graphviz
import math
import copy
import sys
import arff
import pprint

pp = pprint.PrettyPrinter(depth = 6)

------
### Algorytm TABU
<img src="img/tabu.png" style="height: 350px">

In [8]:
class TABUAlgorithm(object):
    def __init__(self, attributes, sample_data, scoring_method, max_number_of_parents, number_of_iterations):
        self.attributes = attributes
        self.sample_data = sample_data
        self.scoring_method = scoring_method
        self.max_number_of_parents = max_number_of_parents
        self.number_of_iterations = number_of_iterations
        self.tabu_length = len(attributes) - 1
        
    def compute_optimal_net(self):
        iteration_counter = 0
        network = self.prepare_initial_network()
        tabu_list = []
        
        while(iteration_counter < self.number_of_iterations):
            #print("iteration", iteration_counter)
            iteration_counter += 1
            (network, tabu_list) = self.find_optimal_connection(network, tabu_list)
            print()
        
        return network
            
    
    def find_optimal_connection(self, network, tabu_list):
        temp_network = copy.deepcopy(network)
        best_network = copy.deepcopy(network)
        best_parent = ""
        best_score = -sys.maxsize -1
        
        for first_node in temp_network:
            for second_node in temp_network:
                if first_node['name'] != second_node['name'] and not self.check_if_node_already_has_parent(first_node['name'], second_node['name'], best_network) and not self.is_in_tabu(first_node['name'], tabu_list) and not self.has_max_parents(second_node['name'], best_network):

                    candidate_network = self.add_parent_to_node(second_node['name'], first_node, copy.deepcopy(temp_network))
                    score = self.compute_metric(candidate_network)
                    #print(first_node['name'], "→", second_node['name'], ":", score)
                    
                    if (score > best_score):
                        best_score = score
                        best_parent = first_node['name']
                        best_network = candidate_network
                        
        tabu_list.append(best_parent)
        #print("best score", best_score)
        return (best_network, tabu_list)
                    
    def add_parent_to_node(self, node_name, parent, net):
        for node in net:
            if node['name'] == node_name:
                node['parents'].append({'name': parent['name'], 'q': parent['r']})
        return net
    
    def is_in_tabu(self, node_name, tabu):
        if node_name in tabu:
            return True
        return False
    
    def has_max_parents(self, node_name, net):
        for node in net:
            if node['name'] == node_name:
                if len(node['parents']) >= self.max_number_of_parents:
                    return True
                else:
                    return False
        raise ValueError(self.node_name + " is not found in given net")

    def prepare_initial_network(self):
        initial_bayes_network = []
        for attribute in self.attributes:
            initial_bayes_network.append({'r': attribute['states'], 'name': attribute['name'], 'parents': []})
        return initial_bayes_network
    
    def check_if_node_already_has_parent(self, node_name, parent_name, net):
        for node in net:
            if node['name'] == node_name:
                for parent in node['parents']:
                    if parent['name'] == parent_name:
                        return True
        return False
    
    def compute_metric(self, net):
        if self.scoring_method == 'aic':
            return AICMetric().compute_aic_metric(net, self.sample_data)
        elif self.scoring_method == 'mdl':
            return MDLMetric().compute_mdl_metric(net, self.sample_data)
        elif self.scoring_method == 'bayes':
            return BayesianMetric(net, self.sample_data).compute_bayesian_metric()
        
        raise ValueError(self.scoring_method + " is not a valid scoring method!")