#### 1. Tree class

The following link-based, general tree classes are adapted from *Data Structures and Algorithms in Python* by Goodrich, Tamassia, and Goldwasser.

In [2]:
class Tree:
    """Abstract base class for trees"""
    
    class Position:
        """Nested class for location of an element"""
        def element(self):
            raise NotImplementedError('must be impmlemented in a sub-class')
            
        def __eq__(self,other):
            raise NotImplementedError('must be impmlemented in a sub-class')

        def __ne__(self,other):
            raise NotImplementedError('must be impmlemented in a sub-class')
            
    def root(self):
        raise NotImplementedError('must be impmlemented in a sub-class')
    def parent(self,p):
        raise NotImplementedError('must be impmlemented in a sub-class')
    def num_children(self,p):
        raise NotImplementedError('must be impmlemented in a sub-class')
    def children(self,p):
        raise NotImplementedError('must be impmlemented in a sub-class')
    def __len__(self):
        raise NotImplementedError('must be impmlemented in a sub-class')

    def is_root(self,p):
        """return true if position p is the root"""
        return self.root()==p
    def is_leaf(self,p):
        """return true if position p is a leaf"""
        return self.num_children(p)==0
    def is_empty(self):
        """return true if the tree is empty"""
        return len(self)==0

In [9]:
class LinkedTree(Tree):
    """Linked representation of a general tree structure"""
    
    class _Node:
        """Nested class for storing nodes of binary trees"""
        __slots__ = ('_element','_parent','_children','x','y')
        def __init__(self, element, parent = None, children = []):
            self._element = element
            self._parent = parent
            self._children = children
            
            self.x = None                                   # Coordinates for node
            self.y = None

    class Position(Tree.Position):
        """Nested class for location of an element"""
        
        def __init__(self, container, node):
            self._container = container
            self._node = node
        
        def element(self):
            """returns the element at the position"""
            return self._node._element
        
        def __eq__(self,other):
            """Return True if other position represents the same location"""
            return type(other) is type(self) and other._node is self._node

        def _prev(self):
            """returns the position of the previous node"""
            
            sibs = self._node._parent._children
            index = sibs.index(self._node)
            if index > 0:
                return self._container._make_position(sibs[index-1])
            else:
                return None
        
        def _next(self):
            """returns the position of the next node"""
            
            sibs = self._node._parent._children
            index = sibs.index(self._node)
            if index+1 < len(sibs):
                return self._container._make_position(sibs[index+1])
            else:
                return None
        
        def __str__(self):
            """returns a string of the element"""
            return str(self._node._element)

    def _validate(self,p):
        """Return associated node, if position is valid"""
        if not isinstance(p,self.Position):
            raise TypeError('p must be a proper Position type')
        if p._container is not self:
            raise ValueError('p does not belong to this container')
        if p._node._parent is p._node:
            raise ValueError('p is no longer valid')
        return p._node
    
    def _make_position(self,node):
        """Return Position instance for given node"""
        return self.Position(self,node) if node is not None else None



    def __init__(self):
        """Constructor for Binary tree"""
        self._root = None
        self._size = 0
    
    def __len__(self):
        """Returns the size of the tree"""
        return self._size
    
    def root(self):
        """Return the root Position of the tree (or None if p is root)."""
        return self._make_position(self._root)
    
    def parent(self,p):
        """Return the position of p's parent"""
        node = self._validate(p)
        return self._make_position(node._parent)

    def num_children(self,p):
        """Return the number of children of position p"""
        node = self._validate(p)        
        return len(node._children)
    
    def _add_root(self,e):
        """Create a root for the graph"""
        if self._root is not None: raise ValueError('Root Exists')
        self._size = 1
        self._root = self._Node(e,None,[])
        return self._make_position(self._root)

    def siblings(self,p):
        """Returns a list of position of the siblings of p"""
        node = self._validate(p)
        
        # l is a list of positions
        l = []
        # parent is nodes p's parent
        parent = node._parent
        # siblings is a list of the nodes p's siblings
        sibs = parent._children
        # loop over the nodes of p's children
        for child in sibs:
            # create a position for the child of p
            q = self._make_position(child)
            if q == p:
                pass
            else:
                #append the position to the list
                l.append(q)
            
        return l
        
    def children(self,p):
        """Return a list of the position of the children of p"""
        # validate the position coming in
        node = self._validate(p)
        # children is a list of the nodes p's children
        children = node._children
        
        return list(map(self._make_position,children))
    
    def _add_child(self,p,e):
        """Add a new child to the node at position p with element e"""
        # validate the parent position p
        node_p = self._validate(p)

        # Create a node for the new child
        node_c = self._Node(e,node_p,[])

        # append child to the list
        node_p._children.append(node_c)
        
        # increment the size
        self._size += 1
        
        # return the position of the newest child node
        return self._make_position(node_c)
    
    def _replace(self,p,e):
        """replace the element at position p with e"""
        node = self._validate(p)
        old = node._element
        node._element = e
        return old

    def _delete(self,p):
        """Delete the node at position p if it is a leaf"""
        node_p= self._validate(p)
        # check if p is a leaf
        
        if self.num_children(p)>1:
            #raise ValueError('p is not a leaf')
            print("Warning: 'p is not a leaf'")
            return None
        elif self.num_children(p)==1:
            kid = self._make_position(node_p._children[0])
            self.replace(p,kid.element())
            #self._delete(kid)
            self._delete_leaf(kid)
            self.replace(self.parent(p),p.element())
            self._delete(p)
            print("Deleting a non-leaf ...")
            return None
        
        # parent is the node for the parent of p
        parent = node_p._parent
        kids = self.children(self._make_position(parent))
        # Loop over the children of p        
        for i in range(len(kids)):
            # check if we found the node to delete
            if kids[i]==p:
                break
        # remove the node from the parent node's children list
        parent._children.pop(i)
        # Decrement the size of the tree        
        self._size -= 1

        return node_p._element

    def _delete_leaf(self,p):
        """Delete the node at position p if it is a leaf"""
        if self.num_children(p)>0:
            raise ValueError('p is not a leaf')
        else:
            node_p= self._validate(p)
            # parent is the node for the parent of p
            parent = node_p._parent
            kids = self.children(self._make_position(parent))
            # Loop over the children of p        
            for i in range(len(kids)):
                # check if we found the node to delete
                if kids[i]==p:
                    break
            # remove the node from the parent node's children list
            parent._children.pop(i)
            # Decrement the size of the tree        
            self._size -= 1

            return node_p._element
    
    def _attacho(self,p,t1):
        """Attach trees t1 as a subtrees of external p"""
        # validate position
        node = self._validate(p)
        # Check the position is a leaf
        #if not self.is_leaf(p): raise ValueError('p must be a leaf')
        # Check if the two trees are the same type
        if not type(self) is type(t1): raise TypeError('Tree types must match')
        
        # Check if the tree added is empty
        if not t1.is_empty():
            root_node = t1._root
            
            # Tree t1's parent link is node p
            root_node._parent = node
            # Put the root of t1 into the children list of p 
            node._children.append(root_node)

            # Increment the size of T
            self._size += len(t1)
            # clean up the tree t1
            t1._root = None
            t1._size = 0
    
    def _attach(self,p,t1):
        """Attach trees t1 as a subtrees of external p"""
        # validate position
        node = self._validate(p)
        # Check the position is a leaf
        if not self.is_leaf(p): raise ValueError('p must be a leaf')
        # Check if the two trees are the same type
        if not type(self) is type(t1): raise TypeError('Tree types must match')
        
        # Check if the tree added is empty
        if not t1.is_empty():
            root_node = t1._root                      # root node of t1
            
            self._replace(p,root_node._element)       # replace element at node p
            
            root_node._parent = node._parent          # Tree t1's parent link is node p
            
            node._children = root_node._children      # set children of p to children from root_node
            
            for child in node._children:              # set the parent node for each child of root of t1
                child._parent = node
            
            # Increment the size of T
            self._size += len(t1)
            self._size -= 1
            # clean up the tree t1
            t1._root = None
            t1._size = 0

