In [217]:
# Path ORAM 

from collections import defaultdict, Counter
from turtle import update
import random

class Node(object): 
	def __init__(self, ary, idx, left, right): 
		self.val = ary
		self.idx = idx
		self.left = left 
		self.right = right
		
	def __repr__(self):
		return f"({self.idx},{self.val},{self.left},{self.right})"


In [218]:
def getNewNode(val, idx,):
    newNode = Node(val, idx,None, None)
    return newNode

In [219]:
def buildPBT_helper(start, end, pre): 
    if start > end:
        return None 
    root = getNewNode([], pre[start])

    if(start == end):
        return root 
    
    leftStart = start + 1 
    rightStart = leftStart + int((end - leftStart + 1) /2)
    leftEnd = rightStart - 1
    rightEnd = end 

    root.left = buildPBT_helper(leftStart, leftEnd, pre)
    root.right = buildPBT_helper(rightStart, rightEnd, pre)
    return root 


In [220]:
# constructing a simple binary tree
# actually each of the int in the arrary should be a pointer to a block of size B
#    0    l0
#   /\
#  1  4   l1
#  /\ /\
# 2 3 5 6 l2

def buildPBT(pre, size):
    return buildPBT_helper(0, size - 1, pre)

rt = buildPBT(range(7),7)
rt

(0,[],(1,[],(2,[],None,None),(3,[],None,None)),(4,[],(5,[],None,None),(6,[],None,None)))

In [221]:
def getHeight(root):
	'''
	Assumption: the given tree is a binary tree 
	Input: the root node of a tree 
	Output: the height of the tree 
	'''
	if not root: 
		return 0 

	leftHeight = getHeight(root.left)
	rightHeight = getHeight(root.right) 

	return max(leftHeight, rightHeight) + 1 

In [222]:
# Global set up 
N = 28 

# Z numbers of blocks within each bucket
Z = 4 

global StashInit
StashInit = []
LEVELS = getHeight(rt) 

random.seed(11)
position = defaultdict(int, {k:random.randrange(0, (pow(2, LEVELS) - 1)) for k in range(1,N+1)})

sorted_pos = defaultdict(list)

for key,val in position.items():
    sorted_pos[val].append(key)
    

In [223]:
print(position)
print(sorted_pos)

defaultdict(<class 'int'>, {1: 3, 2: 6, 3: 4, 4: 6, 5: 6, 6: 3, 7: 3, 8: 4, 9: 6, 10: 4, 11: 1, 12: 1, 13: 6, 14: 4, 15: 3, 16: 5, 17: 4, 18: 6, 19: 1, 20: 0, 21: 3, 22: 2, 23: 1, 24: 0, 25: 4, 26: 6, 27: 5, 28: 5})
defaultdict(<class 'list'>, {3: [1, 6, 7, 15, 21], 6: [2, 4, 5, 9, 13, 18, 26], 4: [3, 8, 10, 14, 17, 25], 1: [11, 12, 19, 23], 5: [16, 27, 28], 0: [20, 24], 2: [22]})


In [224]:
def concat_STASH(stsh, anotherL):
    for i in anotherL:
        stsh.append(i)
    return None 

In [225]:
def initialize_tree(root, stsh):
    if root is None:
        return None 

    data_ = [("", x) for x in sorted_pos[root.idx]]
    data = data_[:Z]
    concat_STASH(stsh, data_[Z:])
    
    root.val = data 
    initialize_tree(root.left, stsh)
    initialize_tree(root.right, stsh)
    return None

initialize_tree(rt, StashInit)

In [226]:
def getPath(root, NodeId): 
	'''
	Input: root(root) node, and leaf node(NodeId)
	Output: all nodes along the path from leaf node to root node. 
	'''
	if not root:
		return []
	
	if root.idx == NodeId:
		return [(root.idx, root.val)] 

	leftPath = getPath(root.left, NodeId) 
	rightPath = getPath(root.right, NodeId)
 
	if leftPath:
		return leftPath +  [(root.idx, root.val)]
	if rightPath: 
		return rightPath + [(root.idx, root.val)]
	return []

In [227]:
print(getPath(rt, 6))

[(6, [('', 2), ('', 4), ('', 5), ('', 9)]), (4, [('', 3), ('', 8), ('', 10), ('', 14)]), (0, [('', 20), ('', 24)])]


In [228]:
def getDataInBlock(root, blockId):
  ''' 
  Get the data associated with a block ID in a given tree
  Input: root of a tree and a block ID
  Output: the data associated with the block ID 
  '''
  if root is not None:
    for elem in iter(root.val):
      if elem[1] == blockId:
        return elem[0]

    lres = getDataInBlock(root.left, blockId)
    if lres is not None: 
      return lres 

    rres = getDataInBlock(root.right, blockId)
    if rres is not None:
      return rres

getDataInBlock(rt,16)

''

In [229]:
def getNodeAtLevel(root, leafidx, level):
  '''
  Input: the root of a tree, the leaf node index, a specific level
  Output: the node along the path from a given leaf node to root node at the given level 
  '''
  path = getPath(root, leafidx)
  path_len = len(path)
  if level > path_len - 1:
    return None
  return path[path_len-level-1]


