In [3]:
import numpy as np
from hdi import hpd_grid
from scipy.stats import beta

In [47]:
class Tree_Inference_Object():
    # We should expand this class later to support non-uniform priors? 
    
    def __init__(self, observation_array, labels, confidence, granularity): 
        
        assert len(observation_array) == len(labels)     # Every frequency count should have a name. 

        self.observation_array = observation_array      # Array to hold the actual, observed frequencies
        self.labels = labels                            # Names for each outcome type. 
        self.confidence = confidence                    # Overall confidence for the intervals.
        
        self.granularity = granularity                  # Number of random deviates to sample for the beta 
                                                        #  when computing the density.

        self.level = 1                                  # Level of the tree we are at. Root = 1

        
        self.interval_array = np.zeros(shape = (len(observation_array), 2)) # Empty array to hold the upper and 
                                                                            # lower bounds for each variable

        self.credmass_per_cell =self.confidence**(1/(len(observation_array) - 1))   # Size of the credibility region
                                                                                    # per cell. 

        self.outcome_size = len(self.observation_array)                  # Number of outcome types. 
                                                                          # 'K' for dirichlet distribution. 

        
        self.tree_height = np.ceil(np.log2(self.outcome_size))         # Height of our binary tree. 

        

    def update_obs_array(self, new_obs_array):
        """helper function for changing the frequency counts. Use care, as the intervals will have
        to be recomputed."""                  
                          
        self.observation_array = new_obs_array
        
        # Then change all of the dependent variables. 
        self.interval_array = np.zeros(shape = (len(self.observation_array), 2))
        self.credmass_per_cell =self.confidence**(1/(len(self.observation_array) - 1))
        self.outcome_size = len(self.observation_array)
        self.tree_height = np.ceil(np.log2(self.outcome_size))
        
    def update_labels_array(self, new_labels): 
        """Helper function for changing the label names"""
                          
        assert len(new_labels) == len(self.observation_array)
                          
        self.labels = new_labels
        
        
    def __str__(self): 
        output_string = "labels {} \n observations {} \n  \n".format(self.labels, self.observation_array)
        
        return(output_string)
        
    def return_normalized_interval_sample(self, num_samples): 
        """returns stochastic samples from the tree. Probs are different each time it is called."""
        
        
        # For each variable, randomly pick a point in its interval, and then normalize for all intervals
        # to get a multinomial.Then return num_samples from that multinomial. 
        prob_list = []
        for i in range(len(self.observation_array)): 
            lower = self.interval_array[i, 0 ]
            upper = self.interval_array[i, 1 ] 

            selected_prob = np.random.uniform(lower, upper)
            prob_list.append(selected_prob)
                
                
        prob_list = np.array(prob_list)
        prob_list = prob_list/np.sum(prob_list)
            
        returned_samples = np.random.choice(a = self.labels, size = num_samples, p = prob_list)
                
                
        return(returned_samples)
    
    def get_interval_sample(self, number_of_samples = 50, interval_type = "mix"):
        """Returns samples from the tree from a number of options"""
        interval_type = interval_type.lower()
        assert interval_type in ["lower", "upper", "midrange", "mix"]
        
        
        # Four options only: 
        # lower = the lower interval edge for each variable. 
        # upper = the upper interval edge for each variable.
        # midrange = the midrange for each variable
        # mix  = for each sample, randomly choose from any of the above three options with equal probability. 

        if interval_type == "lower":
            normalized_probs = self.interval_array[:, 0]/np.sum(self.interval_array[:, 0])
            return_list = np.random.choice(a = self.labels, size = number_of_samples, p = normalized_probs)
        
        elif interval_type == "upper":
            normalized_probs = self.interval_array[:, 1]/np.sum(self.interval_array[:, 1])
            return_list = np.random.choice(a = self.labels, size = number_of_samples, p = normalized_probs)

        elif interval_type == "midrange":
            normalized_probs = (self.interval_array[:, 0] + self.interval_array[:, 1])/2
            normalized_probs = normalized_probs/np.sum(normalized_probs)
            return_list = np.random.choice(a = self.labels, size = number_of_samples, p = normalized_probs)

        elif interval_type == "mix": 
            return_list = np.empty(shape = number_of_samples, dtype= np.str)
            choice_list = np.random.randint(low = 0, high = 2, size = number_of_samples) 
            for counter, choice in enumerate(choice_list): 
                
                if choice == 0: 
                    normalized_probs = self.interval_array[:, 0]/np.sum(self.interval_array[:, 0])                    
                    return_list[counter] = np.random.choice(a = self.labels, size = 1, p = normalized_probs)[0]
                    
                elif choice == 1: 
                    normalized_probs = self.interval_array[:, 1]/np.sum(self.interval_array[:, 1])
                    return_list[counter] = np.random.choice(a = self.labels, size = 1, p = normalized_probs)[0]
                    
                elif choice == 2: 
                    normalized_probs = (self.interval_array[:, 0] + self.interval_array[:, 1])/2
                    normalized_probs = normalized_probs/np.sum(normalized_probs)
                    return_list[counter] = np.random.choice(a = self.labels, size = 1, p = normalized_probs)[0]
                    
                    
        return(return_list)
    
    def __recursive_branching(self, valid_range):
        """private function for getting samples from the tree. Should only ever be called by 
        itself and get_tree_sample()"""
        
        # Code largely adapted from XXX. 
        midpoint = np.floor(np.mean([valid_range[0], valid_range[-1]]))
        left_range = self.__get_inclusive_range(valid_range[0], midpoint)
        right_range = self.__get_inclusive_range(midpoint +  1, valid_range[-1])

        # A and B for the beta distribution. 
        b = np.sum(self.observation_array[left_range]) 
        a = np.sum(self.observation_array[right_range]) 
        
       # We discount each one by our distance from the top of the tree. 
        b = b* 2**(self.level - self.tree_height)
        a = a * 2**(self.level - self.tree_height)
        

        # pseudocounts for each side of a cell. Each outcome type/variable increments them by one. 
        left_base = len(self.observation_array[left_range])
        right_base = len(self.observation_array[right_range])
        
        # Freeze the distribution, then get a single random deviate. 
        frozen_beta_dist = beta(a = (a + right_base), b = (b + left_base))
        beta_deviate = frozen_beta_dist.rvs(1).item()

        # direction is chosen stochasticly based on whatever the deviate was. 
        direction = np.random.choice(a = ["left", "right"], size= 1, p = [1 - beta_deviate, beta_deviate])

        
         # If they are right next to each other, then we are at a leaf. 
        if len(left_range) == 1 and len(right_range) == 1:            
            if direction == "left":
                return(self.labels[left_range.item()])
                
            elif direction == "right": 
                return(self.labels[right_range.item()])

            
        # if the cell only branches on side side instead of two or none. 
        elif (len(left_range) == 1 or len(right_range) == 1):
            if direction == "left":
                return(self.__recursive_branching(left_range))
            else: 
                return(self.labels[right_range.item()])
            
            
            
        else: 
            if direction == "left": 
                return(self.__recursive_branching(left_range))
                
                
            elif direction == "right": 
                return(self.__recursive_branching(right_range))

            
        
    def get_tree_sample(self, num_samples):
        """Returns a sample from the actual tree. For each sample, we walk down the tree
            stochasticaly and retrieve a sample. """
        return_list = []
        
        # we need a valid range to supply to __recursive_branching()
        valid_range = np.linspace(start= 0  , stop = len(self.observation_array), 
                                  num = len(self.observation_array), endpoint= False)
        
        # Converting them to np.int, as by default np.linspace returns doubles. 
        valid_range = valid_range.astype(np.int)
        
        for sample in range(num_samples):
            return_list.append(self.__recursive_branching(valid_range))
            
        return(return_list)
   
    def scramble_order(self):
        """Helper function that scrambles the order of the labels and frequency counts."""
        
        scrambled_indices = np.linspace(start = 0, stop = len(self.observation_array), 
                                        num= len(self.observation_array), endpoint = False)
        scrambled_indices = np.random.choice(a = scrambled_indices, replace= False, size = len(scrambled_indices))
        
        scrambled_indices = scrambled_indices.astype(np.uint)

        # we scramble the observation array the variable names, and the interval array.  
        new_obs_array = np.zeros(shape = self.observation_array.shape)
        new_labels = ["empty"] * len(self.labels)
        new_interval_array = np.zeros(self.interval_array.shape)
        
        for i in range(len(self.observation_array)): 
            
            # things to scramble self.observation_array, self.labels, self.interval_array
            
            new_obs_array[scrambled_indices[i]] = self.observation_array[i]
            new_labels[scrambled_indices[i]] = self.labels[i]
                          
            new_interval_array[scrambled_indices[i], 0] = self.interval_array[i, 0]
            new_interval_array[scrambled_indices[i], 1] = self.interval_array[i, 1]
                          
        self.observation_array = new_obs_array
        self.labels = new_labels
        self.interval_array = new_interval_array
        
    def __get_inclusive_range(self, left_indice, right_indice): 
        """Private helper function for getting inclusive range as a numpy array. 
        Makes the code more convenient as by default python indexing is [inclusive, exclusive]
        unline in R"""
        
        # The right indice should never be less than the left indice, ever. 
        assert right_indice >= left_indice
        
        # If they are the same (i.e. we are at a leaf), just return a np array holding a scalar value. 
        if right_indice == left_indice: 
            
            inclusive_range = np.array([left_indice])
            inclusive_range = inclusive_range.astype(np.int)
            
        else: 
            inclusive_range = np.linspace(start= left_indice, stop = right_indice, 
                                      num = np.int(right_indice - left_indice) + 1, 
                                      endpoint= True)
            inclusive_range = inclusive_range.astype(np.int)

        return(inclusive_range)
    
    
    def display_results(self):
        """Displays the results of the inference: intervals and variable names."""
        
        print("With {}% confidence, the intervals are:".format(self.confidence * 100))
        print("--------------------------------------")
        
        for counter, val in enumerate(self.labels):
        
            lower_val = self.interval_array[counter, 0]
            upper_val = self.interval_array[counter, 1] 


            output_string = "{} \n lower {} \n upper  {}".format(self.labels[counter], lower_val, upper_val)
            print(output_string)
            print("\n")
            
    
    def __handle_edge_case(self):
        """Private helper method for handling edge cases. I.E. when the number of variables 
        is less than or equal to 3."""
        
        if len(self.observation_array) <= 1:
            print("Error: number of categories should be at least 2")
            raise Exception 
            
        elif len(self.observation_array) == 2:
            b = 1 + self.observation_array[0]
            a = 1 + self.observation_array[1]
            
            frozen_beta_dist = beta(a = a, b = b)
            beta_deviates = frozen_beta_dist.rvs(self.granularity)
            
            interval, _, _, _ = hpd_grid(sample = beta_deviates, 
                                                 alpha = self.credmass_per_cell, roundto= 6)
            
            
            # Deshelling, as by default the current interval is a tuple inside a list...
            current_interval = [item for t in current_interval for item in t]
            current_interval = np.array([min(current_interval), max(current_interval)])  
            
            inverse_interval = (1 - current_interval[::-1])
 
            self.interval_array[0, 0]   = min(inverse_interval)
            self.interval_array[0, 1]   = max(inverse_interval)
            self.interval_array[1, 0]  = min(current_interval)
            self.interval_array[1, 1]  = max(current_interval)
            
            
            
        # This is the tricky one. Basically, We just do all the computation by hand 
        # without any clever recursion. Cumbersome, but it works. 
        elif len(self.observation_array) == 3:
            b = 2 + np.sum(self.observation_array[[0, 1]])
            a = 1 + self.observation_array[2]
            
            frozen_beta_dist = beta(a = a, b = b)
            beta_deviates = frozen_beta_dist.rvs(self.granularity)
            
            current_interval, _, _, _ = hpd_grid(sample = beta_deviates, 
                                                 alpha = self.credmass_per_cell, roundto= 6)
            
            
            # Deshelling, as by default the current interval returned by hpd_grid is a tuple inside a list...
            current_interval = [item for t in current_interval for item in t]
            current_interval = np.array([min(current_interval), max(current_interval)])  
            
            previous_inverse_interval = (1 - current_interval[::-1])
 
            self.interval_array[2, 0]   = min(current_interval)
            self.interval_array[2, 1]   = max(current_interval)
        
        
            # cleaning and renaming current interval, just to make things easier. 
            previous_interval = current_interval
            del current_interval
            
            self.level += 1
    
