### TODO
1. Smaller stuff
    1. Proper exeptions
    1. Replace Node.is_leaf with Node.max_depth or similar
1. Optimization
    1. \_\_rmul\_\_
1. Make the Tree indexable
1. Make the Tree iterable?
1. Type hinting
1. Global behavior control

In [1]:
# imports
from inspect import signature

from warnings import warn

from copy import copy, deepcopy

In [2]:
class Operation():
    def __init__(self, function, name):
        if not callable(function):
            raise TypeError('function must be a callable object')
        
        parameter_number = len(signature(function).parameters)
        if not parameter_number == 2:
            raise TypeError(f'function must take exactly two inputs, not {parameter_number}')
        
        if not type(name) == str:
            raise TypeError(f'name must be a string, not {type(name)}')
        
        self.function = function
        self.name = name
    
    def __call__(self, outcome, node):
        if not isinstance(node, Node):
            raise TypeError(f'node must be a Node, not {type(node)}')
        
        return [self.function(outcome, o) for o in node._outcomes]
    
    def __repr__(self, ):
        return  f'<Operation: name={self.name}>'

In [3]:
class Node():
    # define some general operations
    standard_operations = {'+': Operation(lambda x,y: x+y, 'add'), 
                  '*': Operation(lambda x,y: x*y, 'multiply'), 
                  'concat': Operation(lambda x,y: [x,y], 'concatenate')}
    
    _default_operation = standard_operations['+']
    
    # Probably a good idea to just write these as global package properties...
    
    # makes it possible to turn off warnings globally
    def _issue_warnings(issue_warnings, ):
        assert isinstance(issue_warnings, bool), 'Must be a bool'
        Node._warnings_enabled = issue_warnings
    
    _warnings_enabled = True
    
    issue_warnings = _issue_warnings
    
    # makes it possible to control the collapse behavior globally
    def _set_collapse_behavior(collapse_behavior, ):
        assert isinstance(collapse_behavior, int), 'Collapse behavior must be an integer'
        Node._collapse_behavior = collapse_behavior
        
    _collapse_behavior = 0
    
    set_collapse_behavior = _set_collapse_behavior
    
    def __init__(self, outcomes, probabilities = None, nodes = None, operations = None):
        # 
        if probabilities == None:
            probabilities = [1/len(outcomes) for o in outcomes]
            
        if nodes == None:
            nodes = [None for o in outcomes]
            
        if operations == None:
            operations = [None if n == None else Node._default_operation for n in nodes]
        
        # check validitiy of inputs
        if not all(type(p) == float or type(p) == int for p in probabilities):
            raise TypeError('probabilities must be float or int')
        
        if not all(type(n) == Node or n == None for n in nodes):
            raise TypeError('nodes must be Node or None')
       
        if not all(type(o) == Operation or o == None for o in operations):
            raise TypeError('operations must be Operation or None') 
        
        if not (len(outcomes) == len(probabilities) and 
                len(outcomes) == len(nodes) and 
                len(outcomes) == len(operations)):
            raise TypeError('all input lists must have the same lenght')
        
        if not all((n == None) == (o == None) for n, o in zip(nodes, operations)):
            raise TypeError('nodes and operations must come in pairs')
        
        # warnings
        if sum(probabilities) != 1 and Node._warnings_enabled: warn(
            f'Warning: The total probability is {sum(probabilities):.4f} 1')
            
        # create tree
        self._outcomes = outcomes
        self._probabilities = probabilities
        
        self._nodes = nodes
        self._operations = operations
    
    def branch_at(self, node, operation, index):
        out = copy(self)
        
        out._nodes[index] == node
        out._operations[index] == operation
        
        return out
    
    def branch_all_root(self, node, operation):
        out = copy(self)
        
        out._nodes = [node if n==None else n for n in self._nodes]
        out._operations = [operation if o==None else o for o in self._operations]
        
        return out
    
    def branch_all_leaf(self, node, operation):
        out = copy(self)
        
        out._nodes = [node if n==None else n.branch_all_leaf(node, operation) for n in self._nodes]
        out._operations = [operation if o==None else o for o in self._operations]
        
        return out
    
    def collapse_v(self, ):
        outcomes = []
        probabilities = []
        
        for i in range(len(self)):
            if self._nodes[i] != None:
                node = self._nodes[i].collapse_v()
                
                outcomes += self._operations[i](self._outcomes[i], node)
                probabilities += [self._probabilities[i] * p for p in node._probabilities]
            else:
                outcomes.append(self._outcomes[i])
                probabilities.append(self._probabilities[i])
        
        return Node(outcomes, probabilities)
                
    def collapse_h(self, ):
        return -1
    
    # NEEDS TO BE REPLACED BY 'depth(self, ) -> int'
    def is_leaf(self, ):
        return all([n == None for n in self._nodes])
        
    def __add__(self, other):
        if not isinstance(other, Node):
            raise TypeError(f'Addition is only supported for Node + Node not {type(self)} + {type(other)}')
        
        return self.branch_all_leaf(other, Node.standard_operations['+'])
        
    def __sub__(self, other):
        if not isinstance(other, Node):
            raise TypeError(f'Subtraction is only supported for Node - Node not {type(self)} - {type(other)}')
        
        return self.branch_all_leaf(other, Node.standard_operations['-'])
        
    def __mul__(self, other):
        if not isinstance(other, Node):
            raise TypeError(f'Multiplication is only supported for Node * Node or Int * Node not {type(self)} * {type(other)}')
        
        return self.branch_all_leaf(other, Node.standard_operations['*'])
    
    def __rmul__(self, other):
        if not isinstance(other, int):
            raise TypeError(f'Multiplication is only supported for Node * Node or Int * Node not {type(self)} * {type(other)}')
            
        if not self.is_leaf():
            raise TypeError('The Node must be a leaf node to be multiplied with an integer')
        
        out = self
        for i in range(other - 1):
            out = out + out
            
        return out
        
        
    def __eq__(self, node):
        if type(node) != Node: 
            return False
        else:
            return -1
        
    def __len__(self, ):
        return len(self._outcomes)
    
    def __copy__(self, ):
        out = Node(copy(self._outcomes), 
                   copy(self._probabilities), 
                   [copy(n) for n in self._nodes], 
                   copy(self._operations))
        
        return out
    
    def __deepcopy__(self, memo):
        out = Node(deepcopy(self._outcomes, memo), 
                   deepcopy(self._probabilities, memo), 
                   deepcopy(self._nodes, memo), 
                   deepcopy(self._operations, memo))
        
        return out
    
    def __repr__(self, ):
        return f'<Node: id={id(self)}>'
    
    def __str__(self, ):
        out = ''
        
        for i in range(len(self)):
            branch = '├─ ' if i != len(self)-1 else '└─ '
            
            out += (branch + f'{self._probabilities[i]:2.2f}' 
                    + ' ── ' + f'{self._outcomes[i].__repr__():8s}' + '\n')
            
            if type(self._nodes[i]) == Node: # this loop enables branching
                offset = '|      ' if i != len(self)-1 else '       '
                
                operation_out = offset + '| ' + self._operations[i].name + '\n'
                sub_out = offset + str(self._nodes[i]).replace('\n', '\n' + offset)
                
                out += operation_out + sub_out + '\n'
            
        out = out[:-2] # remove the last linebreak
            
        return out

In [4]:
def d(sides):
    return Node(range(1, sides+1))

In [5]:
a = d(3)
print(a)

├─ 0.33 ── 1       
├─ 0.33 ── 2       
└─ 0.33 ── 3      
