# ECS529U Algorithms and Data Structures
# Lab sheet 7

This lab gets you to work with binary trees and binary search trees in particular.

**Marks (max 5):**  Question 1: 1.5 | Questions 2-4: 1 each | Question 5: 0.5

## Question 1

This question is about understanding binary search trees (BSTs).

a) Draw the binary search tree we obtain if we start from the empty tree and add 
consecutively the numbers:

    21, 40, 3, 16, 39, 58, 21, 46, 1, 10

b) Write down the numbers of the tree you constructed, starting from the root using 
depth-first search, and using breadth-first search (and using pre-order)

c) Let `t` point to the root node of the BST you constructed in part a. Draw the BST that
results by applying each of the following operations:

    1. t.left = t.left.right
    2. t.left.right.right = t.left.right.left

In each of these cases, is the resulting structure a binary tree? Is it a binary search 
tree?

d) Starting each time from the tree you constructed in part a, perform the following removals (using the algorithm we saw in the lectures) and draw the resulting trees:

1. remove the node with value 16
2. remove the node with value 40

The rest of the Questions ask you to work with the `BST` class and variants thereof. To help you visualise trees, we have implemented the following "pretty printing" function for `BTNode` objects:

In [93]:
def niceStr(self): # this goes in the BTNode class
    S = ["├","─","└","│"]
    angle = S[2]+S[1]+" "
    vdash = S[0]+S[1]+" "
        
    def niceRec(ptr,acc,pre):
        if ptr == None: return acc+pre+"None"
        if ptr.left==ptr.right==None: return acc+pre+str(ptr.data)
        if pre == vdash: pre2 = S[3]+"  "
        elif pre == angle: pre2 = "   "
        else: pre2 = ""
        left = niceRec(ptr.right,acc+pre2,vdash)
        right = niceRec(ptr.left,acc+pre2,angle)
        return acc+pre+str(ptr.data)+"\n"+left+"\n"+right
            
        return niceRec(self,"","")

For example, the following tree

        22
       /  \
      20   42
     / \   / \
    11 21 22 44

is converted into a string that prints as follows:

    22 
    ├─ 42
    │  ├─ 44
    │  └─ 22
    └─ 20
       ├─ 21
       └─ 11

