In [57]:
from copy import deepcopy

In [1]:
def atoms_from_pdb(filepath):
    atom_positions = []

    with open(filepath, "r") as file:
        line = file.readline()

        # loops over all lines in file until the end
        while line != '':
            if line[:4] == "ATOM":  # only read ATOM data
                split_line = line.split(' ')  # split by space
                # split will have to be self-implemented in C++
                entries = []

                for string in split_line:  # remove empty strings
                    if string != '':
                        entries.append(string)
                        # in C++ this is vector::push_back()

                position_vec = []
                for i in entries[6:9]:  # columns 6-8 contain XYZ coords
                    position_vec.append(float(i))

                atom_positions.append(position_vec)

            line = file.readline()
            
    return atom_positions

atom_positions = atoms_from_pdb("2erk.pdb")

In [2]:
# not done with np because I plan to move this to C++

def find_minima(atom_centers):
    minima_position = [atom_centers[0][i] for i in range(3)]
    
    for pos in atom_centers:
        for i in range(3):
            if pos[i] < minima_position[i]:
                minima_position[i] = pos[i]
    return minima_position

def find_maxima(atom_centers):
    maxima_position = [atom_centers[0][i] for i in range(3)]
    
    for pos in atom_centers:
        for i in range(3):
            if pos[i] > maxima_position[i]:
                maxima_position[i] = pos[i]
    return maxima_position

In [3]:
def encode_octant(octant):  # int from 0-7
    encoding = list(format(octant, "03b"))
    result = [0, 0, 0]
    
    for i in range(3):
        result[i] = int(encoding[i])
        
    return result

In [4]:
def generate_flips(encoding): # determines a cell's 3 immediate siblings/neighbors 
    bit_flips = []
    
    for i in range(len(encoding)):
        flip = encoding.copy()
        flip[i] = (encoding[i] + 1) % 2
        
        bit_flips.append(flip)
        
    return bit_flips

In [5]:
# (0,x,x) -> (1,x,x):  dim = 0, end_state = 1
# generates all possible (1,x,x) encodings

def generate_closest_octants(dim, end_state):  # dim: int in [0, 2], end_state: int in [0, 1]
    template = [0,0,0]
    template[dim] = (end_state + 1) % 2
    closest_octants = []
    
    for i in range(2):
        for j in range(2):
            octant = template.copy()
            
            octant[(dim+1)%3] = i
            octant[(dim+2)%3] = j
            
            closest_octants.append(octant)
            
    return closest_octants

In [6]:
generate_closest_octants(1, 0)

[[0, 1, 0], [1, 1, 0], [0, 1, 1], [1, 1, 1]]

In [17]:
# every node is indexed by a sequence of octants 
# e.g. [0, 7, 1, 3] for a node at depth 4
# here, each octant is returned as a binary triplet
# e.g. 3 is represented as [0, 1, 1]


def find_path_of_node(node, path=[]):
    if node.parent is None:
        return []
    
    path = find_path_of_node(node.parent)
    enc = encode_octant(node.octant)
    
    path.append(enc)
    
    return path

In [22]:
# say you have a sequence such as [0, 1, 0, 0, 0, 0, 0] 
# or [0, 0, 0, 0, 1, 1, 1] and, starting from the end of the sequence, 
# you want to know when the pattern of "same" entries breaks
# In the first example, it's at index 1. In the second, it's 4.



def detect_deviation(sequence): 
    n = len(sequence)
    
    head_prev = head = sequence[n - 1]
    head = sequence[n - 1]
    
    for i in range(n):
        index = n - i - 1
        head_prev = head
        head = sequence[index]
        
        if head != head_prev:
            return index
        
    return None

In [58]:
def invert_path_at_index(path, index, dim=0):    
    inverted_path = deepcopy(path)
    
    if index is None:
        return inverted_path
    
    n = len(path)
    for i in range(index, n):
        inverted_path[i][dim] = (inverted_path[i][dim] + 1) % 2
        
    return inverted_path

In [59]:
path = [[0,0,1],[1,0,1],[1,1,1]]

print(invert_path_at_index(path, 0, dim=0))
print(path)

[[1, 0, 1], [0, 0, 1], [0, 1, 1]]
[[0, 0, 1], [1, 0, 1], [1, 1, 1]]


In [62]:
# each node is the tree has a unique path 
# e.g. [[0,0,1],[1,0,1],[1,1,1]]
# as well as up to 6* neighboring nodes with similar paths
# *(up to 6 only if considering cells at the same depth or below)
#
# Our example path has neighboring paths:
# [[0,0,1],[1,0,1],[0,1,1]]  (the last octant with any of the 3 entries flipped)
# [[0,0,1],[1,1,1],[1,0,1]]  (entry [1] from the last two octants inverted)
# [[1,0,1],[0,0,1],[0,1,1]]  (sliceing the [0] indices gives a sequence
# {0, 1, 1} which is inverted to {1, 0, 0} keeping everything else fixed)
# 
# So our path has 5 neighobring cells in the octree