In [4]:
class MutableLinkedTree(LinkedTree):
    """Mutable Linked representation of a general tree structure
        
        Functions below are public to the user and call the internal functions above
    """
    
    def add_root(self,e):
        """Create a root for the graph"""
        return self._add_root(e)
        
    def add_child(self,p,e):
        """Add a new child to the node at position p with element e"""
        return self._add_child(p,e)
    
    def replace(self,p,e):
        """replace the element at position p with e"""
        return self._replace(p,e)

    def delete(self,p):
        """Delete the node at position p if it is a leaf"""
        return self._delete(p)
    
    def attach(self,p,t1):
        """Attach trees t1 as a subtrees of external p"""
        self._attach(p,t1)

#### 2. Print Code

The following two functions compute coordinates using an algorith similar to that described in *Data Structures and Algorithms in Python* in chapter 8.

In [4]:
def depth(p):
    """returns the depth of a node in the tree T"""
    T = p._container
    count=0
    while T.parent(p):
        count+=1
        p = T.parent(p)
    return count

In [5]:
def get_coordinates(T):
    """Computes the coordinates for the nodes of the tree T"""
        
    def _in_order(p):
        """in-order traversal of tree"""
        global count
        
        Sibs = T.children(p)                   # Determine the children of p
        
        for i in range(len(Sibs)-1):           # Loop over the children of p
            _in_order(Sibs[i])
        
        p._node.x = count                      # in_order activities
        p._node.y = depth(p)
        Coords.append([p.element(),count,depth(p),T.parent(p)])
        count += 1
        
        if len(Sibs)>0: _in_order(Sibs[-1])    # Recursively look at last child    
    
    Coords = [[]]
    global count
    count = 1
    
    _in_order(T.root())
    
    return Coords[1:]