## Question 2

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def min(self)`

that returns the smallest element of the tree. If the tree is empty, the function should return `None`.

b) `def max(self)`

that returns the largest element of the tree. If the tree is empty, the function should return `None`.

c) `def removeAll(self, d)`

that removes all occurrences of the element `d` in the tree and returns the number of occurrences removed. For example, if the tree `t` is:

        22
       /  \
      20   42
     / \   / \
    11 21 22 44
    
then `t.removeAll(22)` should change `t` to:

        42
       /  \
      20   44
     / \  
    11 21 
    
and return `2`.

In [94]:
import random


class BTNode:
    def __init__(self, d,l,r):
        self.data = d
        self.left = l
        self.right = r
    
    def __str__(self):  
        st = str(self.data)+" -> ["
        if self.left != None:
            st += str(self.left)
        else: st += "None"
        if self.right != None:
            st += ", "+str(self.right)
        else: st += ", None"
        return st + "]"
    
    def niceStr(self): # this goes in the BTNode class
        S = ["├","─","└","│"]
        angle = S[2]+S[1]+" "
        vdash = S[0]+S[1]+" "
        
        def niceRec(ptr,acc,pre):
            if ptr == None: 
                return acc+pre+"None"
            if ptr.left==ptr.right==None:
                return acc+pre+str(ptr.data)
            if pre == vdash:
                pre2 = S[3]+"  "
            elif pre == angle: 
                pre2 = "   "
            else: 
                pre2 = ""
            left = niceRec(ptr.right,acc+pre2,vdash)
            right = niceRec(ptr.left,acc+pre2,angle)
            return acc+pre+str(ptr.data)+"\n"+left+"\n"+right
            
        return niceRec(self,"","")
    
    def updateChild(self, oldChild, newChild):
        if self.left == oldChild:
            self.left = newChild
        elif self.right == oldChild:
            self.right = newChild
        else: raise Exception("updateChild error")
    
class BST:
    def __init__(self):
        self.root = None
        self.size = 0

    def add(self, d):
        if self.root == None:
            self.root = BTNode(d,None,None)
        else:
            ptr = self.root
            while True:
                if d < ptr.data:
                    if ptr.left == None:
                        ptr.left = BTNode(d,None,None)
                        break
                    ptr = ptr.left
                else:
                    if ptr.right == None:
                        ptr.right = BTNode(d,None,None)
                        break
                    ptr = ptr.right
        self.size += 1
    
    def min(self):
        if self.root == None:
            return None
        
        ptr = self.root
        smallest = ptr.data
        while ptr is not None and ptr.left is not None:
            if smallest > ptr.left.data:
                smallest = ptr.left.data
            ptr = ptr.left
        
        return smallest
    
    def max(self):
        if self.root == None:
            return None
        
        ptr = self.root
        biggest = ptr.data
        while ptr is not None and ptr.right is not None:
            if biggest < ptr.right.data:
                biggest = ptr.right.data
            ptr = ptr.right
        
        return biggest
    
    def _removeNode(self, ptr, parentPtr):
        # there are 3 cases to consider:
        # 1. the node to be removed is a leaf (no children)
        if ptr.left == ptr.right == None:
            parentPtr.updateChild(ptr,None)
        # 2. the node to be removed has exactly one child            
        elif ptr.left == None:
            parentPtr.updateChild(ptr,ptr.right)
        elif ptr.right == None:
            parentPtr.updateChild(ptr,ptr.left)
        # 3. the node to be removed has both children
        else:
            # find the min node at the right of ptr -- and its parent
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left != None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            # replace the data of ptr with that of the min node
            ptr.data = minRNode.data
            # bypass the min node
            parentMinRNode.updateChild(minRNode,minRNode.right)
        

    def _removeRoot(self):
        # this is essentially a hack: we are adding a dummy node at 
        # the root and call the previous method -- it allows us to
        # re-use code
        parentRoot = BTNode(None,self.root,None)
        self._removeNode(self.root,parentRoot)
        self.root = parentRoot.left

    def _searchNode(self, ptr, d):
        while ptr != None:
            if d == ptr.data:
                return ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return None
    
    def search(self, d):   
        ptr = self.root
        while ptr != None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False
    
    def remove(self,d):
        if self.root == None: return
        if self.root.data == d: 
            self.size -= 1
            return self._removeRoot()
        parentPtr = None
        ptr = self.root
        while ptr != None and ptr.data != d:
            parentPtr = ptr                
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        if ptr != None:
            self.size -= 1
            self._removeNode(ptr,parentPtr)

    def removeAll(self, d):
        if self.root is None:
            return
        
        ptr = self.root
        while self.search(d):
            self.remove(d)
        return ptr



#TESTING:
A = [None] * 10
for i in range(len(A)):
    A[i] = random.randint(1,100)

bTree = BST()
for j in range(len(A)):
    bTree.add(A[j])

rootNode = bTree.root

print("Array: ")
print(A) 
print("--- MIN FUNCTION ---")
print(rootNode.niceStr())
print(bTree.min())

print("--- MAX FUNCTION ---")
print(rootNode.niceStr())
print(bTree.max())

print("--- REMOVEALL FUNCTION ---")
A = [3,1,3,5,1,7,7,7,9,10]
bTree = BST()
for j in range(len(A)):
    bTree.add(A[j])

rootNode = bTree.root

print(rootNode.niceStr())
print("Removing :" + str(A[5]))
bTree.removeAll(A[5])
print(rootNode.niceStr())




    

Array: 
[92, 76, 81, 16, 35, 77, 47, 88, 13, 66]
--- MIN FUNCTION ---
92
├─ None
└─ 76
   ├─ 81
   │  ├─ 88
   │  └─ 77
   └─ 16
      ├─ 35
      │  ├─ 47
      │  │  ├─ 66
      │  │  └─ None
      │  └─ None
      └─ 13
13
--- MAX FUNCTION ---
92
├─ None
└─ 76
   ├─ 81
   │  ├─ 88
   │  └─ 77
   └─ 16
      ├─ 35
      │  ├─ 47
      │  │  ├─ 66
      │  │  └─ None
      │  └─ None
      └─ 13
92
--- REMOVEALL FUNCTION ---
3
├─ 3
│  ├─ 5
│  │  ├─ 7
│  │  │  ├─ 7
│  │  │  │  ├─ 7
│  │  │  │  │  ├─ 9
│  │  │  │  │  │  ├─ 10
│  │  │  │  │  │  └─ None
│  │  │  │  │  └─ None
│  │  │  │  └─ None
│  │  │  └─ None
│  │  └─ None
│  └─ None
└─ 1
   ├─ 1
   └─ None
Removing :7
3
├─ 3
│  ├─ 5
│  │  ├─ 9
│  │  │  ├─ 10
│  │  │  └─ None
│  │  └─ None
│  └─ None
└─ 1
   ├─ 1
   └─ None


## Question 3

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def _sumAllRec(self, ptr)`

that <u>uses recursion</u> and returns the sum of all the elements of the subtree starting from the node `ptr`.

