In [1183]:
# Path ORAM 

from collections import defaultdict, Counter
from enum import Enum 
from turtle import update
import random
import string

class Node(object): 
	def __init__(self, ary, idx, left, right): 
		self.val = ary
		self.idx = idx
		self.left = left 
		self.right = right
		
	def __repr__(self):
		return f"({self.idx},{self.val},{self.left},{self.right})"


In [1184]:
def getNewNode(val, idx,):
    newNode = Node(val, idx,None, None)
    return newNode

In [1185]:
def buildPBT_helper(start, end, pre): 
    if start > end:
        return None 
    root = getNewNode([], pre[start])

    if(start == end):
        return root 
    
    leftStart = start + 1 
    rightStart = leftStart + int((end - leftStart + 1) /2)
    leftEnd = rightStart - 1
    rightEnd = end 

    root.left = buildPBT_helper(leftStart, leftEnd, pre)
    root.right = buildPBT_helper(rightStart, rightEnd, pre)
    return root 


In [1186]:
# constructing a simple binary tree
# actually each of the int in the arrary should be a pointer to a block of size B
#    0    l0
#   /\
#  1  4   l1
#  /\ /\
# 2 3 5 6 l2

def buildPBT(pre, size):
    return buildPBT_helper(0, size - 1, pre)

rt = buildPBT(range(7),7)
rt

(0,[],(1,[],(2,[],None,None),(3,[],None,None)),(4,[],(5,[],None,None),(6,[],None,None)))

In [1187]:
def getHeight(root):
	'''
	Assumption: the given tree is a binary tree 
	Input: the root node of a tree 
	Output: the height of the tree 
	'''
	if not root: 
		return 0 

	leftHeight = getHeight(root.left)
	rightHeight = getHeight(root.right) 

	return max(leftHeight, rightHeight) + 1 

In [1188]:
# Global set up 
N = 28 

# Z numbers of blocks within each bucket
Z = 4 
StashInit = []
LEVELS = getHeight(rt) 

random.seed(11)
position = defaultdict(int, {k:random.randrange(0, (pow(2, LEVELS) - 1)) for k in range(1,N+1)})
Operator = ['rd', 'wr']

sorted_pos = defaultdict(list)

for key,val in position.items():
    sorted_pos[val].append(key)
    

In [1189]:
print(position)
print(sorted_pos)

defaultdict(<class 'int'>, {1: 3, 2: 6, 3: 4, 4: 6, 5: 6, 6: 3, 7: 3, 8: 4, 9: 6, 10: 4, 11: 1, 12: 1, 13: 6, 14: 4, 15: 3, 16: 5, 17: 4, 18: 6, 19: 1, 20: 0, 21: 3, 22: 2, 23: 1, 24: 0, 25: 4, 26: 6, 27: 5, 28: 5})
defaultdict(<class 'list'>, {3: [1, 6, 7, 15, 21], 6: [2, 4, 5, 9, 13, 18, 26], 4: [3, 8, 10, 14, 17, 25], 1: [11, 12, 19, 23], 5: [16, 27, 28], 0: [20, 24], 2: [22]})


In [1190]:
def concat_STASH(stsh, anotherL):
    for i in anotherL:
        stsh.append(i)
    return None 

In [1191]:
def initialize_tree(root, stsh):
    if root is None:
        return None 

    data_ = [("", x) for x in sorted_pos[root.idx]]
    data = data_[:Z]
    concat_STASH(stsh, data_[Z:])
    
    root.val = data 
    initialize_tree(root.left, stsh)
    initialize_tree(root.right, stsh)
    return None
print("Before init the stash is ")
print(StashInit)
initialize_tree(rt, StashInit)
print(rt)
print("After init the stash is ")
print(StashInit)


Before init the stash is 
[]
(0,[('', 20), ('', 24)],(1,[('', 11), ('', 12), ('', 19), ('', 23)],(2,[('', 22)],None,None),(3,[('', 1), ('', 6), ('', 7), ('', 15)],None,None)),(4,[('', 3), ('', 8), ('', 10), ('', 14)],(5,[('', 16), ('', 27), ('', 28)],None,None),(6,[('', 2), ('', 4), ('', 5), ('', 9)],None,None)))
After init the stash is 
[('', 21), ('', 17), ('', 25), ('', 13), ('', 18), ('', 26)]


