In [1]:
# imports
from inspect import signature

from warnings import warn

from copy import copy, deepcopy

In [10]:
class Operation():
    def __init__(self, function, name):
        assert callable(function), (
            'The function needs to be callable')
        assert len(signature(function).parameters) == 2, (
            'The function need to take exactly two parameters')
        assert type(name) == str, (
            'The name must be a string')
        
        self.function = function
        self.name = name
    
    def __call__(self, outcome, node):
        return [self.function(outcome, o) for o in node._outcomes]
    
    def __repr__(self):
        return  f'<Operation: name={self.name}>'

In [55]:
class Node():
    # define some general operations
    standard_operations = {'+': Operation(lambda x,y: x+y, 'add'), 
                  '*': Operation(lambda x,y: x*y, 'multiply'), 
                  'concat': Operation(lambda x,y: [x,y], 'concatenate')}
    
    _default_operation = standard_operations['+']
    
    # makes it possible to turn off warnings globally
    def _issue_warnings(issue_warnings, ):
        Node._warnings_enabled = issue_warnings
    
    _warnings_enabled = True
    
    issue_warnings = _issue_warnings
    
    
    def __init__(self, outcomes, probabilities = None, nodes = None, operations = None):
        if probabilities == None:
            probabilities = [1/len(outcomes) for o in outcomes]
            
        if nodes == None:
            nodes = [None for o in outcomes]
            
        if operations == None:
            operations = [None if n == None else Node._default_operation for n in nodes]
        
        # check validitiy of inputs
        assert all(type(p) == float or type(p) == int for p in probabilities), (
            'probabilities should be float or int')
        assert all(type(n) == Node or n == None for n in nodes), (
            'nodes must be Node objects or None')
        assert all(type(o) == Operation or o == None for o in operations), (
            'operations must be Operations objects or None')
        
        assert (len(outcomes) == len(probabilities) and 
                len(outcomes) == len(nodes) and 
                len(outcomes) == len(operations)), (
            'all inputs need to have the same lenght')
        
        assert all((n == None) == (o == None) for n, o in zip(nodes, operations)), (
            'nodes must be accompanied by an operation and the other way around')
        
        # warnings
        if sum(probabilities) != 1 and Node._warnings_enabled: warn(
            f'Warning: The total probability is {sum(probabilities):.4f} 1')
            
        # create tree
        self._outcomes = outcomes
        self._probabilities = probabilities
        
        self._nodes = nodes
        self._operations = operations
    
    def branch_at(self, node, operation, index):
        out = copy(self)
        
        out._nodes[index] == node
        out._operations[index] == operation
        
        return out
    
    def branch_all_root(self, node, operation):
        out = copy(self)
        
        out._nodes = [node if n==None else n for n in self._nodes]
        out._operations = [operation if o==None else o for o in self._operations]
        
        return out
    
    def branch_all(self, node, operation):
        out = copy(self)
        
        out._nodes = [node if n==None else n.branch_all(node, operation) for n in self._nodes]
        out._operations = [operation if o==None else o for o in self._operations]
        
        return out
    
    def collapse_v(self, ):
        outcomes = []
        probabilities = []
        
        for i in range(len(self)):
            if self._nodes[i] != None:
                node = self._nodes[i].collapse_v()
                
                outcomes += self._operations[i](self._outcomes[i], node)
                probabilities += [self._probabilities[i] * p for p in node._probabilities]
            else:
                outcomes.append(self._outcomes[i])
                probabilities.append(self._probabilities[i])
        
        return Node(outcomes, probabilities)
                
    def collapse_h(self, ):
        
        
    
    def __add__(self, node):
        return self.branch_all(node, Node.standard_operations['+'])
        
    def __sub__(self, node):
        return self.branch_all(node, Node.standard_operations['-'])
        
    def __mul__(self, node):
        return self.branch_all(node, Node.standard_operations['*'])
        
    def __eq__(self, node):
        if type(node) != Node: 
            return False
        else:
            return -1
        
    def __len__(self, ):
        return len(self._outcomes)
    
    def __copy__(self, ):
        out = Node(copy(self._outcomes), 
                   copy(self._probabilities), 
                   [copy(n) for n in self._nodes], 
                   copy(self._operations))
        
        return out
    
    def __deepcopy__(self, memo):
        out = Node(deepcopy(self._outcomes, memo), 
                   deepcopy(self._probabilities, memo), 
                   deepcopy(self._nodes, memo), 
                   deepcopy(self._operations, memo))
        
        return out
    
    def __repr__(self, ):
        return f'<Node: id={id(self)}>'
    
    def __str__(self, ):
        out = ''
        
        for i in range(len(self)):
            branch = '├─ ' if i != len(self)-1 else '└─ '
            
            out += (branch + f'{self._probabilities[i]:2.2f}' 
                    + ' ── ' + f'{self._outcomes[i].__repr__():8s}' + '\n')
            
            if type(self._nodes[i]) == Node: # this loop enables branching
                offset = '|      ' if i != len(self)-1 else '       '
                
                operation_out = offset + '| ' + self._operations[i].name + '\n'
                sub_out = offset + str(self._nodes[i]).replace('\n', '\n' + offset)
                
                out += operation_out + sub_out + '\n'
            
        out = out[:-2] # remove the last linebreak
            
        return out
    
Node.issue_warnings(False)

In [65]:
def d(sides):
    return Node(range(1, sides+1))

In [88]:
a = Node([1,2,3,4], nodes=[None, None, Node([5,6], nodes=[Node(range(9)), Node([1], nodes=[Node(range(4))])]), None])

In [89]:
print(a)

├─ 0.25 ── 1       
├─ 0.25 ── 2       
├─ 0.25 ── 3       
|      | add
|      ├─ 0.50 ── 5       
|      |      | add
|      |      ├─ 0.11 ── 0       
|      |      ├─ 0.11 ── 1       
|      |      ├─ 0.11 ── 2       
|      |      ├─ 0.11 ── 3       
|      |      ├─ 0.11 ── 4       
|      |      ├─ 0.11 ── 5       
|      |      ├─ 0.11 ── 6       
|      |      ├─ 0.11 ── 7       
|      |      └─ 0.11 ── 8      
|      └─ 0.50 ── 6       
|             | add
|             └─ 1.00 ── 1       
|                    | add
|                    ├─ 0.25 ── 0       
|                    ├─ 0.25 ── 1       
|                    ├─ 0.25 ── 2       
|                    └─ 0.25 ── 3    
└─ 0.25 ── 4      


In [90]:
print(a.collapse_v())

├─ 0.25 ── 1       
├─ 0.25 ── 2       
├─ 0.01 ── 8       
├─ 0.01 ── 9       
├─ 0.01 ── 10      
├─ 0.01 ── 11      
├─ 0.01 ── 12      
├─ 0.01 ── 13      
├─ 0.01 ── 14      
├─ 0.01 ── 15      
├─ 0.01 ── 16      
├─ 0.03 ── 10      
├─ 0.03 ── 11      
├─ 0.03 ── 12      
├─ 0.03 ── 13      
└─ 0.25 ── 4      


In [84]:
a = d(1) + d(1)

In [85]:
print(a)

└─ 1.00 ── 1       
       | add
       └─ 1.00 ── 1     


In [86]:
print(a.collapse_v())

└─ 1.00 ── 2      
