In [73]:
# Path ORAM 

from collections import defaultdict, Counter, deque
from enum import Enum 
from turtle import update
import random
import string

class Node(object): 
	def __init__(self, ary, idx, left, right): 
		self.val = ary
		self.pt = 0 
		self.idx = idx
		self.left = left 
		self.right = right
		
	def __repr__(self):
		return f"({self.idx},{self.val},{self.left},{self.right})"


In [74]:
queue = deque()

def buildingPBTLevelOrder_helper(root, nodeidx):
    nNode = Node([], nodeidx, None, None) 

    if queue: 
        currNode = queue[0]
    
    if root is None: # if root is given as None
        root = nNode
    
    elif currNode.left == None:
        currNode.left = nNode
    
    elif currNode.right == None:
        currNode.right = nNode
        queue.popleft()

    queue.append(nNode)
    return root


In [75]:
# constructing a simple binary tree
# actually each of the int in the arrary should be a pointer to a block of size B
#    0    l0
#   /\
#  1  2   l1
#  /\ /\
# 3 4 5 6 l2

def buildingPBTLevelOrder(root, lst):
    for elem in lst:
        root = buildingPBTLevelOrder_helper(root, elem)
    return root
rt = buildingPBTLevelOrder(None, range(15))
print(rt)

(0,[],(1,[],(3,[],(7,[],None,None),(8,[],None,None)),(4,[],(9,[],None,None),(10,[],None,None))),(2,[],(5,[],(11,[],None,None),(12,[],None,None)),(6,[],(13,[],None,None),(14,[],None,None))))


In [76]:
def buildNodeLevelDict(root, dct, currlevel):
    if not root:
        return None
    
    dct[root.idx] = currlevel

    nextlevel = currlevel + 1 
    buildNodeLevelDict(root.left, dct, nextlevel)
    buildNodeLevelDict(root.right, dct, nextlevel)

    return dct 

levelDict = buildNodeLevelDict(rt, {}, 0)
# print(levelDict)
    

In [77]:
def getHeight(root):
	'''
	Assumption: the given tree is a binary tree 
	Input: the root node of a tree 
	Output: the height of the tree 
	'''
	if not root: 
		return 0 

	leftHeight = getHeight(root.left)
	rightHeight = getHeight(root.right) 

	return max(leftHeight, rightHeight) + 1 

In [78]:
# Global set up 
N = 28 

# Z numbers of blocks within each bucket
Z = 4 
StashInit = []
LEVELS = getHeight(rt) 

random.seed(11)
position = defaultdict(int, {k:random.randrange(0, (pow(2, LEVELS) - 1)) for k in range(1,N+1)})
Operator = ['rd', 'wr']

sorted_pos = defaultdict(list)

for key,val in position.items():
    sorted_pos[val].append(key)
    

In [79]:
print(position)
print(sorted_pos)

defaultdict(<class 'int'>, {1: 7, 2: 13, 3: 8, 4: 13, 5: 14, 6: 12, 7: 7, 8: 7, 9: 8, 10: 13, 11: 9, 12: 3, 13: 2, 14: 12, 15: 8, 16: 7, 17: 10, 18: 9, 19: 12, 20: 2, 21: 1, 22: 7, 23: 4, 24: 2, 25: 1, 26: 8, 27: 12, 28: 14})
defaultdict(<class 'list'>, {7: [1, 7, 8, 16, 22], 13: [2, 4, 10], 8: [3, 9, 15, 26], 14: [5, 28], 12: [6, 14, 19, 27], 9: [11, 18], 3: [12], 2: [13, 20, 24], 10: [17], 1: [21, 25], 4: [23]})


In [80]:
def concat_STASH(stsh, anotherL):
    for i in anotherL:
        stsh.append(i)
    return None 

