# ECS529U Algorithms and Data Structures
# Lab sheet 7

This lab gets you to work with binary trees and binary search trees in particular.

**Marks (max 5):**  Questions 1-5: 1 | Question 6: formative

## Question 1

This question is about understanding binary search trees (BSTs).

a) Draw the binary search tree we obtain if we start from the empty tree and add 
consecutively the strings: (use alphabetical ordering)

    egg, gem, bat, dog, fox, ink, egg, hat, ant, cat

b) Write down the strings of the tree you constructed, starting from the root using 
depth-first search, and using breadth-first search (and using pre-order)

c) Let `t` point to the root node of the BST you constructed in part a. Draw the BST that
results by applying each of the following operations:

    1. t.left = t.left.right
    2. t.left.right.right = t.left.right.left

In each of these cases, is the resulting structure a binary tree? Is it a binary search 
tree?

d) Starting each time from the tree you constructed in part a, perform the following removals (using the algorithm we saw in the lectures) and draw the resulting trees:

1. remove the node with value dog
2. remove the node with value gem

The rest of the Questions ask you to work with the `BST` class and variants thereof. To help you visualise trees, we have implemented the following "pretty printing" function for `BTNode` objects:

In [17]:
    def niceStr(self): # this goes in the BTNode class
        S = ["├","─","└","│"]
        angle = S[2]+S[1]+" "
        vdash = S[0]+S[1]+" "
        
        def niceRec(ptr,acc,pre):
            if ptr == None: return acc+pre+"None"
            if ptr.left==ptr.right==None: return acc+pre+str(ptr.data)
            if pre == vdash: pre2 = S[3]+"  "
            elif pre == angle: pre2 = "   "
            else: pre2 = ""
            left = niceRec(ptr.right,acc+pre2,vdash)
            right = niceRec(ptr.left,acc+pre2,angle)
            return acc+pre+str(ptr.data)+"\n"+left+"\n"+right
            
        return niceRec(self,"","")

For example, the following tree

        22
       /  \
      20   42
     / \   / \
    11 21 22 44

is converted into a string that prints as follows:

    22 
    ├─ 42
    │  ├─ 44
    │  └─ 22
    └─ 20
       ├─ 21
       └─ 11

## Question 2

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def min(self)`

that returns the smallest element of the tree. If the tree is empty, the function should return `None`.

b) `def max(self)`

that returns the largest element of the tree. If the tree is empty, the function should return `None`.

c) `def removeAll(self, d)`

that removes all occurrences of the element `d` in the tree and returns the number of occurrences removed. For example, if the tree `t` is:

        22
       /  \
      20   42
     / \   / \
    11 21 22 44
    
then `t.removeAll(22)` should change `t` to:

        42
       /  \
      20   44
     / \  
    11 21 
    
and return `2`.

In [11]:
class BTNode:
    def __init__(self, d, l, r):
        self.data = d
        self.left = l
        self.right = r

class BST:
    def __init__(self):
        self.root = None
        self.size = 0

        #####################################################################################################
    def min(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.left is not None:
            ptr = ptr.left
        return ptr.data

    def max(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.right is not None:
            ptr = ptr.right
        return ptr.data

    def removeAll(self, d):
        count = 0
        while self._removeSingleOccurrence(d):
            count += 1
        return count

    def _removeSingleOccurrence(self, d):
        if self.root is None:
            return False
            
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                self._removeNode(ptr, parentPtr)
                return True
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False
    ##################################################################################################

    def search(self, d):
        ptr = self.root
        while ptr is not None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        self.size += 1
        if self.root is None:
            self.root = BTNode(d, None, None)
            return
        ptr = self.root
        while True:
            if d < ptr.data:
                if ptr.left is None:
                    ptr.left = BTNode(d, None, None)
                    return
                ptr = ptr.left
            else:
                if ptr.right is None:
                    ptr.right = BTNode(d, None, None) 
                    return
                ptr = ptr.right

    def remove(self, d):
        if self.root is None: 
            return
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                return self._removeNode(ptr, parentPtr)
            parentPtr = ptr
            if d < ptr.data: 
                ptr = ptr.left
            else:
                ptr = ptr.right

    def _removeNode(self, ptr, parentPtr):
        def updateChild(ptr, oldChild, newChild):
            if ptr is None:
                self.root = newChild
            elif ptr.left == oldChild:
                ptr.left = newChild
            elif ptr.right == oldChild:
                ptr.right = newChild
            else: 
                raise Exception("updateChild error")
                
        if ptr.left == ptr.right == None:
            updateChild(parentPtr, ptr, None)
        elif ptr.left is None:
            updateChild(parentPtr, ptr, ptr.right)
        elif ptr.right is None:
            updateChild(parentPtr, ptr, ptr.left)
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left is not None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            updateChild(parentMinRNode, minRNode, minRNode.right)

## Question 3

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def _sumAllRec(self, ptr)`

