In [1]:
import os
import sys
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from random import sample
import math

from scipy.spatial.distance import directed_hausdorff
from sklearn.neighbors import KNeighborsClassifier

from rcsbsearchapi.search import TextQuery
from rcsbsearchapi import rcsb_attributes as attrs

In [2]:
symmetry = "C2"

# use a rscbsearchapi query to retrieve the proteins of the desired symmetry, or load it from storage.

# q1 = attrs.rcsb_struct_symmetry.symbol == symmetry
# q2 = attrs.rcsb_struct_symmetry.kind == "Global Symmetry"
# query = q1 & q2
# entry_list = list(query("assembly"))
# with open("symmetry-lists/"+symmetry+"_list.pkl", "wb") as file:
#     pickle.dump(entry_list, file)

with open("symmetry-lists/"+symmetry+"_list.pkl", "rb") as file:
    entry_list = pickle.load(file)

In [3]:
training_location = os.path.join('batch-1', '*.cif')
training_files = glob.glob(training_location)

testing_location = os.path.join('batch-2', '*.cif')
testing_files = glob.glob(testing_location)

In [4]:
#classify all of the proteins in the training data folder
training_entries = []
symmetries = 0
for f in training_files:
    entry = f[-8:-4].upper()
    training_entries.append(entry)
    if entry+'-1' in entry_list:
        symmetries = symmetries + 1
print('Desired symmetries:  ',symmetries,'/',len(training_entries))

Desired symmetries:   29 / 50


In [5]:
testing_entries = []
symmetries = 0
for f in testing_files:
    entry = f[-8:-4].upper()
    testing_entries.append(entry)
    if entry+'-1' in entry_list:
        symmetries = symmetries + 1
print('Desired symmetries:  ',symmetries,'/',len(testing_entries))

Desired symmetries:   36 / 50


In [6]:
#here's an example of getting all of the coordinates out of a protein file
i = 0
with open(training_files[i],"r") as outfile:
    data_1 = outfile.readlines()
atom_list_1 = [[float(x) for x in line.split()[10:13]] 
               for line in data_1 if 'ATOM' in line]
print(atom_list_1)

[[34.493, -12.185, 1.399], [33.306, -11.447, 1.815], [32.742, -11.944, 3.152], [31.6, -11.655, 3.491], [32.232, -11.518, 0.722], [31.801, -12.942, 0.408], [32.68, -13.797, 0.199], [30.582, -13.209, 0.352], [33.547, -12.678, 3.916], [33.103, -13.224, 5.201], [32.906, -12.112, 6.208], [31.988, -12.147, 7.012], [34.114, -14.233, 5.784], [33.538, -14.909, 6.999], [34.477, -15.262, 4.765], [33.786, -11.125, 6.142], [33.81, -10.003, 7.068], [32.448, -9.348, 7.282], [32.13, -8.898, 8.378], [34.804, -8.962, 6.565], [34.508, -7.684, 7.091], [31.646, -9.303, 6.23], [30.33, -8.69, 6.302], [29.376, -9.555, 7.127], [28.48, -9.052, 7.814], [29.776, -8.451, 4.882], [28.341, -8.937, 4.738], [29.9, -6.98, 4.511], [29.603, -10.861, 7.078], [28.721, -11.805, 7.734], [29.058, -12.044, 9.204], [28.156, -12.215, 10.024], [28.727, -13.116, 6.968], [27.946, -13.049, 5.697], [26.569, -13.017, 5.679], [28.344, -12.97, 4.406], [26.151, -12.931, 4.43], [27.207, -12.902, 3.638], [30.341, -12.052, 9.547], [30.693, 

In [7]:
#given two filenames, find the hausdorff distance between their sets of atoms
def atom_list(filename):
    with open(filename,"r") as outfile:
            data = outfile.readlines()
    points = []
    for line in data:
        if 'ATOM' not in line:
            continue
        try:
            point = tuple(float(x) for x in line.split()[10:13])
            if len(point) ==3:
                points.append(point)
        except:
            continue
    return np.array(points)

def protein_distance(filename_1, filename_2):
    atom_list_1 = atom_list(filename_1)
    atom_list_2 = atom_list(filename_2)

    return max(directed_hausdorff(atom_list_1, atom_list_2)[0], 
               directed_hausdorff(atom_list_2, atom_list_1)[0])

In [8]:
def double_to_file(i):
    if int(i) >= len(training_files):
        return testing_files[int(i)-len(training_files)]
    return training_files[int(i)]

In [9]:
#build a KNN classifier with the protein_distance metric
neigh = KNeighborsClassifier(n_neighbors=7, metric=lambda i, j : protein_distance(double_to_file(i[0]), double_to_file(j[0])))

In [10]:
X = [[i] for i in range(len(training_files))]
y = [training_files[i][-8:-4].upper()+'-1' in entry_list for i in range(len(training_files))]
neigh.fit(X,y)

In [11]:
predictions = {}
actuals = {}

for i in range(len(testing_files)):
    txt = (("\rOn {current} of {total}.   {percent}% complete.")
                   .format(current = i+1, total = len(testing_files), 
                           percent = math.floor(i/len(testing_files)*100)))
    sys.stdout.write(txt)
    sys.stdout.flush()
    try:
        predicted = neigh.predict([[i+len(training_files)]])[0]
        actual = testing_files[i][-8:-4].upper()+'-1' in entry_list
        predictions.update({testing_files[i][-8:-4].upper()+'-1' : predicted})
        actuals.update({testing_files[i][-8:-4].upper()+'-1' : actual})
    except:
        print("Exception!")



On 50 of 50.   98% complete.

In [12]:
accuracy = len([key for key in predictions.keys() if predictions[key] == actuals[key]])/len(predictions.keys())
print('Accuracy:',accuracy)

Accuracy: 0.68


In [13]:
precision = (len([key for key in predictions.keys() if predictions[key] == True and actuals[key] == True])
             /len([key for key in predictions.keys() if predictions[key] == True]))
print('Precision:',precision)

Precision: 0.7941176470588235


In [14]:
recall = (len([key for key in predictions.keys() if predictions[key] == True and actuals[key] == True])
             /len([key for key in predictions.keys() if actuals[key] == True]))
print('Recall:',recall)

Recall: 0.75


In [15]:
print(predictions)

{'2YAQ-1': True, '6ZP0-1': False, '1T26-1': True, '3JTW-1': True, '8EBN-1': True, '3FVA-1': True, '7B4E-1': True, '4Q50-1': False, '2FEL-1': True, '8ARQ-1': True, '5AVF-1': False, '2C20-1': True, '4O6H-1': True, '2YE9-1': True, '3BXP-1': True, '4RLD-1': False, '5WPI-1': True, '2WJU-1': False, '3AQA-1': True, '6QR5-1': True, '5T1X-1': True, '5PGY-1': True, '4GOE-1': True, '7YSV-1': False, '3TIF-1': True, '4GUZ-1': True, '2FSW-1': True, '7RS7-1': True, '6ECI-1': False, '2GSR-1': True, '6DG3-1': True, '5X2R-1': False, '2C21-1': True, '2VAT-1': True, '2HAX-1': True, '3K5T-1': True, '6U8X-1': False, '1EKQ-1': False, '4FEW-1': False, '3ZTS-1': False, '3LFL-1': True, '7R6A-1': False, '6J6T-1': False, '8SSY-1': True, '4XVY-1': True, '7OWL-1': False, '7WWY-1': True, '7M8L-1': False, '1IZB-1': True, '5I7D-1': True}