In [1192]:
def getPath(root, NodeId): 
	'''
	Input: root(root) node, and leaf node(NodeId)
	Output: all nodes along the path from leaf node to root node. 
	'''
	if not root:
		return []
	
	if root.idx == NodeId:
		return [(root.idx, root.val)] 

	leftPath = getPath(root.left, NodeId) 
	rightPath = getPath(root.right, NodeId)
 
	if leftPath:
		return leftPath +  [(root.idx, root.val)]
	if rightPath: 
		return rightPath + [(root.idx, root.val)]
	return []

In [1193]:
print(getPath(rt, 6))

[(6, [('', 2), ('', 4), ('', 5), ('', 9)]), (4, [('', 3), ('', 8), ('', 10), ('', 14)]), (0, [('', 20), ('', 24)])]


In [1194]:
def getDataInBlock(root, blockId):
  ''' 
  Get the data associated with a block ID in a given tree
  Input: root of a tree and a block ID
  Output: the data associated with the block ID 
  '''
  if root is not None:
    for elem in iter(root.val):
      if elem[1] == blockId:
        return elem[0]

    lres = getDataInBlock(root.left, blockId)
    if lres is not None: 
      return lres 

    rres = getDataInBlock(root.right, blockId)
    if rres is not None:
      return rres

getDataInBlock(rt,16)

''

In [1195]:
def getNodeAtLevel(root, leafidx, level):
  '''
  Input: the root of a tree, the leaf node index, a specific level
  Output: the node along the path from a given leaf node to root node at the given level 
  '''
  path = getPath(root, leafidx)
  path_len = len(path)
  if level > path_len - 1:
    return None
  return path[path_len-level-1]


In [1196]:
def ReadNodes(root, leaf, stsh):
    '''
    read all of the (block, data) pair along the path from a leaf node to the root node 
    Input: the root of a tree, a leaf node, and a 

    '''
    for l in range(LEVELS):
        bucket = getNodeAtLevel(root, leaf, l)
        if bucket is not None:
            bucket = bucket[1]
            for elem in bucket: 
                stsh.append(elem)
    return stsh


In [1197]:
def getCandidateBlocksHelper(root, leaf, blockID, level): 
	'''Output: (node.idx, node.val)'''
	lhs = getNodeAtLevel(root, leaf, level)
	rhs = getNodeAtLevel(root, position[blockID], level)
	if lhs is not None and rhs is not None:
		if lhs == rhs:
			return (getDataInBlock(root, blockID), blockID)
		else:
			return None 
	return None

In [1198]:
def testing(leaf, blockid):
    for i in range(3):
        print("Testing at Level {}".format(i))
        print(getCandidateBlocksHelper(rt, leaf, blockid, i))        

In [1199]:
def testall(): 
    for leaf,bucket in sorted_pos.items():
        print("---------------{}-------------".format(leaf))
        for bid in bucket:
            print("=========={}===========".format(bid))
            testing(leaf, bid)

In [1200]:
#{3: [1, 6, 7, 15, 21], 6: [2, 4, 5, 9, 13, 18, 26], 4: [3, 8, 10, 14, 17, 25], 1: [11, 12, 19, 23], 5: [16, 27, 28], 0: [20, 24], 2: [22]}
# print(getCandidateBlocksHelper(rt, 3, 2, 0))
print(getCandidateBlocksHelper(rt, 3, 2, 1))
# print(getCandidateBlocksHelper(rt, 3, 4, 2))

None


In [1201]:
def getCandidateBlocks(root, leaf, level, stsh): 
    acc = []
    for elem in stsh:
        blocks = getCandidateBlocksHelper(root, leaf, elem[1], level)
        acc.append(blocks)
    res = [x for x in acc if x is not None]
    return res

In [1202]:
def readBlockFromStsh(stsh, blockID):
    if stsh is None: 
        return None 
    for elem in stsh:
        if elem[1] == blockID:
            return elem[0]

In [1203]:
def writeBackNodes(root, index, currlevel, tgtlevel, data):
	if root is None:
		return None

	if root.idx == index and currlevel == tgtlevel:
		retVal = root.val 
		root.val = data 
		return retVal

	nextlevel = currlevel + 1

	leftTree = writeBackNodes(root.left, index, nextlevel, tgtlevel, data)
	if leftTree: 
		return leftTree
  
	rightTree = writeBackNodes(root.right, index, nextlevel, tgtlevel, data)
	if rightTree:
		return rightTree

	return None  

In [1204]:
def update_STASH(blockID, dataN, stsh):
    blockIndex = -1
    for (idx, pair) in enumerate(stsh):
        temp = list(pair)
        if temp[1] == blockID:
            blockIndex = idx 
    stsh[blockIndex] = (dataN, blockID)