that <u>uses recursion</u> and returns the sum of all the elements of the subtree starting from the node `ptr`.

_Hint:_ you can simply use depth-first search, and ignoring the fact that this is a BST 
rather than a simple binary tree.

b) `def sumAll(self)`

that sums all the elements of the tree (use the function from part a).

c) `def sumAllBFS(self)`

that sums all the elements of the tree using breadth-first search.

_Hint:_ you can adapt the code for breadth-first search that we saw in the lecture 
(week 6). You will need to use a queue (see lecture of week 5).

In [31]:
class Node:
    def __init__(self, d, n):
        self.data = d
        self.next = n

class LinkedList:
    def __init__(self):
        self.head = None
        self.length = 0

    def __str__(self):
        st = "--> "
        ptr = self.head
        while ptr != None:
            st = st + str(ptr.data)
            st = st+" -> "
            ptr = ptr.next
        return st+"None"
        
    def search(self, d):
        i = 0
        ptr = self.head
        while ptr != None:
            if ptr.data == d:
                return i
            ptr = ptr.next
            i += 1
        return -1
        
    def append(self, d):
        if self.head == None:      
            self.head = Node(d,None) 
        else:
            ptr = self.head
            while ptr.next != None:
                ptr = ptr.next
            ptr.next = Node(d,None)
        self.length += 1

    def insert(self, i, d):
        if self.head == None or i == 0:
            self.head = Node(d,self.head)
        else:
            ptr = self.head
            while i>1 and ptr.next != None:
                ptr = ptr.next
                i -= 1
            ptr.next = Node(d,ptr.next)
        self.length += 1

    def remove(self, i): # removes i-th element and returns it
        if self.head == None:
            return None
        if i == 0:
            val = self.head.data
            self.head = self.head.next
            self.length -= 1
            return val
        ptr = self.head
        while ptr.next != None:
            if i == 1:
                val = ptr.next.data
                ptr.next = ptr.next.next
                self.length -= 1
                return val                
            ptr = ptr.next
            i -= 1
    
    # removes the first occurrence of d if found
    # returns True if d removed, otherwise False
    def removeVal(self, d):
        if self.head == None: return False
        if self.head.data == d:
            self.head = self.head.next
            self.length -= 1
            return True
        ptr = self.head	
        while ptr.next != None:
            if ptr.next.data == d:
                ptr.next = ptr.next.next
                self.length -= 1
                return True
            ptr = ptr.next
        return False
    
    def sublist(self, i):
        ptr = self.head
        ls = LinkedList()
        ls.length = self.length
        while ptr != None and i>0:
            ptr = ptr.next
            i -= 1
            ls.length -= 1
        ls.head = ptr
        return ls

class Queue:
    def __init__(self):
        self.inList = LinkedList()

    def __str__(self):
        return str(self.inList)
        
    def size(self):
        return self.inList.length

    def enq(self, e):
        self.inList.append(e)

    def deq(self):
        return self.inList.remove(0)

    def is_empty(self):
        return self.inList.length == 0
    

In [32]:
class BTNode:
    def __init__(self, d, l, r):
        self.data = d
        self.left = l
        self.right = r