In [81]:
def initialize_tree(root, stsh):
    if root is None:
        return None 

    data_ = [("", x) for x in sorted_pos[root.idx]]
    data = data_[:Z]
    concat_STASH(stsh, data_[Z:])
    
    root.val = data 
    # root.pt = 0
    initialize_tree(root.left, stsh)
    initialize_tree(root.right, stsh)
    return None
print("Before init the stash is ")
print(StashInit)
initialize_tree(rt, StashInit)
print(rt)
print("After init the stash is ")
print(StashInit)


Before init the stash is 
[]
(0,[],(1,[('', 21), ('', 25)],(3,[('', 12)],(7,[('', 1), ('', 7), ('', 8), ('', 16)],None,None),(8,[('', 3), ('', 9), ('', 15), ('', 26)],None,None)),(4,[('', 23)],(9,[('', 11), ('', 18)],None,None),(10,[('', 17)],None,None))),(2,[('', 13), ('', 20), ('', 24)],(5,[],(11,[],None,None),(12,[('', 6), ('', 14), ('', 19), ('', 27)],None,None)),(6,[],(13,[('', 2), ('', 4), ('', 10)],None,None),(14,[('', 5), ('', 28)],None,None))))
After init the stash is 
[('', 22)]


In [82]:
def getPath(root, NodeId): 
	'''
	Input: root(root) node, and leaf node(NodeId)
	Output: all nodes along the path from leaf node to root node. 
	'''
	if not root:
		return []
	
	if root.idx == NodeId:
		return [(root.idx, root.val)] 		
		# return [(root.idx, root.val, root.pt)] 

	leftPath = getPath(root.left, NodeId) 
	rightPath = getPath(root.right, NodeId)
 
	if leftPath:
		return leftPath +  [(root.idx, root.val)]
		# return leftPath +  [(root.idx, root.val, root.pt)]
	if rightPath: 
		return rightPath + [(root.idx, root.val)]
		# return rightPath + [(root.idx, root.val, root.pt)]
	return []

In [83]:
getPath(rt, 5)

[(5, []), (2, [('', 13), ('', 20), ('', 24)]), (0, [])]

In [84]:
def clearPath(root, NodeId):
    if not root: 
        return False  
        
    temp = root.val
    root.val = []
    if root.idx == NodeId: 
        return True  
    
    lp = clearPath(root.left, NodeId)
    rp  = clearPath(root.right, NodeId)
    
    if lp or rp: 
        return True 
    
    root.val = temp
    return False 


In [85]:
def getDataInBlock(root, blockId):
  ''' 
  Get the data associated with a block ID in a given tree
  Input: root of a tree and a block ID
  Output: the data associated with the block ID 
  '''
  if root is not None:
    print("root val is {}".format(root.val))
    for elem in iter(root.val):
      if elem[1] == blockId:
        # print(" inside getDataInBlock: elem[1] is {} elem[0] is{}".format(elem[1],elem[0]))
        return elem[0]

    lres = getDataInBlock(root.left, blockId)
    if lres is not None: 
      return lres 

    rres = getDataInBlock(root.right, blockId)
    if rres is not None:
      return rres

getDataInBlock(rt,13)

root val is []
root val is [('', 21), ('', 25)]
root val is [('', 12)]
root val is [('', 1), ('', 7), ('', 8), ('', 16)]
root val is [('', 3), ('', 9), ('', 15), ('', 26)]
root val is [('', 23)]
root val is [('', 11), ('', 18)]
root val is [('', 17)]
root val is [('', 13), ('', 20), ('', 24)]


''

In [86]:
def readBlockFromStsh(stsh, blockID):
    if stsh is None: 
        return None 
    for elem in stsh:
        if elem[1] == blockID:
            return elem[0]

In [87]:
def getNodeAtLevel(root, leafidx, level):
  '''
  Input: the root of a tree, the leaf node index, a specific level
  Output: the node along the path from a given leaf node to root node at the given level 
  '''
  path = getPath(root, leafidx)
  path_len = len(path)
  if level > path_len - 1:
    return None
  return path[path_len-level-1]
