In [7]:
class WTNode:
    def __init__(self,d,l,m,r):
        self.data = d
        self.left = l
        self.right = r
        self.midl = m
        self.mult = 0
          
    # prints the node and all its children in a string
    def __str__(self):  
        st = "("+str(self.data)+", "+str(self.mult)+") -> ["
        if self.left != None:
            st += str(self.left)
        else: st += "None"
        if self.midl != None:
            st += ", "+str(self.midl)
        else: st += ", None"
        if self.right != None:
            st += ", "+str(self.right)
        else: st += ", None"
        return st + "]"
    
class WordTree:
    def __init__(self):
        self.root = None
        self.size = 0
        
    # Holds data for a node, its parent and a flag representing
    # the link between them
    class NodeWrapper:
        def __init__(self, p, n, f):
            self.parent = p
            self.actual = n
            self.flag = f
        
    def __str__(self):
        return str(self.root)

    # returns the number of times that string st is stored in the tree
    def count(self, st):    
        if len(st) == 0: return
        ptr = self.root  
        i = 0
        while ptr != None:
            if st[i] == ptr.data:
                if i == len(st)-1: return ptr.mult
                ptr = ptr.midl
                i += 1
            elif st[i] < ptr.data:
                ptr = ptr.left
            elif st[i] > ptr.data:
                ptr = ptr.right
        return 0
        
    # adds the string st in the tree, increases the size by 1
    # and returns None
    def add(self, st):
        if len(st) == 0: return
        if self.size == 0: return self._initTree(st)
        flag = None; parent = None; currNode = self.root
        i = 0
        while i < len(st):
            if currNode == None:
                if flag == 'l':
                    parent.left = WTNode(st[i], None, None, None)
                    currNode = parent.left
                elif flag == 'm':
                    parent.midl = WTNode(st[i], None, None, None)
                    currNode = parent.midl
                elif flag == 'r':
                    parent.right = WTNode(st[i], None, None, None)
                    currNode = parent.right
                parent = currNode; currNode = currNode.midl
                flag = 'm'; i += 1
            elif st[i] == currNode.data:
                parent = currNode; currNode = currNode.midl
                flag = 'm'; i += 1
            elif st[i] > currNode.data:
                parent = currNode; currNode = currNode.right
                flag = 'r'
            elif st[i] < currNode.data:
                parent = currNode; currNode = currNode.left
                flag = 'l'
        parent.mult += 1
        self.size += 1
                
    # builds the first word in the tree
    def _initTree(self, st):
        self.root = WTNode(st[0], None, None, None)
        ptr = self.root
        st = st[1:]
        while st:
            ptr.midl = WTNode(st[0], None, None, None)
            st = st[1:]
            ptr = ptr.midl
        ptr.mult += 1
        self.size += 1

    # returns the lexicographically smallest string in the tree
    # if the tree is empty, return None
    def minst(self):    
        if self.size == 0: return
        res = ""; ptr = self.root
        while ptr.left != None or ptr.midl != None:
            if ptr.left != None:
                ptr = ptr.left
            elif ptr.midl != None:
                res += ptr.data
                if ptr.mult > 0: return res
                ptr = ptr.midl
        res += ptr.data              
        return res

    # removes one occurrence of string st from the tree and returns None
    # if st does not occur in the tree then it returns without changing the tree
    # it updates the size of the tree accordingly
    def remove(self,st):    
        if self.count(st) == 0 or len(st) == 0:
            return
        else:
            self._removeWord(st)
            self.size -= 1
            return
        
    def _removeWord(self, st):
        nodesToRemove = self._buildArrayOfWrappers(st)
        for i in range(len(nodesToRemove)-1, -1, -1):
            currNode = nodesToRemove[i].actual
            if i == len(nodesToRemove)-1:
                currNode.mult -= 1
                if currNode.mult > 0: return
            if currNode.mult == 0 and currNode.midl == None:
                self._removeNode(nodesToRemove[i])
                
    # Builds an array of type NodeWrapper
    def _buildArrayOfWrappers(self, st):
        A = [None for i in range(len(st))]
        i = 0
        parent = None
        ptr = self.root
        flag = None
        while len(st) > 0:
            if st[0] == ptr.data:
                A[i] = self.NodeWrapper(parent, ptr, flag)
                parent = ptr
                ptr = ptr.midl
                flag = 'm'
                st = st[1:]
                i += 1
            elif st[0] < ptr.data:
                parent = ptr
                ptr = ptr.left
                flag = 'l'
            elif st[0] > ptr.data:
                parent = ptr
                ptr = ptr.right
                flag = 'r'      
        return A
                
    # Removes a node
    def _removeNode(self, wrapper):
        parent = wrapper.parent
        currNode = wrapper.actual
        flag = wrapper.flag
        
        # If trying to remove root node
        if parent == None: return self._removeRoot()
        
        # If removing a node with no children - simply remove
        if currNode.left == None and currNode.right == None:
            if flag == 'l': parent.left = None
            elif flag == 'r': parent.right = None
            elif flag == 'm': parent.midl = None
        # If node has a child on the right only - simply bypass
        elif currNode.left == None:
            if flag == 'l': parent.left = currNode.right
            elif flag == 'r': parent.right = currNode.right
            elif flag == 'm': parent.midl = currNode.right
        # If node has a child on the left only - simply bypass
        elif currNode.right == None:
            if flag == 'l': parent.left = currNode.left
            elif flag == 'r': parent.right = currNode.left
            elif flag == 'm': parent.midl = currNode.left
        # If node has children on the left and the right
        # Find min node and remove it from current position, then put it in
        # the position of the node that is being removed
        else:
            if currNode.right.left == None:
                minNode = currNode.right
                minNode.left = currNode.left
            else:
                minNode = self._removeAndReturnMin(currNode)
                minNode.left = currNode.left
                minNode.right = currNode.right
            if flag == 'l': parent.left = minNode
            elif flag == 'r': parent.right = minNode
            elif flag == 'm': parent.midl = minNode
                
    # Find the min node by going right once, then left as far as possible
    def _removeAndReturnMin(self, currNode):
        ptr = currNode.right
        while ptr.left != None:
            parent = ptr
            ptr = ptr.left
        parent.left = None
        return ptr

    # Remove the root of the tree
    def _removeRoot(self):
        if self.root.left == None and self.root.right == None:
            self.root = None
        elif self.root.left == None:
            self.root = self.root.right
        elif self.root.right == None:
            self.root = self.root.left
        else:
            if self.root.right.left == None:
                minNode = self.root.right
                minNode.left = self.root.left
            else:
                minNode = self._removeAndReturnMin(self.root)
                minNode.left = self.root.left
                minNode.right = self.root.right
            self.root = minNode