def find_neighbors_at_same_depth(path):
    n = len(path)
    neighbors = []
    
    head_prev = [None, None, None]  # can just be None but I want to be consistent
    head = [None, None, None]
    
    for i in range(n):
        index = n - i - 1
        head_prev = head.copy()
        head = path[index].copy()
        
        for j in range(3):
            if head[j] != head_prev[j]:
                neighbor_path = invert_path_at_index(path, index, dim=j)
                neighbors.append(neighbor_path)
                
    return neighbors

In [63]:
find_neighbors_at_same_depth([[0,0,1],[1,0,1],[1,1,1]])

[[[0, 0, 1], [1, 0, 1], [0, 1, 1]],
 [[0, 0, 1], [1, 0, 1], [1, 0, 1]],
 [[0, 0, 1], [1, 0, 1], [1, 1, 0]],
 [[0, 0, 1], [1, 1, 1], [1, 0, 1]],
 [[1, 0, 1], [0, 0, 1], [0, 1, 1]]]

In [8]:
class Node:
    def __init__(self, points, parent=None, octant=None):
        self.points = points
        self.parent = parent
        self.octant = octant
        self.children = []
        self.leaves = []
        self.index = None  
        self.dimensions = None
        self.center = None 
        self.depth = None
        self.category = None
        
        # this initializes the root node
        if parent is None:
            self.depth = 0
            
            minima_position, maxima_position = None, None   
            
            # should always be true for root node, but just in case
            if self.points: 
                minima_position = find_minima(self.points) 
                maxima_position = find_maxima(self.points)

                self.dimensions = [maxima_position[i] - minima_position[i] for i in range(3)]
                self.center = [minima_position[i] + self.dimensions[i]/2.0 for i in range(3)]
            
        # initialize child nodes
        else: 
            self.depth = self.parent.depth + 1
            self.dimensions = [self.parent.dimensions[i]/2.0 for i in range(3)]
            
            octant_encoding = encode_octant(self.octant)  # convert octant to binary
            offset = [self.dimensions[i]/2.0 for i in range(3)]
            
            self.center = [self.parent.center[i] + (1 - 2*octant_encoding[i])*offset[i] for i in range(3)]

    
    # generates 8 child nodes according to the octree algorithm.
    # this will only execute if the number of atoms in this cell 
    # is greater than the minimum number of atoms allowed in a cell
    
    def create_children(self, min_atoms):
        
        # exit function early if there are too few atoms in the cell
        if len(self.points) < min_atoms:
            self.children = []
            return
        
        # divide into octants
        divided_points = [[] for _ in range(8)]
        
        for i in range(len(self.points)):
            p = self.points[i]
            x_pol = (p[0] <= self.center[0])
            y_pol = (p[1] <= self.center[1])
            z_pol = (p[2] <= self.center[2])

            # abuse of Python's type coercion
            # basically take the sequence of x,y,z_pol as
            # a binary number and convert to base 10
            ind = x_pol*4 + y_pol*2 + z_pol

            divided_points[ind].append(p)

        for i in range(8):
            self.children.append(Node(divided_points[i], parent=self, octant=i))
       
    
    # build octree with recursion
    def span(self, min_atoms):
        self.create_children(min_atoms)
        
        for child in self.children:
            child.span(min_atoms)
            
    
    # returns total number of nodes
    def count_subnodes(self): 
        count = 1
        
        for child in self.children:
            count += child.count_subnodes()
            
        return count
    
    
    def count_leaf_nodes(self):
        count = 0
        
        if not self.children:
            return 1
        
        for child in self.children:
            count += child.count_leaf_nodes()
            
        return count
    
    
    def collect_leaf_nodes(self, current_count=0):
        count = 0
        leaves = []
        
        if not self.children:
            self.index = current_count
            leaves.append(self)
            return 1, leaves
        
        for child in self.children:
            increment, new_leaves = child.collect_leaf_nodes(current_count)
            count += increment
            leaves.extend(new_leaves)
            current_count = count
            
        if self.depth == 0:
            self.leaves = leaves
            
        return 1, leaves

In [9]:
tree = Node(atom_positions)

# create an (8,1)-admissible octree
tree.span(8)

In [10]:
print("total nodes in the tree:", tree.count_subnodes())

tree.collect_leaf_nodes()

print("total leaves in the tree:", len(tree.leaves))

total nodes in the tree: 1665
total leaves in the tree: 1457


In [17]:
# how many leaf nodes are empty?
s = 0

for leaf in tree.leaves:
    if not leaf.points:
        leaf.category = 'exterior'
        s += 1
        
s

416

In [56]:
li = [[0,2],[9,4]]
a = li.copy()

a[0][0] = 34

print(a, li)

[[34, 2], [9, 4]] [[34, 2], [9, 4]]