_Hint:_ you can simply use depth-first search, and ignoring the fact that this is a BST 
rather than a simple binary tree.

b) `def sumAll(self)`

that sums all the elements of the tree (use the function from part a).

c) `def sumAllBFS(self)`

that sums all the elements of the tree using breadth-first search.

_Hint:_ you can adapt the code for breadth-first search that we saw in the lecture 
(week 6). You will need to use a queue (see lecture of week 5).

In [100]:
class Node:
    def __init__(self, d, n):
        self.data = d
        self.next = n

class Stack:
    def __init__(self):
        self.inList = LinkedList()

    def __str__(self):
        return str(self.inList)
        
    def size(self):
        return self.inList.length

    def push(self, e):
        self.inList.insert(0,e)

    def pop(self):
        return self.inList.remove(0)

class LinkedList:
    def __init__(self):
        self.head = None
        self.length = 0

    def __str__(self):
        st = ""
        ptr = self.head
        while ptr != None:
            st = st + str(ptr.data)
            st = st+" -> "
            ptr = ptr.next
        return st+"None"
        
    def append(self, d):
        if self.head == None:      
            self.head = Node(d,None) 
        else:
            ptr = self.head
            while ptr.next != None:
                ptr = ptr.next
            ptr.next = Node(d,None)
        self.length += 1
    
    def insert(self, i, d):
        if self.head == None or i == 0:
            self.head = Node(d,self.head)
        else:
            ptr = self.head
            while i>1 and ptr.next != None:
                ptr = ptr.next
                i -= 1
            ptr.next = Node(d,ptr.next)
        self.length += 1

    def remove(self, i): # removes i-th element and returns it
        if self.head == None:
            return None
        if i == 0:
            val = self.head.data
            self.head = self.head.next
            self.length -= 1
            return val
        else:
            ptr = self.head
            while i>1 and ptr.next != None:
                ptr = ptr.next
                i -= 1
            if i == 1:
                val = ptr.next.data
                ptr.next = ptr.next.next
                self.length -= 1
                return val
            return None
        
class Queue:
    def __init__(self):
        self.inList = LinkedList()

    def __str__(self):
        return str(self.inList)
        
    def size(self):
        return self.inList.length

    def enq(self, e):
        self.inList.append(e)

    def deq(self):
        return self.inList.remove(0)
    

class BSTQ3(BST):
    def _sumAllRec(self, ptr):
        stack = Stack()
        stack.push(ptr)
        counter = 0
        def auxSum(stack, counter):
            if stack.size() < 1:
                return counter
            
            ptr = stack.pop()
            if ptr is None:
                return counter
            counter += ptr.data

            if ptr.left is not None:
                stack.push(ptr.left)
            
            if ptr.right is not None:
                stack.push(ptr.right)
            
            counter = auxSum(stack, counter)
            return counter

        return auxSum(stack, counter)

    def sumAll(self):
        return self._sumAllRec(self.root)
    

    def sumAllBFS(self):
        q = Queue()
        ptr = self.root

        q.enq(ptr)

        counter = 0
        def auxSum(q, counter):
            if q.size() < 1:
                return counter
            
            ptr = q.deq()
            if ptr is None:
                return counter
            counter += ptr.data

            if ptr.left is not None:
                q.enq(ptr.left)
            if ptr.right is not None:
                q.enq(ptr.right)

            counter = auxSum(q, counter)
            return counter
        
        return auxSum(q, counter)

print("--- RECURSION-SUM FUNCTION ---")
print("Part 1")
A = [3,1,3,5,1,7,7,7,9,10]
targetSum = 0
for i in A:
    targetSum += i
print("Target: " + str(targetSum))
bTree = BSTQ3()
for j in range(len(A)):
    bTree.add(A[j])

rootNode = bTree.root

print(rootNode.niceStr())
print("Sum is:" + str(bTree._sumAllRec(rootNode.left)))

print("Part 2")
print("Sum is:" + str(bTree.sumAll()))

print("Part 3 (BFS)")
print("Sum is:" + str(bTree.sumAllBFS()))


--- RECURSION-SUM FUNCTION ---
Part 1
Target: 53
3
├─ 3
│  ├─ 5
│  │  ├─ 7
│  │  │  ├─ 7
│  │  │  │  ├─ 7
│  │  │  │  │  ├─ 9
│  │  │  │  │  │  ├─ 10
│  │  │  │  │  │  └─ None
│  │  │  │  │  └─ None
│  │  │  │  └─ None
│  │  │  └─ None
│  │  └─ None
│  └─ None
└─ 1
   ├─ 1
   └─ None