getNodeAtLevel(rt, 3, 1)

(1, [('', 21), ('', 25)])

In [88]:
def ReadnPopNodes(root, leaf, stsh):
    '''
    read all of the (block, data) pair along the path from a leaf node to the root node
    Input: the root of a tree, a leaf node, and a stash 
    Output: the updated stash and the updated tree
    Thoughts: consider merging getPath() and clearPath() to be the same function since they share the same recursion function. 
    '''
    for l in range(LEVELS):
        nd = getNodeAtLevel(root, leaf, l)
        # print("Inside ReadNodes function ------------------ nd is {} ".format(nd))
        if nd is not None:
            datum = nd[1] # index 1 of Node nd contains list of data 
            for elem in datum: 
                stsh.append(elem)
  
    clearPath(root, leaf)
    return stsh


In [89]:
def getCandidateBlocksHelper(root, leaf, blockID, level, stsh): 
	'''Output: (node.val, node.idx)'''
	lhs = getNodeAtLevel(root, leaf, level)
	rhs = getNodeAtLevel(root, position[blockID], level)
	if lhs is not None and rhs is not None:
		if lhs == rhs:
			data = readBlockFromStsh(stsh, blockID)
			res = (data, blockID)
			return res
		else:
			print("LHS{} and RHS{} not eq according to the creteria $$".format(lhs[0],rhs[0]))
			return None 
	return None

In [90]:
def testing(leaf, blockid):
    for i in range(3):
        print("Testing at Level {}".format(i))
        print(getCandidateBlocksHelper(rt, leaf, blockid, i))        

In [91]:
def testall(): 
    for leaf,bucket in sorted_pos.items():
        print("---------------{}-------------".format(leaf))
        for bid in bucket:
            print("=========={}===========".format(bid))
            testing(leaf, bid)

In [92]:
def getCandidateBlocks(root, leaf, level, stsh): 
    acc = []
    for elem in stsh:
        blocks = getCandidateBlocksHelper(root, leaf, elem[1], level, stsh)
        acc.append(blocks)
    res = [x for x in acc if x is not None]
    return res

In [93]:
def writeBackNodes(root, leafIdx, tgtlevel, data):
    if not root:
        return False 
    
    temp = root.val 
    if levelDict[root.idx] == tgtlevel: 
        root.val = data
    
    if root.idx == leafIdx:
        return True
    
    lb = writeBackNodes(root.left, leafIdx, tgtlevel, data)
    rb = writeBackNodes(root.right, leafIdx, tgtlevel, data)

    if lb or rb: 
        return True

    root.val = temp 
    return False


current problem is that the leaf node always get updated. Node leaf is always at the last level, so leaf node should not be assigned to new data, or equivalently, the leaf node should be assgined back to temp, which is the originial data in the node. But under what conditions? 

In [95]:
def update_STASH(blockID, dataN, stsh):
    blockIndex = -1
    for (idx, pair) in enumerate(stsh):
        temp = list(pair)
        if temp[1] == blockID:
            blockIndex = idx 
    if blockIndex != -1: 
        stsh[blockIndex] = (dataN, blockID)
    else:
        print("not finding the key!!!!!!")

In [96]:
def pop_STASH(stsh, items):
    '''
    Input: STASH, and list of pairs to be popped 
    Output: updated STASH 
    '''
    popIds = [] 
    for item in items:
        popIds.append(item[1])
    updatedSTASH = [x for x in stsh if x[1] not in popIds]
    return updatedSTASH

In [97]:
def print_tree_status(status):
    print(status +" tree is ")
    print(rt)

In [98]:
def getBlockIdsFromLst(lst):
    return [x[1] for x in lst]