In [1205]:
def pop_STASH(stsh, items):
    '''
    Input: STASH, and list of pairs to be popped 
    Output: updated STASH 
    '''
    popIds = [] 
    for item in items:
        popIds.append(item[1])
    updatedSTASH = [x for x in stsh if x[1] not in popIds]
    return updatedSTASH

In [1206]:
def print_tree_status(status):
    print(status +" tree is ")
    print(rt)

In [1207]:
def getBlockIdsFromLst(lst):
    return [x[1] for x in lst]

In [1208]:
def access(root, opCode, blockId, dataNew=None): 
	leafIdx = position[blockId]
	position[blockId] = random.randrange(0, (pow(2, LEVELS) - 1))
	global StashInit
	STASH = ReadNodes(root, leafIdx, StashInit)
	print("\n after reading path in STASH is {}\n".format(STASH))
	# assert blockId in getBlockIdsFromLst(StashInit),"invariant not upheld!"
	dataOld = readBlockFromStsh(STASH, blockId)
	if opCode == "wr": 
		updateNode = getNodeAtLevel(root, leafIdx, LEVELS - 1)
		update_STASH(blockId, dataNew, STASH)
	for l in reversed(range(LEVELS)):		
		if len(STASH) != 0:
			candidateBlocks = getCandidateBlocks(root, leafIdx, l, STASH)
			if (len(candidateBlocks) >= Z):
				writeBackSize = Z
			else:
				writeBackSize = len(candidateBlocks)
		else:
			candidateBlocks = []
			writeBackSize = 0
		writeBackBlocks = candidateBlocks[:writeBackSize] 
		print("write back blocks are")
		print(writeBackBlocks)
		updatedSTASH = pop_STASH(STASH, writeBackBlocks)
		print("updated STASH is {}".format(updatedSTASH))
		StashInit = updatedSTASH
		writeBackNodes(root, leafIdx, 0, l, writeBackBlocks)
		ndWr= getNodeAtLevel(root, position[blockId], l)
	return dataOld
	

Development notes:
*1 inspect the STASH 
*2 inspect why there are "None"s in the STASH after initialization
*2 automate the testing 
    - record the sequence (accessIdx, treeRoot, Op, blockId, dataN)
    - parameterized by N 

In [1209]:
# Another option is to make Operator a list
def validRAM(testN, root):
    # accesses = []
    # index = random.randrange(0,1)
    for i in range(testN):
        randBlockId = random.randrange(1,N+1)
        dataN = random.choice(string.ascii_letters)
        print("*****************writing access {}******************".format(i))
        dataO = access(root, 'wr', randBlockId, dataN)
        print("*****************reading access {}******************".format(i))
        dataR = access(root, 'rd', randBlockId, )
        print("blockID is {}, dataN is \"{}\", dataR is \"{}\" ".format(randBlockId, dataN,dataR))
        if dataN == dataR:
            print("access # {} True".format(i))
        else:
            print("access # {} False".format(i))
validRAM(30,rt)

*****************reading access 0******************

 after reading path in STASH is [('', 21), ('', 17), ('', 25), ('', 13), ('', 18), ('', 26), ('', 20), ('', 24), ('', 3), ('', 8), ('', 10), ('', 14), ('', 2), ('', 4), ('', 5), ('', 9)]

write back blocks are
[(None, 13), (None, 18), (None, 26), ('', 4)]
updated STASH is [('', 21), ('', 17), ('', 25), ('', 20), ('', 24), ('', 3), ('', 8), ('', 10), ('', 14), ('M', 2), ('', 5), ('', 9)]
write back blocks are
[(None, 17), (None, 25), (None, 13), (None, 18)]
updated STASH is [('', 21), ('', 26), ('', 20), ('', 24), ('', 3), ('', 8), ('', 10), ('', 14), ('M', 2), ('', 4), ('', 5), ('', 9)]
write back blocks are
[(None, 21), (None, 17), (None, 25), (None, 13)]
updated STASH is [('', 18), ('', 26), ('', 20), ('', 24), ('', 3), ('', 8), ('', 10), ('', 14), ('M', 2), ('', 4), ('', 5), ('', 9)]
*****************writing access 0******************

 after reading path in STASH is [('', 18), ('', 26), ('', 20), ('', 24), ('', 3), ('', 8), ('', 

assertions 
weakened version of the main invariant: assert it as either the block lives in the path or the stash. 