Sum is:2
Part 2
Sum is:53
Part 3 (BFS)
Sum is:53


## Question 4

Add in `BST` a function 

    def toSortedArray(self)

that returns an array containing the elements of the tree in ascending order.

_Hint:_ Use a helper function to do an inorder traversal of the BST.

In [114]:
class BSTQ4(BSTQ3):
    def toSortedArray(self):
        if self.root is None:
            return []
        
        A = [None] * self.size
        counter = 0
        stack = Stack()
        ptr = self.root
        while ptr is not None or stack.size() > 0:

            while ptr is not None:
                stack.push(ptr)
                ptr = ptr.left
            
            ptr = stack.pop()
            A[counter] = ptr.data
            counter += 1

            ptr = ptr.right
        
        return A
            

# Minimal testing Questions 2-4
print("Question 2")
t = BSTQ4()
print(str(t.root))
print(t.min(),t.max())
A = [22,20,11,21,42,22,44,1]
for x in A: t.add(x)
print(t.root.niceStr())
print(t.min(),t.max())
t.removeAll(22)
print(t.root.niceStr())

print("\nQuestion 3")
t = BSTQ4()
print(str(t.root))
print(t.sumAll(), t.sumAllBFS())
A = [22,20,11,21,42,22,44]
for x in A: t.add(x)
print(t.root.niceStr())
print(t.sumAll(), t.sumAllBFS())

print("\nQuestion 4")
print(BSTQ4().toSortedArray())
print(t.toSortedArray())

Question 2
None
None None
22
├─ 42
│  ├─ 44
│  └─ 22
└─ 20
   ├─ 21
   └─ 11
      ├─ None
      └─ 1
1 44
42
├─ 44
└─ 20
   ├─ 21
   └─ 11
      ├─ None
      └─ 1

Question 3
None
0 0
22
├─ 42
│  ├─ 44
│  └─ 22
└─ 20
   ├─ 21
   └─ 11
182 182

Question 4
[]
[11, 20, 21, 22, 22, 42, 44]


## Question 5

You are asked to write a class `BST2` which implements a BST in which each node has a multiplicity counter (`mult`), which counts how many times the node's value is stored in the tree. This way, there is no need to store duplicate nodes in the tree: 
- adding a value that already exists in the tree simply amounts to increase the counter of the value's node by 1; 
- removing a value from the tree amounts to reducing the counter of its node by 1, and if the counter becomes 0 then the node is removed altogether.

Below we have provided you with a class of nodes `BTNode2` to use, and we made a start in implementing `BST2`.You are asked to implement the following functions:

- `add(self,d)` for adding the value `d` in the BST2. This should use BST 
search and either increase the `mult` counter of the `BTNode2` containing `d` or, if `d` is not in the tree, create a new `BTNode2` for `d`.

- `search(self,d)` for searching the value `d` in the BST2. This should use BST 
search and return `True` if the value is found, and `False` otherwise.

- `count(self,d)` for counting the times the value `d` appears in the BST2. This 
should use BST search and return the number of times that the value appears in 
the BST2.

- `remove(self,d)` for removing one occurrence of the value `d` from the BST2.

In [147]:
class BTNode2:
    def __init__(self,d,l,r):
        self.data = d
        self.left = l
        self.right = r
        self.mult = 1

    def updateChild(self, oldChild, newChild):
        if self.left == oldChild:
            self.left = newChild
        elif self.right == oldChild:
            self.right = newChild
        else: raise Exception("updateChild error")


    def niceStr(self): # this goes in the BTNode class
        S = ["├","─","└","│"]
        angle = S[2]+S[1]+" "
        vdash = S[0]+S[1]+" "
        
        def niceRec(ptr,acc,pre):
            if ptr == None: 
                return acc+pre+"None"
            if ptr.left==ptr.right==None:
                return acc+pre+str(ptr.data)+"{"+str(ptr.mult)+"}"
            if pre == vdash:
                pre2 = S[3]+"  "
            elif pre == angle: 
                pre2 = "   "
            else: 
                pre2 = ""
            left = niceRec(ptr.right,acc+pre2,vdash)
            right = niceRec(ptr.left,acc+pre2,angle)
            return acc+pre+str(ptr.data)+"{"+str(ptr.mult)+"}"+"\n"+left+"\n"+right
            
        return niceRec(self,"","")

    # prints the node and all its children in a string
    def __str__(self):  
        st = str(self.data)+" ("+str(self.mult)+")-> ["
        if self.left != None:
            st += str(self.left)
        else: st += "None"
        if self.right != None:
            st += ", "+str(self.right)
        else: st += ", None"
        return st + "]"