In [99]:
def access(root, opCode, blockId, dataNew=None): 
	leafIdx = position[blockId]
	position[blockId] = random.randrange(0, (pow(2, LEVELS) - 1))
	print("Operating on path identified with leaf {}".format(leafIdx))
	global StashInit
	StashInit = ReadnPopNodes(root, leafIdx, StashInit)
	print("after first reading into the stash is ")
	print(StashInit)
	print("after first reading, the tree is {}".format(root))
	# assert blockId in getBlockIdsFromLst(StashInit),"invariant not upheld!"
	dataOld = readBlockFromStsh(StashInit, blockId)
	if opCode == "wr": 
		# updateNode = getNodeAtLevel(root, leafIdx, LEVELS - 1)
		# print("Update Node {}".format(updateNode))
		print(" block ID is {}, ids in Stash is {}".format(blockId, getBlockIdsFromLst(StashInit)))
		# assert (blockId in getBlockIdsFromLst(StashInit)), "BlockID not in StashInit"
		update_STASH(blockId, dataNew, StashInit)
	for l in reversed(range(LEVELS)):		
		if len(StashInit) != 0:
			candidateBlocks = getCandidateBlocks(root, leafIdx, l, StashInit)
			if (len(candidateBlocks) >= Z):
				writeBackSize = Z
			else:
				writeBackSize = len(candidateBlocks)
		else:
			candidateBlocks = []
			writeBackSize = 0
		writeBackBlocks = candidateBlocks[:writeBackSize] 
		print("before WB StashInit is {}".format(StashInit))
		print("before WB tree is {}".format(root))
		print("write back blocks are")
		print(writeBackBlocks)
		updatedSTASH = pop_STASH(StashInit, writeBackBlocks)
		print("updated STASH is {}".format(updatedSTASH))
		StashInit = updatedSTASH
		# writeBackNode = getNodeAtLevel(root, leafIdx, l)
		# if writeBackNode is None: 
		# 	print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>YELLING ")
		# print("writeBackNode is {} at level {} ".format(writeBackNode, l))
		print("after WB tree is {}".format(root))
		writeBackNodes(root, leafIdx, 0, l, writeBackBlocks)
	return dataOld
	

Development notes:
*1 inspect the STASH 
*2 inspect why there are "None"s in the STASH after initialization
*2 automate the testing 
    - record the sequence (accessIdx, treeRoot, Op, blockId, dataN)
    - parameterized by N 

In [100]:
# Another option is to make Operator a list
def validRAM(testN, root):
    # accesses = []
    # index = random.randrange(0,1)
    for i in range(testN):
        randBlockId = random.randrange(1,N+1)
        dataN = random.choice(string.ascii_letters)
        print("*****************writing access {}******************".format(i))
        dataO = access(root, 'wr', randBlockId, dataN)
        print("*****************reading access {}******************".format(i))
        dataR = access(root, 'rd', randBlockId, )
        print("blockID is {}, dataN is \"{}\", dataR is \"{}\" ".format(randBlockId, dataN,dataR))
        if dataN == dataR:
            print("access # {} True".format(i))
        else:
            print("access # {} False".format(i))
validRAM(30,rt)

*****************writing access 0******************
Operating on path identified with leaf 4
after first reading into the stash is 
[('', 22), '0', '0', '0', '0', '0', '0', '0', '0', '0', ('', 21), ('', 25), ('', 23)]
after first reading, the tree is (0,[],(1,[],(3,[('', 12)],(7,[('', 1), ('', 7), ('', 8), ('', 16)],None,None),(8,[('', 3), ('', 9), ('', 15), ('', 26)],None,None)),(4,[],(9,[('', 11), ('', 18)],None,None),(10,[('', 17)],None,None))),(2,11111111111,(5,55555555555,(11,[],None,None),(12,[('', 6), ('', 14), ('', 19), ('', 27)],None,None)),(6,666666666,(13,[('', 2), ('', 4), ('', 10)],None,None),(14,[('', 5), ('', 28)],None,None))))


IndexError: string index out of range

assertions 
weakened version of the main invariant: assert it as either the block lives in the path or the stash. 