In [6]:
import matplotlib.pyplot as plt

def printTree(T,scale=1,node=60,font=12):
    """Prints a BinaryTree object """
    
    Coords = get_coordinates(T)                         # Compute the coordinates
    
    n = len(Coords)                                     # Number of Nodes
    
    X = [Coords[i][1] for i in range(n)]                # x-coords
    Y = [Coords[i][2] for i in range(n)]                # y-coords
    Z = [Coords[i][0] for i in range(n)]                # labels


    fig, ax = plt.subplots()                            # Create a plot object
        
    M = max(Y)+1                                        # invert y-values by maximum
    Y = [M - y for y in Y]

    ax.scatter(X, Y, s=node)                            # Create the scatter plot
    
    N = n+1
    ax.set(xlim=(0, N), xticks=list(range(1,N)),        # format the window
           ylim=(0, M+1), yticks=list(range(M+1)))
    
    shrink = 1.5*font                                   # size to shrink lines
    
    for i,txt in enumerate(Z):                                              # traverse the nodes

        ax.annotate(txt, xy=(X[i],Y[i]), xytext=(X[i],Y[i]), 
                    fontsize=font, ha='center', va='center')                # create the label for the nodes

        if Coords[i][3]:
            parent_node = Coords[i][3]._node                                # reference to the parent node
            parent_coords = (parent_node.x, M - parent_node.y)              # coordinates of the parent node
            ax.annotate("",                                                 # add an arrow from child to parent
                        xy=parent_coords, xytext=(X[i],Y[i]),
                        arrowprops=dict(arrowstyle="-",
                                        shrinkA=shrink,shrinkB=shrink,
                                        connectionstyle="arc3"))

    plt.gcf().set_size_inches(16*scale, 9*scale)    # set figure size
    plt.show()                                      # display the figure