## -------------- This represents where we would normally call construct_tree() a second time. Here, 
# -------------- we just recompute manually and keep going. 
            
            b = 1 + self.observation_array[[0]]
            a = 1 + self.observation_array[1]
            
            b = b* 2**(self.level - self.tree_height)
            a = a * 2**(self.level - self.tree_height)
            
            frozen_beta_dist = beta(a = a, b = b)
            beta_deviates = frozen_beta_dist.rvs(self.granularity)
            
            current_interval, _, _, _ = hpd_grid(sample = beta_deviates, 
                                                 alpha = self.credmass_per_cell, roundto= 6)
            
            
            # Deshelling, as by default the current interval returned by hpd_grid is a tuple inside a list...
            current_interval = [item for t in current_interval for item in t]
            current_interval = np.array([min(current_interval), max(current_interval)])  
            
            current_inverse_interval = (1 - current_interval[::-1])
            
            current_interval = current_interval * previous_inverse_interval
            
            # Finally, we get to write the results to the array. 
            self.interval_array[1, 0]   = min(current_interval)
            self.interval_array[1, 1]   = max(current_interval)
            self.interval_array[0, 0]   = min(current_inverse_interval)
            self.interval_array[0, 1]   = max(current_inverse_interval)
            
            self.level -= 1
        
        else: 
            print("Error: something has gone wrong.")
            raise Exception
        
        
        
    def construct_tree(self, valid_range = None, previous_interval = None):
        """Primary function that builds the tree and computes confidence intervals for each variable. 
         Calls itself recursively. """
        
        
        # Handle the edge case if there are three or less variables. 
        if len(self.observation_array) <= 3: 
            self.__handle_edge_case()
            return()
            
        
        if valid_range is None: # If we are starting at the root.
            
            midpoint = np.floor((len(self.observation_array) - 1)/2)
            left_range = self.__get_inclusive_range(0, midpoint)
            right_range = self.__get_inclusive_range(midpoint + 1, len(self.observation_array) - 1)
             
        else:  # Standard case otherwise.. 
            midpoint = np.floor(np.mean([valid_range[0], valid_range[-1]]))
            left_range = self.__get_inclusive_range(valid_range[0], midpoint)
            right_range = self.__get_inclusive_range(midpoint +  1, valid_range[-1])
            
        # A and B for the beta distribution. 
        b = np.sum(self.observation_array[left_range]) 
        a = np.sum(self.observation_array[right_range]) 
        
       # We discount each one by our distance from the top of the tree. 
        b = b* 2**(self.level - self.tree_height)
        a = a * 2**(self.level - self.tree_height)
        

        # pseudocounts for each side of a cell. Each outcome type/variable increments them by one. 
        left_base = len(self.observation_array[left_range])
        right_base = len(self.observation_array[right_range])
        
        
        # Freeze the distribution, then get random deviates for the density estimator later on. 
        frozen_beta_dist = beta(a = (a + right_base), b = (b + left_base))
        beta_deviates = frozen_beta_dist.rvs(self.granularity)
        
        
        
        
        # Next, we branch depending on where we are in the tree. 

        
        # they are both length 1, we must be at a leaf. 
        if len(left_range) == 1 and len(right_range) == 1:     
            current_interval, _, _, _ = hpd_grid(sample = beta_deviates, 
                                                 alpha = self.credmass_per_cell, roundto= 6)
            
            # Deshelling, as by default the current interval is a tuple inside a list...
            current_interval = [item for t in current_interval for item in t]


            current_interval = np.array([min(current_interval), max(current_interval)])  

            current_inverse_interval = (1 - current_interval[::-1])

            # self.level += 1                 # Is this a bug???

            new_left_interval = previous_interval * current_inverse_interval 
            new_right_interval = previous_interval *  current_interval
 
            self.interval_array[left_range, 0]   = min(new_left_interval)
            self.interval_array[left_range, 1]   = max(new_left_interval)
            self.interval_array[right_range, 0]  = min(new_right_interval)
            self.interval_array[right_range, 1]  = max(new_right_interval)
            

            return()

        # if the cell only branches on side side instead of two or none. 
        elif (len(left_range) == 1 or len(right_range) == 1): 
            current_interval, _, _, _ = hpd_grid(sample = beta_deviates, alpha = self.credmass_per_cell, 
                                                 roundto=4)

            # Deshelling, as by default the current interval is a tuple inside a list...
            current_interval = [item for t in current_interval for item in t]

            
            current_interval = np.array([min(current_interval), max(current_interval)])  

            
            current_inverse_interval = (1 - current_interval[::-1])


            new_left_interval = previous_interval * current_inverse_interval 

            self.level += 1
            self.construct_tree(left_range, new_left_interval)                    
            self.level -= 1
            
            # append(credibility_regions, current_interval)
            
            self.interval_array[right_range, 0]  =  min(current_interval)
            self.interval_array[right_range, 1]  =  max(current_interval)        
            
            return()

        elif valid_range is None:  # if we are at the root and ... XXX TODO commenting
            
            current_interval, _, _, _ = hpd_grid(sample = beta_deviates, alpha = self.credmass_per_cell, 
                                                 roundto=4)
       
            # Deshelling, as by default the current interval is a tuple inside a list...
            current_interval = [item for t in current_interval for item in t]
            
            current_interval = np.array([min(current_interval), max(current_interval)])  
            
            current_inverse_interval = (1 - current_interval[::-1])

            new_left_interval = current_inverse_interval 
            new_right_interval = current_interval

            self.level += 1
            self.construct_tree(left_range,  new_left_interval)                       # the one on the left
            self.construct_tree(right_range, new_right_interval)                 # the one on the right
            self.level -= 1
            
   
            return()


        # if we are branching twice i.e. the standard case with an intermediary node.
        else:
            current_interval, _, _, _ = hpd_grid(sample = beta_deviates, alpha = self.credmass_per_cell, 
                                                 roundto=4)

            # Deshelling, as by default the current interval is a tuple inside a list...
            current_interval = [item for t in current_interval for item in t]

            
            current_interval = np.array([min(current_interval), max(current_interval)])  

            current_inverse_interval = (1 - current_interval[::-1])
                        
            new_left_interval = previous_interval * current_inverse_interval 
            new_right_interval = previous_interval *  current_interval
              
            # send them up
            self.level += 1
            self.construct_tree(left_range, new_left_interval)        # the one on the left
            self.construct_tree(right_range,  new_right_interval ) # the one on the right
            self.level -= 1

            return()
        
        
        # Once the recursion is done and the first function returns, we have to reset self.level. 
        self.level = 1


