In [1]:
def atoms_from_pdb(filepath):
    atom_positions = []

    with open(filepath, "r") as file:
        line = file.readline()

        # loops over all lines in file until the end
        while line != '':
            if line[:4] == "ATOM":  # only read ATOM data
                split_line = line.split(' ')  # split by space
                # split will have to be self-implemented in C++
                entries = []

                for string in split_line:  # remove empty strings
                    if string != '':
                        entries.append(string)
                        # in C++ this is vector::push_back()

                position_vec = []
                for i in entries[6:9]:  # columns 6-8 contain XYZ coords
                    position_vec.append(float(i))

                atom_positions.append(position_vec)

            line = file.readline()
            
    return atom_positions

atom_positions = atoms_from_pdb("2erk.pdb")

In [2]:
# not done with np because I plan to move this to C++

def find_minima(atom_centers):
    minima_position = [atom_centers[0][i] for i in range(3)]
    
    for pos in atom_centers:
        for i in range(3):
            if pos[i] < minima_position[i]:
                minima_position[i] = pos[i]
    return minima_position

def find_maxima(atom_centers):
    maxima_position = [atom_centers[0][i] for i in range(3)]
    
    for pos in atom_centers:
        for i in range(3):
            if pos[i] > maxima_position[i]:
                maxima_position[i] = pos[i]
    return maxima_position

In [3]:
class Node:
    def __init__(self, points, parent=None, octant=None):
        self.points = points
        self.parent = parent
        self.octant = octant
        self.children = []
        self.leaves = []
        self.index = None  
        self.dimensions = None
        self.center = None 
        self.depth = None
        
        # this initializes the root node
        if parent is None:
            self.depth = 0
            
            minima_position, maxima_position = None, None   
            
            # should always be true for root node, but just in case
            if self.points: 
                minima_position = find_minima(self.points) 
                maxima_position = find_maxima(self.points)

                self.dimensions = [maxima_position[i] - minima_position[i] for i in range(3)]
                self.center = [minima_position[i] + self.dimensions[i]/2.0 for i in range(3)]
            
        # initialize child nodes
        else: 
            self.depth = self.parent.depth + 1
            self.dimensions = [self.parent.dimensions[i]/2.0 for i in range(3)]
            
            octant_encoding = format(self.octant, "03b")  # convert octant to binary
            offset = [self.dimensions[i]/2.0 for i in range(3)]
            
            self.center = [self.parent.center[i] + (1 - 2*int(octant_encoding[i]))*offset[i] for i in range(3)]

    
    # generates 8 child nodes according to the octree algorithm.
    # this will only execute if the number of atoms in this cell 
    # is greater than the minimum number of atoms allowed in a cell
    
    def create_children(self, min_atoms):
        
        # exit function early if there are too few atoms in the cell
        if len(self.points) < min_atoms:
            self.children = []
            return
        
        # divide into octants
        divided_points = [[] for _ in range(8)]
        
        for i in range(len(self.points)):
            p = self.points[i]
            x_pol = (p[0] <= self.center[0])
            y_pol = (p[1] <= self.center[1])
            z_pol = (p[2] <= self.center[2])

            # abuse of Python's type coercion
            # basically take the sequence of x,y,z_pol as
            # a binary number and convert to base 10
            ind = x_pol*4 + y_pol*2 + z_pol

            divided_points[ind].append(p)

        for i in range(8):
            self.children.append(Node(divided_points[i], parent=self, octant=i))
       
    
    # build octree with recursion
    def span(self, min_atoms):
        self.create_children(min_atoms)
        
        for child in self.children:
            child.span(min_atoms)
            
    
    # returns total number of nodes
    def count_subnodes(self): 
        count = 1
        
        for child in self.children:
            count += child.count_subnodes()
            
        return count
    
    
    def count_leaf_nodes(self):
        count = 0
        
        if not self.children:
            return 1
        
        for child in self.children:
            count += child.count_leaf_nodes()
            
        return count
    
    
    def collect_leaf_nodes(self, current_count=0):
        count = 0
        leaves = []
        
        if not self.children:
            self.index = current_count
            leaves.append(self)
            return 1, leaves
        
        for child in self.children:
            increment, new_leaves = child.collect_leaf_nodes(current_count)
            count += increment
            leaves.extend(new_leaves)
            current_count = count
            
        if self.depth == 0:
            self.leaves = leaves
            
        return 1, leaves

In [4]:
tree = Node(atom_positions)

# create an (8,1)-admissible octree
tree.span(8)

In [5]:
print("total nodes in the tree:", tree.count_subnodes())

tree.collect_leaf_nodes()

print("total leaves in the tree:", len(tree.leaves))

total nodes in the tree: 1665
total leaves in the tree: 1457


In [6]:
# how many leaf nodes are empty?
s = 0

for leaf in tree.leaves:
    if not leaf.points:
        s += 1
        
s

416