class BST2(BSTQ4):
    def add(self, d):
        if self.root == None:
            self.root = BTNode2(d,None,None)
        else:
            ptr = self.root
            while True:
                if d < ptr.data:
                    if ptr.left == None:
                        ptr.left = BTNode2(d,None,None)
                        break
                    ptr = ptr.left
                elif d > ptr.data:
                    if ptr.right == None:
                        ptr.right = BTNode2(d,None,None)
                        break
                    ptr = ptr.right
                else:
                    ptr.mult += 1
                    return
        self.size += 1
    
    def remove(self, d):
        if self.root == None:
            return

        ptr = self.root
        parent = None
        while ptr:
            if d < ptr.data:
                parent, ptr = ptr, ptr.left
            elif d > ptr.data:
                parent, ptr = ptr, ptr.right
            elif d == ptr.data:
                ptr.mult -= 1
                if ptr.mult <= 0:
                    if parent is None:
                        self.root = self._removeNodeRec(ptr)
                    else:
                        newChild = self._removeNodeRec(ptr)
                        parent.updateChild(ptr, newChild)
                    self.size -= 1
                return
        return
    
    def _removeRec(self, ptr, d):
        if ptr == None: return None
        if ptr.data == d: 
            return self._removeNodeRec(ptr)
        if ptr.data < d:
            ptr.right = self._removeRec(ptr.right, d)
        else:
            ptr.left = self._removeRec(ptr.left, d)
        return ptr
    
    # removes the node ptr from the tree and returns the remaining tree
    def _removeNodeRec(self, ptr):
        self.size -= 1
        # there are 3 cases to consider:
        # 1. the node to be removed is a leaf (no children)
        if ptr.left == ptr.right == None:
            return None
        # 2. the node to be removed has exactly one child
        elif ptr.right == None:
            return ptr.left
        elif ptr.left == None:
            return ptr.right
        # 3. the node to be removed has both children
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            
            while minRNode.left != None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data, minRNode.data = minRNode.data, ptr.data
            ptr.mult, minRNode.mult = minRNode.mult, ptr.mult

            if minRNode.mult > 1:
                minRNode.mult -= 1
            else:
                parentMinRNode.updateChild(minRNode, minRNode.right)
            return ptr
    
    def count(self, d):
        ptr = self.root
        ptr = self._searchNode(ptr,d)
        return ptr.mult if ptr else 0
    

    def _searchNodeRec(self, ptr, d):
        if ptr != None:
            if d == ptr.data:
                return ptr
            if d < ptr.data:
                return self._searchNodeRec(ptr.left, d)
            else:
                return self._searchNodeRec(ptr.right, d)
        return None
    
print("Question 5")
t = BST2()
A = [22,20,11,21,42,11,22,44,1]
for x in A: t.add(x)
print(t.root.niceStr())
for x in A:
    print(x,t.search(x),t.count(x),t.search(-x),t.count(-x))
for x in A:
    t.remove(x)
    if t.root is not None:
        print("take",x,":\n",t.root.niceStr())
    else:
        print("take",x,":\n",str(t.root))

Question 5
22{2}
├─ 42{1}
│  ├─ 44{1}
│  └─ None
└─ 20{1}
   ├─ 21{1}
   └─ 11{2}
      ├─ None
      └─ 1{1}
22 True 2 False 0
20 True 1 False 0
11 True 2 False 0
21 True 1 False 0
42 True 1 False 0
11 True 2 False 0
22 True 2 False 0
44 True 1 False 0
1 True 1 False 0
take 22 :
 22{1}
├─ 42{1}
│  ├─ 44{1}
│  └─ None
└─ 20{1}
   ├─ 21{1}
   └─ 11{2}
      ├─ None
      └─ 1{1}
take 20 :
 22{1}
├─ 42{1}
│  ├─ 44{1}
│  └─ None
└─ 21{1}
   ├─ None
   └─ 11{2}
      ├─ None
      └─ 1{1}
take 11 :
 22{1}
├─ 42{1}
│  ├─ 44{1}
│  └─ None
└─ 21{1}
   ├─ None
   └─ 11{1}
      ├─ None
      └─ 1{1}
take 21 :
 22{1}
├─ 42{1}
│  ├─ 44{1}
│  └─ None
└─ 11{1}
   ├─ None
   └─ 1{1}
take 42 :
 22{1}
├─ 44{1}
└─ 11{1}
   ├─ None
   └─ 1{1}
take 11 :
 22{1}
├─ 44{1}
└─ 1{1}
take 22 :
 44{1}
├─ None
└─ 1{1}
take 44 :
 1{1}
take 1 :
 None