In [49]:

test_obj = Tree_Inference_Object(observation_array= np.array([3, 40, 61, 59]), 
                                 labels = ["Black", "Brown", "Blonde", "Red"],
                                 confidence= 0.99, granularity=  10**4)



test_obj.construct_tree()
test_obj.display_results()

With 99.0% confidence, the intervals are:
--------------------------------------
Black 
 lower 0.001051338600000002 
 upper  0.099821082


Brown 
 lower 0.11544318299999996 
 upper  0.4246149244


Blonde 
 lower 0.21447885239999998 
 upper  0.5453589688000001


Red 
 lower 0.2048894352 
 upper  0.5311289706000001




In [50]:
test_obj.get_tree_sample(100)

['Red',
 'Blonde',
 'Blonde',
 'Blonde',
 'Blonde',
 'Black',
 'Blonde',
 'Blonde',
 'Red',
 'Red',
 'Blonde',
 'Brown',
 'Red',
 'Red',
 'Red',
 'Blonde',
 'Blonde',
 'Blonde',
 'Red',
 'Red',
 'Red',
 'Red',
 'Blonde',
 'Red',
 'Blonde',
 'Blonde',
 'Blonde',
 'Blonde',
 'Red',
 'Red',
 'Red',
 'Red',
 'Red',
 'Blonde',
 'Blonde',
 'Blonde',
 'Red',
 'Brown',
 'Red',
 'Red',
 'Brown',
 'Red',
 'Red',
 'Brown',
 'Blonde',
 'Red',
 'Blonde',
 'Brown',
 'Red',
 'Blonde',
 'Red',
 'Red',
 'Red',
 'Blonde',
 'Red',
 'Red',
 'Brown',
 'Brown',
 'Red',
 'Red',
 'Blonde',
 'Red',
 'Brown',
 'Red',
 'Blonde',
 'Black',
 'Brown',
 'Blonde',
 'Red',
 'Brown',
 'Blonde',
 'Red',
 'Brown',
 'Red',
 'Brown',
 'Red',
 'Red',
 'Blonde',
 'Brown',
 'Red',
 'Red',
 'Red',
 'Brown',
 'Brown',
 'Brown',
 'Brown',
 'Red',
 'Brown',
 'Brown',
 'Brown',
 'Red',
 'Red',
 'Blonde',
 'Red',
 'Blonde',
 'Blonde',
 'Blonde',
 'Red',
 'Brown',
 'Blonde']

In [54]:
# Maybe experiment with different priors.