In [39]:
import numpy as np
import itertools
import h5py

In [104]:
def Tree():
    def __init__(self):
        self.root = Node('0')
    
    def find_node(self,name):
        current_node = self.root
        for c in name:
            current_node = current_node.child[int(c)]
        return current_node
    def walk(self):
        return self.root.walk()
    def save(self,fname):
        with h5py.File(fname,'w') as f:
            grp = f.create_group('0')
            self.root.save(grp)
            
class Node():
    def __init__(self,indx,parent=None):
        self.indx = indx
        self.level = len(indx)-1
        self.global_index = (0,0,0)
        self.parent = parent
        if parent is not None:
            self.global_index = self.calc_global_index(*parent.global_index)
        self.leaf = True
        self.child = [None]*4
        
        self.data = np.zeros((4,))
    def calc_global_index(self,k,i,j):
        """
        Calculate the global index for the cell given (k,i,j).
        (k,i,j) can either be from the level above or the level below
        """
        if k==self.level:
            # Same level
            return (k,i,j)
        elif k < self.level:
            # Going from coarser level
            inew = 2*i + i//2
            jnew = 2*j + j%2
            return (k+1,inew,jnew)
        else:
            # Going from finer level
            inew = i // 2
            jnew = j // 2
            return (k-1,inew,jnew)
    def name_from_index(self,k,i,j):
        """
        Calculate the name of the cell corresponding to 
        global index (k,i,j)
        """
        name = []
        icurr = i
        jcurr = j
        for curr_level in range(k+1)[::-1]:
            name.append(str(2*(icurr%2)+jcurr%2))
            icurr = icurr // 2
            jcurr = jcurr // 2
        return ''.join(name[::-1])
    def index_from_name(self,name):
        """
        Calculate the index of the cell corresponding to 
        the given name
        """
        icurr = 0
        jcurr = 0
        for k,c in enumerate(name):
            icurr = 2*icurr + int(c)//2
            jcurr = 2*jcurr + int(c)%2
        
        return (len(name)-1,icurr,jcurr)
            
            
    def split(self):
        self.leaf = False
        for i in range(4):
            name = self.indx + str(i)
            self.child[i] = Node(name,self)
        return self.child
    def find_neighbors(self):
        """Start with neighbors on same level"""
        myindx = self.global_index
        n_name = []
        n_name_upper = []
        for i in range([-1,0,1]):
            for j in range([-1,0,1]):
                if ~((i==0)&(j==0)):
                    nindx = (myindx[0], myindx[1]+i,myindx[2]+j)
                    nindx_upper = (nindx[0]-1,nindx[1]//2,nindx[2]//2)
                    n_name.append( self.name_from_index(nindx))
                    n_name_upper.append(self.name_from_index( nindx_upper))
        return n_name, n_name_upper
    def check_refinement(self,neighbors,neighbors_upper,cuttof=):
        res = False
        for n,nu in zip(neighbors,neighbors_upper):
            node = self.find(nu)
            if node.leaf:
                res |= (node.compare(self) or node.r_flag)
            else:
                res |= (node.child[int(n[-1])].compare(self) or node.r_flag)
        return res
#         for i in range(3):
#             neighbor = self.parent.child[int(self.name[-1]^i)]
        
#         for i in range(3):
#             up_ind = int(self.name[-2]^i)
#             same_ind = int(self.name[-1]^i)
#             neighbor_name = str(up_ind) + str(same_ind)
#             n = find_name(neighbor_name)
#             if n.leaf:
#                 neighbor = n
#             else:
#                 neighbor = n.child[same_ind]
                
#         LR_neighbor = name[:-1] + str(int(self.name[-1])^1)
#         UR_neighbor = name[:-1] + str(int(self.name[-1])^2)
#         C_neighbor = str(int(self.name[-1]^3))

    def find(self,name):
    """
       Find the next step towards the desired 
       node with name name.
    """ 
        len_myself = len(self.indx)
        len_name = len(name)
        if self.indx == name:
            # Found it!
            return self
        if len_myself < len_name:
            if self.indx == name[:len_myself]:
                # It's below us in the tree
                child = name[:len_myself+1][-1]
                return self.down(int(child))
        # It's not below us, so move up
        return self.up()
        
    def up(self):
        return self.parent
    def down(self,i=0):
        return self.child[i] 
    def walk(self):
        if self.leaf:
            return 
        for c in self.child:
            return c.walk()
    def save(self,f):
        f.create_dataset('Data',data=self.data)
        if not self.leaf:
            for i,c in enumerate(self.child):
                grp = f.create_group(str(i))
                c.save(grp)
        
    def __repr__(self):
        return self.indx
    def __str__(self):
        return self.indx

In [100]:
root = [Node('0')]
root.append(root[0].split())
for i in range(7):
    root.append(list(itertools.chain.from_iterable([x.split() for x in root[-1]])))
    


In [154]:
for j in range(1,4):
    for i in range(4):
        print('{:d}: {:d} -> {:d}'.format(j,i,i^j))

1: 0 -> 1
1: 1 -> 0
1: 2 -> 3
1: 3 -> 2
2: 0 -> 2
2: 1 -> 3
2: 2 -> 0
2: 3 -> 1
3: 0 -> 3
3: 1 -> 2
3: 2 -> 1
3: 3 -> 0


In [164]:
for i in range(4):
    print('{:d} -> {:d}'.format(i,i//2))

0 -> 0
1 -> 0
2 -> 1
3 -> 1


In [170]:
for indx in [ (2**4,2**4),(2**4,2**4+1),(2**4+1,2**4),(2**4+1,2**4+1)]:
    print('({:d},{:d},{:d}) -> ({:d},{:d},{:d})'.format(4,indx[0],indx[1],3,indx[0]//2 + indx[0]%2,indx[1]//2 + indx[1]%2))

(4,16,16) -> (3,8,8)
(4,16,17) -> (3,8,9)
(4,17,16) -> (3,9,8)
(4,17,17) -> (3,9,9)


In [181]:
for indx in [ (2**4,2**4),(2**4,2**4+1),(2**4+1,2**4),(2**4+1,2**4+1)]:
    print(indx,[x%2 for x in indx],2*(indx[0]%2)+indx[1]%2)

(16, 16) [0, 0] 0
(16, 17) [0, 1] 1
(17, 16) [1, 0] 2
(17, 17) [1, 1] 3


In [186]:
k,i,j = (5,14,7)
name = []
icurr = i
jcurr = j
for curr_level in range(k+1)[::-1]:
    print((curr_level,icurr,jcurr))
    name.append(str(2*(icurr%2)+jcurr%2))
    icurr = icurr // 2
    jcurr = jcurr // 2
print(''.join(name[::-1]))

(5, 14, 7)
(4, 7, 3)
(3, 3, 1)
(2, 1, 0)
(1, 0, 0)
(0, 0, 0)
002331


In [195]:
icurr = 0
jcurr = 0
print(len('002331'))
for k,c in enumerate('002331'):
    icurr = 2*icurr + int(c)//2
    jcurr = 2*jcurr + int(c)%2
    print((k,icurr,jcurr))

6
(0, 0, 0)
(1, 0, 0)
(2, 1, 0)
(3, 3, 1)
(4, 7, 3)
(5, 14, 7)


In [40]:
list(itertools.chain.from_iterable(l2))

[000,
 001,
 002,
 003,
 010,
 011,
 012,
 013,
 020,
 021,
 022,
 023,
 030,
 031,
 032,
 033]

In [196]:
n1 = '002331'
n0 = '00201'

In [219]:
n1.split('0023')

['', '31']

In [220]:
[(n0[:i],(n0[:i] in n1)) for i in range(1,len(n0)+1)]

[('0', True), ('00', True), ('002', True), ('0020', False), ('00201', False)]

In [217]:
n0[:-1],n0[-1]

('0020', '1')