In [230]:
def ReadBucket(root, leaf, stsh):
    '''
    read all of the (block, data) pair along the path from a leaf node to the root node 
    Input: the root of a tree, a leaf node, and a 

    '''
    for l in range(LEVELS):
        bucket = getNodeAtLevel(root, leaf, l)
        if bucket is not None:
            bucket = bucket[1]
            for elem in bucket: 
                stsh.append(elem)
    return stsh


In [231]:
def getCandidateBlocksHelper(root, leaf, blockID, level): 
	'''Output: (node.idx, node.val)'''
	lhs = getNodeAtLevel(root, leaf, level)
	rhs = getNodeAtLevel(root, position[blockID], level)
	if lhs is not None and rhs is not None:
		if lhs == rhs:
			return (getDataInBlock(root, blockID), blockID)
		else:
			return None 
	return None

In [232]:
def testing(leaf, blockid):
    for i in range(3):
        print("Testing at Level {}".format(i))
        print(getCandidateBlocksHelper(rt, leaf, blockid, i))        

In [233]:
def testall(): 
    for leaf,bucket in sorted_pos.items():
        print("---------------{}-------------".format(leaf))
        for bid in bucket:
            print("=========={}===========".format(bid))
            testing(leaf, bid)

In [234]:
#{3: [1, 6, 7, 15, 21], 6: [2, 4, 5, 9, 13, 18, 26], 4: [3, 8, 10, 14, 17, 25], 1: [11, 12, 19, 23], 5: [16, 27, 28], 0: [20, 24], 2: [22]}
# print(getCandidateBlocksHelper(rt, 3, 2, 0))
print(getCandidateBlocksHelper(rt, 3, 2, 1))
# print(getCandidateBlocksHelper(rt, 3, 4, 2))

None


In [235]:
def getCandidateBlocks(root, leaf, level, stsh): 
    acc = []
    for elem in stsh:
        blocks = getCandidateBlocksHelper(root, leaf, elem[1], level)
        acc.append(blocks)
    res = [x for x in acc if x is not None]
    return res

In [236]:
def readBlockFromStsh(stsh, blockID):
    if stsh is None: 
        return None 
    for elem in stsh:
        if elem[1] == blockID:
            return elem[0]

In [237]:
def writeBackNodes(root, index, currlevel, tgtlevel, data):
	if root is None:
		return None

	if root.idx == index and currlevel == tgtlevel:
		retVal = root.val 
		root.val = data 
		return retVal

	nextlevel = currlevel + 1

	leftTree = writeBackNodes(root.left, index, nextlevel, tgtlevel, data)
	if leftTree: 
		return leftTree
  
	rightTree = writeBackNodes(root.right, index, nextlevel, tgtlevel, data)
	if rightTree:
		return rightTree

	return None  

In [238]:
def update_STASH(blockID, dataN, stsh):
    blockIndex = -1
    for (idx, pair) in enumerate(stsh):
        temp = list(pair)
        if temp[1] == blockID:
            blockIndex = idx 
    stsh[blockIndex] = (dataN, blockID)

In [239]:
def pop_STASH(stsh, items):
    '''
    Input: STASH, and list of pairs to be popped 
    Output: updated STASH 
    '''
    updatedSTASH = [x for x in stsh if x not in items]
    return updatedSTASH

In [240]:
def print_tree_status(status):
    print(status +" tree is ")
    print(rt)

In [241]:
def access(root, opCode, blockId, dataNew=None): 
	leafIdx = position[blockId]
	position[blockId] = random.randrange((pow(2, LEVELS - 1) - 1), (pow(2, LEVELS) - 1))
	STASH = ReadBucket(root, leafIdx, StashInit)
	dataOld = readBlockFromStsh(STASH, blockId)
	if opCode == "wr": 
		updateNode = getNodeAtLevel(root, leafIdx, LEVELS - 1)
		update_STASH(blockId, dataNew, STASH)


	for l in reversed(range(LEVELS)):		
		if len(STASH) != 0:
			candidateBlocks = getCandidateBlocks(root, leafIdx, l, STASH)
			if (len(candidateBlocks) >= Z):
				writeBackSize = Z
			else:
				writeBackSize = len(candidateBlocks)
		else:
			candidateBlocks = []
			writeBackSize = 0
		writeBackBlocks = candidateBlocks[:writeBackSize] 

		# print(writeBackBlocks)
		updatedSTASH = pop_STASH(STASH, writeBackBlocks)
		# StashInit = updatedSTASH
		writeBackNodes(root, leafIdx, 0, l, writeBackBlocks)
		# print("STASH is {}".format(updatedSTASH))
	return dataOld
	

In [242]:
access(rt, "wr", 16, "x")

old mapped leaf is 5
new mapped leaf is 3


''

In [243]:
# 1. look for room in the whole path rather than a leaf node
# 2. automate the testing 
res = access(rt, 'rd', 16, )

old mapped leaf is 3
new mapped leaf is 6