class BST:
    def __init__(self):
        self.root = None
        self.size = 0
        
    def min(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.left is not None:
            ptr = ptr.left
        return ptr.data

    def max(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.right is not None:
            ptr = ptr.right
        return ptr.data

    def removeAll(self, d):
        count = 0
        while self._removeSingleOccurrence(d):
            count += 1
        return count

    def _removeSingleOccurrence(self, d):
        if self.root is None:
            return False
            
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                self._removeNode(ptr, parentPtr)
                return True
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    ##################################################################################################
    
    def _sumAllRec(self, ptr):
        if ptr is None:
            return 0
        return self._sumAllRec(ptr.left) + ptr.data + self._sumAllRec(ptr.right)

    def sumAll(self):
        return self._sumAllRec(self.root)

    def sumAllBFS(self):
        if self.root is None:
            return 0
            
        total = 0
        queue = Queue() 
        queue.enq(self.root)
        while not queue.is_empty():
            current = queue.deq()
            total += current.data
            
            if current.left is not None:
                queue.enq(current.left)
                
            if current.right is not None:
                queue.enq(current.right)
                
        return total
    ##################################################################################################

    def search(self, d):
        ptr = self.root
        while ptr is not None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        self.size += 1
        if self.root is None:
            self.root = BTNode(d, None, None)
            return
        ptr = self.root
        while True:
            if d < ptr.data:
                if ptr.left is None:
                    ptr.left = BTNode(d, None, None)
                    return
                ptr = ptr.left
            else:
                if ptr.right is None:
                    ptr.right = BTNode(d, None, None) 
                    return
                ptr = ptr.right

    def remove(self, d):
        if self.root is None: 
            return
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                return self._removeNode(ptr, parentPtr)
            parentPtr = ptr
            if d < ptr.data: 
                ptr = ptr.left
            else:
                ptr = ptr.right

    def _removeNode(self, ptr, parentPtr):
        def updateChild(ptr, oldChild, newChild):
            if ptr is None:
                self.root = newChild
            elif ptr.left == oldChild:
                ptr.left = newChild
            elif ptr.right == oldChild:
                ptr.right = newChild
            else: 
                raise Exception("updateChild error")
                
        if ptr.left == ptr.right == None:
            updateChild(parentPtr, ptr, None)
        elif ptr.left is None:
            updateChild(parentPtr, ptr, ptr.right)
        elif ptr.right is None:
            updateChild(parentPtr, ptr, ptr.left)
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left is not None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            updateChild(parentMinRNode, minRNode, minRNode.right)

## Question 4

Add in `BST` a function 

    def toSortedArray(self)

that returns an array containing the elements of the tree in ascending order.

_Hint:_ Use a helper function to do an inorder traversal of the BST.

In [33]:
class BTNode:
    def __init__(self, d, l, r):
        self.data = d
        self.left = l
        self.right = r

class BST:
    def __init__(self):
        self.root = None
        self.size = 0
        
    def min(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.left is not None:
            ptr = ptr.left
        return ptr.data

    def max(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.right is not None:
            ptr = ptr.right
        return ptr.data

    def removeAll(self, d):
        count = 0
        while self._removeSingleOccurrence(d):
            count += 1
        return count

    def _removeSingleOccurrence(self, d):
        if self.root is None:
            return False
            
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                self._removeNode(ptr, parentPtr)
                return True
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False
    
    def _sumAllRec(self, ptr):
        if ptr is None:
            return 0
        return self._sumAllRec(ptr.left) + ptr.data + self._sumAllRec(ptr.right)

    def sumAll(self):
        return self._sumAllRec(self.root)

    def sumAllBFS(self):
        if self.root is None:
            return 0
            
        total = 0
        queue = Queue() 
        queue.enq(self.root)
        while not queue.is_empty():
            current = queue.deq()
            total += current.data
            
            if current.left is not None:
                queue.enq(current.left)
                
            if current.right is not None:
                queue.enq(current.right)
                
        return total

    ##################################################################################################
    
    def _inorderTraversal(self, ptr, result):
        if ptr is None:
            return
            
        self._inorderTraversal(ptr.left, result)
        result.append(ptr.data)
        self._inorderTraversal(ptr.right, result)

    def toSortedArray(self):
        result = []
        self._inorderTraversal(self.root, result)
        return result
    ##################################################################################################

    def search(self, d):
        ptr = self.root
        while ptr is not None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        self.size += 1
        if self.root is None:
            self.root = BTNode(d, None, None)
            return
        ptr = self.root
        while True:
            if d < ptr.data:
                if ptr.left is None:
                    ptr.left = BTNode(d, None, None)
                    return
                ptr = ptr.left
            else:
                if ptr.right is None:
                    ptr.right = BTNode(d, None, None) 
                    return
                ptr = ptr.right

    def remove(self, d):
        if self.root is None: 
            return
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                return self._removeNode(ptr, parentPtr)
            parentPtr = ptr
            if d < ptr.data: 
                ptr = ptr.left
            else:
                ptr = ptr.right

    def _removeNode(self, ptr, parentPtr):
        def updateChild(ptr, oldChild, newChild):
            if ptr is None:
                self.root = newChild
            elif ptr.left == oldChild:
                ptr.left = newChild
            elif ptr.right == oldChild:
                ptr.right = newChild
            else: 
                raise Exception("updateChild error")
                
        if ptr.left == ptr.right == None:
            updateChild(parentPtr, ptr, None)
        elif ptr.left is None:
            updateChild(parentPtr, ptr, ptr.right)
        elif ptr.right is None:
            updateChild(parentPtr, ptr, ptr.left)
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left is not None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            updateChild(parentMinRNode, minRNode, minRNode.right)

## Question 5

Add in `BST` the function

    def __eq__(self, other)

that returns `True` if and only if the size of `self` is equal to the size of `other` and for every node in `self` with data `d` there is a _unique_ node in `other` with data `d`. For example

        3                              2
       /                              / \
      2       is equal to            1   3
     /
    1

In [34]:
class BTNode:
    def __init__(self, d, l, r):
        self.data = d
        self.left = l
        self.right = r

class BST:
    def __init__(self):
        self.root = None
        self.size = 0

    def niceStr(self): # this goes in the BTNode class
        S = ["├","─","└","│"]
        angle = S[2]+S[1]+" "
        vdash = S[0]+S[1]+" "
        
        def niceRec(ptr,acc,pre):
            if ptr == None: return acc+pre+"None"
            if ptr.left==ptr.right==None: return acc+pre+str(ptr.data)
            if pre == vdash: pre2 = S[3]+"  "
            elif pre == angle: pre2 = "   "
            else: pre2 = ""
            left = niceRec(ptr.right,acc+pre2,vdash)
            right = niceRec(ptr.left,acc+pre2,angle)
            return acc+pre+str(ptr.data)+"\n"+left+"\n"+right
            
        return niceRec(self,"","")

    def __str__(self):
        return str(self.toSortedArray())
        
    def min(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.left is not None:
            ptr = ptr.left
        return ptr.data

    def max(self):
        if self.root is None:
            return None
            
        ptr = self.root
        while ptr.right is not None:
            ptr = ptr.right
        return ptr.data

    def removeAll(self, d):
        count = 0
        while self._removeSingleOccurrence(d):
            count += 1
        return count

    def _removeSingleOccurrence(self, d):
        if self.root is None:
            return False
            
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                self._removeNode(ptr, parentPtr)
                return True
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False
    
    def _sumAllRec(self, ptr):
        if ptr is None:
            return 0
        return self._sumAllRec(ptr.left) + ptr.data + self._sumAllRec(ptr.right)

    def sumAll(self):
        return self._sumAllRec(self.root)

    def sumAllBFS(self):
        if self.root is None:
            return 0
            
        total = 0
        queue = Queue() 
        queue.enq(self.root)
        while not queue.is_empty():
            current = queue.deq()
            total += current.data
            
            if current.left is not None:
                queue.enq(current.left)
                
            if current.right is not None:
                queue.enq(current.right)
                
        return total

    def _inorderTraversal(self, ptr, result):
        if ptr is None:
            return
        self._inorderTraversal(ptr.left, result)
        result.append(ptr.data)
        self._inorderTraversal(ptr.right, result)

    def toSortedArray(self):
        result = []
        self._inorderTraversal(self.root, result)
        return result
    ##################################################################################################
    
    def __eq__(self, other):
        if not isinstance(other, BST):
            return False
            
        if self.size != other.size:
            return False
            
        if self.size == 0 and other.size == 0:
            return True
            
        self_sorted = self.toSortedArray()
        other_sorted = other.toSortedArray()
        
        return self_sorted == other_sorted

    ##################################################################################################

    def search(self, d):
        ptr = self.root
        while ptr is not None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        self.size += 1
        if self.root is None:
            self.root = BTNode(d, None, None)
            return
        ptr = self.root
        while True:
            if d < ptr.data:
                if ptr.left is None:
                    ptr.left = BTNode(d, None, None)
                    return
                ptr = ptr.left
            else:
                if ptr.right is None:
                    ptr.right = BTNode(d, None, None) 
                    return
                ptr = ptr.right

    def remove(self, d):
        if self.root is None: 
            return
        parentPtr = None
        ptr = self.root
        while ptr is not None:
            if ptr.data == d:
                self.size -= 1
                return self._removeNode(ptr, parentPtr)
            parentPtr = ptr
            if d < ptr.data: 
                ptr = ptr.left
            else:
                ptr = ptr.right

    def _removeNode(self, ptr, parentPtr):
        def updateChild(ptr, oldChild, newChild):
            if ptr is None:
                self.root = newChild
            elif ptr.left == oldChild:
                ptr.left = newChild
            elif ptr.right == oldChild:
                ptr.right = newChild
            else: 
                raise Exception("updateChild error")
                
        if ptr.left == ptr.right == None:
            updateChild(parentPtr, ptr, None)
        elif ptr.left is None:
            updateChild(parentPtr, ptr, ptr.right)
        elif ptr.right is None:
            updateChild(parentPtr, ptr, ptr.left)
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left is not None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            updateChild(parentMinRNode, minRNode, minRNode.right)

In [35]:
# Minimal testing Questions 2-5

print("Question 2")
t = BST()
print(t)
print(t.min(),t.max())
A = [22,20,11,21,42,22,44,1]
for x in A: t.add(x)
print(t)
print(t.min(),t.max())
t.removeAll(22)
print(t)

print("\nQuestion 3")
t = BST()
print(t)
print(t.sumAll(), t.sumAllBFS())
A = [22,20,11,21,42,22,44]
for x in A: t.add(x)
print(t)
print(t.sumAll(), t.sumAllBFS())

print("\nQuestion 4")
print(BST().toSortedArray())
print(t.toSortedArray())

print("\nQuestion 5")
t1 = BST()
A = [3,2,1]
for x in A: t1.add(x)
print(t1)

t2 = BST()
B = [2,1,3]
for x in B: t2.add(x)
print(t2)

print(t1 == t2)
t2.remove(1)
print(t1 == t2)

Question 2
[]
None None
[1, 11, 20, 21, 22, 22, 42, 44]
1 44
[1, 11, 20, 21, 42, 44]

Question 3
[]
0 0
[11, 20, 21, 22, 22, 42, 44]
182 182

Question 4
[]
[11, 20, 21, 22, 22, 42, 44]

Question 5
[1, 2, 3]
[1, 2, 3]
True
False


## Question 6

You are asked to write a class `BST2` which implements a BST in which each node has a multiplicity counter (`mult`), which counts how many times the node's value is stored in the tree. This way, there is no need to store duplicate nodes in the tree: 
- adding a value that already exists in the tree simply amounts to increase the counter of the value's node by 1; 
- removing a value from the tree amounts to reducing the counter of its node by 1, and if the counter becomes 0 then the node is removed altogether.

Below we have provided you with a class of nodes `BTNode2` to use, and we made a start in implementing `BST2`.You are asked to implement the following functions:

- `add(self,d)` for adding the value `d` in the BST2. This should use BST 
search and either increase the `mult` counter of the `BTNode2` containing `d` or, if `d` is not in the tree, create a new `BTNode2` for `d`.

- `search(self,d)` for searching the value `d` in the BST2. This should use BST 
search and return `True` if the value is found, and `False` otherwise.

- `count(self,d)` for counting the times the value `d` appears in the BST2. This 
should use BST search and return the number of times that the value appears in 
the BST2.

- `remove(self,d)` for removing one occurrence of the value `d` from the BST2.

In [None]:
class BTNode2:
    def __init__(self,d,l,r):
        self.data = d
        self.left = l
        self.right = r
        self.mult = 1
          
    # prints the node and all its children in a string
    def __str__(self):  
        st = str(self.data)+" ("+str(self.mult)+")-> ["
        if self.left != None:
            st += str(self.left)
        else: st += "None"
        if self.right != None:
            st += ", "+str(self.right)
        else: st += ", None"
        return st + "]"
    
class BST2:
    def __init__(self):
        self.root = None
        self.size = 0
        
    def __str__(self):
        if self.root == None: return "None"        
        return str(self.root)
        
print("Question 6")
t = BST2()
A = [22,20,11,21,42,11,22,44,1]
for x in A: t.add(x)
print(t)
for x in A:
    print(x,t.search(x),t.count(x),t.search(-x),t.count(-x))
for x in A:
    t.remove(x); print("take",x,":\n",t)