### Simple bst with lock

The idea is to use a lock on each node. Grab the lock on read or write, and release it when finished. The code below might be problemetic because Concurrent program is very hard to test, but I believe the program should be able to give you right results on insertion. For deletion, I couldn't guarantee its correctness.

I think you should also implement an iterator, which I'm not quite sure how to do for now, but it should be similar to get() in the following code.

In [22]:
from threading import Thread
from threading import Lock

In [41]:
class TreeNode:
    def __init__(self, key, val, left=None, right=None, parent=None):
        self.key = key
        self.val = val
        self.left = left
        self.right = right
        self.parent = parent
        # lock per node
        self.lock = Lock()
        
class Tree:
    def __init__(self, root):
        self.root = root
        
    def set(self, key, val):
        if self.root is None:
            self.root = TreeNode(key, val)
        self._update(self.root, key, val)
    
    def _update(self, node, key, val):
        try:
            node.lock.acquire()
        except AttributeError:
            return
        if node.left is None and key < node.key:
            node.left = TreeNode(key, val, parent=node)
            node.lock.release()
        elif node.right is None and key > node.key:
            node.right = TreeNode(key, val, parent=node)
            node.lock.release()
        elif node.key == key:
            node.val = val
            node.lock.release()
        elif node.key < key:
            node.lock.release()
            self._update(node.right, key, val)
        else:
            node.lock.release()
            self._update(node.left, key, val)
        
    def delete(self, key):
        self._del(self.root, key)
    
    def _del(self, node, key):
        try:
            node.lock.acquire()
        except AttributeError:
            return
        if node.key == key:
            if node.left is None and node.right is None:
                node.parent.lock.acquire()
                if node is node.parent.left:
                    node.parent.left = None
                else:
                    node.parent.right = None
                node.parent.lock.release()
                node.lock.release()
            elif node.left is not None:
                cur = node.left
                while cur is not None and cur.right is not None:
                    cur = cur.right
                if cur is not None:
                    cur.lock.acquire()
                    node.key = cur.key
                    node.val = cur.val
                    cur.lock.release()
                    node.lock.release()
                    self._del(cur, cur.key)
            else:
                cur = node.right
                while cur is not None and cur.left is not None:
                    cur = cur.left
                if cur is not None:
                    cur.lock.acquire()
                    node.key = cur.key
                    node.val = cur.val
                    cur.lock.release()
                    node.lock.release()
                    self._del(cur, cur.key)
        elif node.key < key:
            node.lock.release()
            self._del(node.right, key)
        else:
            node.lock.release()
            self._del(node.left, key)
                
    def get(self, key):
        node = self.root
        while node is not None:
            try:
                node.lock.acquire()
            except AttributeError:
                break
            if key < node.key:
                tmpNode = node
                node = node.left
                tmpNode.lock.release()
            elif key > node.key:
                tmpNode = node
                node = node.right
                tmpNode.lock.release()
            else:
                val = node.val
                node.lock.release()
                return val
        raise KeyError('No such key')

    

In [42]:
root = TreeNode(0, 0)
tree = Tree(root)

In [43]:
def write(tree, key, val):
    tree.set(key, val)
    
def read(tree, key):
    print (key, tree.get(key))
    
def delete(tree, key):
    tree.delete(key)

In [44]:
Thread(target=write, args=(tree, 1, 1)).start()
Thread(target=write, args=(tree, 2, 2)).start()
Thread(target=write, args=(tree, 3, 3)).start()
Thread(target=write, args=(tree, 4, 4)).start()

In [45]:
Thread(target=read, args=(tree, 1)).start()
Thread(target=read, args=(tree, 2)).start()
Thread(target=read, args=(tree, 3)).start()
Thread(target=read, args=(tree, 4)).start()

(1, 1)
(2, 2)
(3, 3)
(4, 4)
