In [5]:
%matplotlib notebook

from Bio.PDB.PDBParser import PDBParser

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection

import numpy as np
import pandas as pd
from scipy.spatial import distance
import math

def average_pos(poss):
    avg_X, avg_Y, avg_Z = zip(*poss)
    avg_X = np.mean(avg_X)
    avg_Y = np.mean(avg_Y)
    avg_Z = np.mean(avg_Z)
    return avg_X, avg_Y, avg_Z

def dot_product(vec1, vec2):
    return sum([float(v1) * float(v2) for v1, v2 in zip(vec1, vec2)])

def vec_len(vec):
    return math.sqrt(float(sum([v * v for v in vec])))

def vec_angle(vec1, vec2):
    if vec1 == vec2:
        return 0.0
    angle = dot_product(vec1, vec2) / (vec_len(vec1) * vec_len(vec2))
    angle = math.degrees(math.acos(angle))
    assert 0.0 <= angle <= 360.0
    return angle
def perpendicular_proj(normal, point_plane, point):
    d = (point-point_plane).dot(normal/np.linalg.norm(normal))
    return point - d * (normal/np.linalg.norm(normal))

def calc_side_chain_vector(protres_name, protres_atoms, suppress_warnings=False):
    # composed of two vectors
    # 1. from average position of 'N', 'C' and 'O' atoms to 'CA'
    # 2. average vector from 'CA' to all other atoms except ('N', 'C' and 'O')

    # 1.
    avg_backbone = []
    CA_atom_coords = []
    avg_side_chain = []
    for a_id, coords in protres_atoms.items():
        if a_id in ['N', 'C', 'O']:
            if a_id != 'O':
                avg_backbone.append(coords)
        elif a_id == 'CA':
            CA_atom_coords.append(coords)
        else:
            if a_id == 'CB':
                avg_side_chain.append(coords)

    if len(CA_atom_coords) != 1:
        if not suppress_warnings:
            print("Warning, no CA atom in:", protres_name)
        return None, None

    assert len(CA_atom_coords) == 1
    CA_atom_coords = CA_atom_coords[0]
    CA_pos_X, CA_pos_Y, CA_pos_Z = CA_atom_coords
    backbone_X, backbone_Y, backbone_Z = average_pos(avg_backbone)

    vec_1 = (
        CA_pos_X - backbone_X, CA_pos_Y - backbone_Y, CA_pos_Z - backbone_Z)

    if protres_name == 'GLY':
        assert len(avg_side_chain) == 0
        return vec_1, CA_atom_coords
    else:
        assert len(avg_side_chain) > 0

    # 2.
    side_X, side_Y, side_Z = average_pos(avg_side_chain)
    vec_2 = (side_X - CA_pos_X, side_Y - CA_pos_Y, side_Z - CA_pos_Z)

    # angle between the two vectors has to be less than 90
    # A . B = |A| * |B| * cos(angle)
    # cos(angle) = A.B / (|A|*|B|)
    # angle = arccos(A.B / (|A|*|B|))
    # angle_deg = vec_angle(vec_1, vec_2)
    #
    # if angle_deg > 180.0:
    #    angle_deg = angle_deg - 360.0
    # if angle_deg >= 70:
    #    if not suppress_warnings:
    #        print("Warning: high angle (%s) between CA-CB and bacbone-CA vectors in %s" % (angle_deg, protres_name))
    #    return None, None
    #
    # assert angle_deg < 70.0

    # average of the two vectors
    vec = tuple(v1 + v2 for v1, v2 in zip(vec_1, vec_2))

    return vec, CA_atom_coords

def get_protein_and_rna(structure):
    protein = []
    protein_names = []
    rna = []
    rna_names = []
    for model in structure:
        for chain in model:
            for residue in chain:
                res_name = residue.get_resname().strip()
                atoms = {}
                for atom in residue:
                    atoms[atom.id] = atom.get_coord()

                if res_name in ['A', 'C', 'G', 'U']:
                    rna.append(atoms)
                    rna_names.append(res_name)
                if (res_name in [ 'ALA', 'ARG','ASN','ASP','CYS','GLU','GLN','HIS','ILE','LEU','LYS','MET',
                                 'PHE','PRO','SER','THR','TRP','TYR', 'VAL']
                        and {'CA', 'CB', 'C', 'N'}.issubset(set(atoms.keys()))) or \
                        (res_name == 'GLY' and {'CA', 'C', 'N'}.issubset(set(atoms.keys()))):

                    protein.append(atoms)
                    protein_names.append(res_name)
    return protein, rna, protein_names, rna_names

def bounding_box(atoms):
    positive = atoms[np.where(atoms[:, 2] >= 0)[0], :3]
    negative = atoms[np.where(atoms[:, 2] < 0)[0], :3]


    all_max = np.amax(atoms, axis=0)[:3]
    all_min = np.amin(atoms, axis=0)[:3]

    x_, y_, z_ = np.max(np.abs(np.vstack((all_max, all_min,))), axis=0)

    z_pos = np.amax(positive, axis=0)[2]

    if negative.shape[0] != 0:
        z_neg = np.amin(negative, axis=0)[2]
    else:
        z_neg = 0

    points = np.array([[-x_, -y_, z_neg],
                       [x_, -y_, z_neg],
                       [x_, y_, z_neg],
                       [-x_, y_, z_neg],
                       [-x_, -y_, z_pos],
                       [x_, -y_, z_pos],
                       [x_, y_, z_pos],
                       [-x_, y_, z_pos]])

    return (x_, y_, z_pos, z_neg), points

def transformation_matrix(aminoacid_atoms, aminoacid_name):
    a = aminoacid_atoms['CA']
    b = np.array(calc_side_chain_vector(aminoacid_name, aminoacid_atoms, suppress_warnings=False)[0])
    # b = aminoacid_atoms['CB']
    c1 = aminoacid_atoms['C']
    c2 = aminoacid_atoms['N']
    if (b - a).dot(c2 - c1) == 0:
        c = a + (c2 - c1)
    else:
        c1_p = perpendicular_proj(b - a, a, c1)
        c2_p = perpendicular_proj(b - a, a, c2)
        #print((c1_p - c2_p).dot(b - a))
        c = a + (c2_p - c1_p)

    d = a + np.cross((b - a), (c - a))

    u = ((d - a) / (np.linalg.norm(d - a)))
    v = ((c - a) / np.linalg.norm(c - a))
    w = ((b - a) / np.linalg.norm(b - a))
    u_m = np.append(u, 0)
    v_m = np.append(v, 0)
    w_m = np.append(w, 0)

    m = np.vstack((u_m, v_m, w_m, np.array([0, 0, 0, 1])))
    a_t = m.dot(np.append(a, 1))
    m[:, 3] = -a_t

    m2 = np.column_stack((u_m, v_m, w_m, np.array([0, 0, 0, 1])))
    m2[:, 3] = np.append(a, 1)

    #angle between b and c1->c2
    angle = (np.arccos((b.dot(c2-c1))/(np.linalg.norm(b) * np.linalg.norm(c2-c1)))) * (180/np.pi)

    return m, m2, angle

def plot_transformation(structure_id, amino_acid, z_aprox, num_figures=5):
    j = 0
    parser = PDBParser()
    structure = parser.get_structure(structure_id, 'pdb structures/' + 'pdb' + structure_id + '.ent')

    protein, rna, protein_names, rna_names = get_protein_and_rna(structure)
    
    # all_protein_atoms = np.concatenate(list(map(lambda x: np.array(list(x.values())), protein)))
    all_rna_atoms = np.concatenate(list(map(lambda x: np.array(list(x.values())), rna)))

    for p, (aminoacid_atoms, aminoacid_name) in enumerate(zip(protein, protein_names)):
        if j >= num_figures:
            break
            
        min_dist = 1000
        closest = -1
        for i, rna_atoms in enumerate(rna):
            dist = np.min(distance.cdist(np.array(list(aminoacid_atoms.values())), np.array(list(rna_atoms.values()))))
            if dist < min_dist:
                min_dist = dist
                closest = i
                
        if min_dist > 4:
            continue

        #transformation matrix
        a = aminoacid_atoms['CA']
        b = np.array(calc_side_chain_vector(aminoacid_name, aminoacid_atoms, suppress_warnings=False)[0])
        # b = aminoacid_atoms['CB']
        c1 = aminoacid_atoms['C']
        c2 = aminoacid_atoms['N']
        if (b - a).dot(c2 - c1) == 0:
            c = a + (c2 - c1)
        else:
            c1_p = perpendicular_proj(b - a, a, c1)
            c2_p = perpendicular_proj(b - a, a, c2)
            # print((c1_p - c2_p).dot(b - a))
            c = a + (c2_p - c1_p)

        d = np.cross((b - a), (c - a))

        u = (d / (np.linalg.norm(d)))
        v = ((c - a) / np.linalg.norm(c - a))
        w = ((b - a) / np.linalg.norm(b - a))
        u_m = np.append(u, 0)
        v_m = np.append(v, 0)
        w_m = np.append(w, 0)

        m = np.vstack((u_m, v_m, w_m, np.array([0, 0, 0, 1])))
        a_t = m.dot(np.append(a, 1))
        m[:, 3] = -a_t

        m2 = np.column_stack((u_m, v_m, w_m, np.array([0, 0, 0, 1])))
        m2[:, 3] = np.append(a, 1)

        angle = (np.arccos((b.dot(c2 - c1)) / (np.linalg.norm(b) * np.linalg.norm(c2 - c1)))) * (180 / np.pi)

        t = np.array(list(aminoacid_atoms.values()))
        t_names = list(aminoacid_atoms)
        t_1 = np.hstack((t, np.ones((t.shape[0], 1))))
        tn = np.round(m.dot(t_1.T).T, 5)[:,:3]

        # t2_1 = np.hstack((tn, np.ones((tn.shape[0], 1))))
        # t2 = m2.dot(t2_1.T).T[:,:3]
        #
        # print(t- t2)

        (x_, y_, z_pos, z_neg), tocke = bounding_box(tn[:, :3])


        x, y, z = 2* x_, 2*y_, z_pos - z_neg

        if (aminoacid_name == amino_acid) and (z_aprox - 0.2< z < z_aprox + 0.2) and (j < num_figures):
            j +=1

            print('sizes(x, y, z): ', x, y, z)
            print('angle: ', angle)

            t_rna = np.array(list(rna[closest].values()))
            t_rna_1 = np.hstack((t_rna, np.ones((t_rna.shape[0], 1))))
            tn_rna = np.round(m.dot(t_rna_1.T).T, 5)[:, :3]

            # all_protein_1 = np.hstack((all_protein_atoms, np.ones((all_protein_atoms.shape[0], 1))))
            # all_protein_n = np.round(m.dot(all_protein_1.T).T, 5)[:, :3]

            closest_protein_atoms = np.concatenate(list(map(lambda x: np.array(list(x.values())), np.array(protein)[[
                p - 2, p - 1, p + 1, p + 2]])))

            closest_protein_1 = np.hstack((closest_protein_atoms, np.ones((closest_protein_atoms.shape[0], 1))))
            closest_protein_n = np.round(m.dot(closest_protein_1.T).T, 5)[:, :3]

            all_rna_1 = np.hstack((all_rna_atoms, np.ones((all_rna_atoms.shape[0], 1))))
            all_rna_n = np.round(m.dot(all_rna_1.T).T, 5)[:, :3]


            # negative = tn[np.where(tn[:, 2] < 0)[0], :3]
            # positive = tn[np.where(tn[:, 2] >= 0)[0], :3]


            fig = plt.figure(figsize=plt.figaspect(0.5))

            ax = fig.add_subplot(1, 2, 2, projection='3d')
            ax.set_aspect('equal')

            center = np.mean(tocke, axis=0)
            max_range = np.max(np.amax(tn_rna, axis=0) - np.amin(tn_rna, axis=0)) * 0.95

            x_min, x_max = center[0] - max_range, center[0] + max_range
            y_min, y_max = center[1] - max_range, center[1] + max_range
            z_min, z_max = center[2] - max_range, center[2] + max_range

            # draw_protein = np.where((x_min < all_protein_n[:, 0]) & (all_protein_n[:, 0] < x_max) &
            #                         (y_min < all_protein_n[:, 1]) & (all_protein_n[:, 1] < y_max) &
            #                         (z_min < all_protein_n[:, 2]) & (all_protein_n[:, 2] < z_max))
            draw_rna = np.where((x_min < all_rna_n[:,0]) &  (all_rna_n[:,0] < x_max) &
                                (y_min < all_rna_n[:,1]) & (all_rna_n[:, 1] < y_max) &
                                (z_min < all_rna_n[:, 2]) & (all_rna_n[:, 2] < z_max))

            # draw_protein_n = all_protein_n[draw_protein]
            draw_rna_n = all_rna_n[draw_rna]

            # ax.plot(draw_protein_n[:, 0], draw_protein_n[:, 1], draw_protein_n[:, 2], color='grey', marker='o', linestyle = 'None')
            ax.plot(closest_protein_n[:, 0], closest_protein_n[:, 1], closest_protein_n[:, 2], color='grey',
                    marker='o',
                    linestyle = 'None')
            ax.plot(draw_rna_n[:, 0], draw_rna_n[:, 1], draw_rna_n[:, 2], 'yo')

            ax.plot(tn[:, 0], tn[:, 1], tn[:, 2], 'ko')
            # ax.plot(positive[:, 0], positive[:, 1], positive[:, 2], 'ko')
            # ax.plot(negative[:, 0], negative[:, 1], negative[:, 2], 'go')

            ax.plot(tn_rna[:, 0], tn_rna[:, 1], tn_rna[:, 2], 'mo')

            ca_n = tn[t_names.index('CA'), :3]
            b_n = m.dot(np.append(b, 1))[:3]
            c_n = tn[t_names.index('C'), :3]
            n_n = tn[t_names.index('N'), :3]
            cp_n = m.dot(np.append(c1_p, 1))[:3]
            np_n = m.dot(np.append(c2_p, 1))[:3]

            ax.plot([c_n[0]], [c_n[1]], [c_n[2]], 'ro')
            ax.plot([n_n[0]], [n_n[1]], [n_n[2]], 'bo')

            ax.plot([ca_n[0], b_n[0]], [ca_n[1], b_n[1]], [ca_n[2], b_n[2]])
            ax.plot([cp_n[0], np_n[0]], [cp_n[1], np_n[1]], [cp_n[2], np_n[2]])


            # plot vertices
            ax.scatter3D(tocke[:, 0], tocke[:, 1], tocke[:, 2])

            # list of sides' polygons of figure
            verts = [[tocke[0], tocke[1], tocke[2], tocke[3]],
                     [tocke[4], tocke[5], tocke[6], tocke[7]],
                     [tocke[0], tocke[1], tocke[5], tocke[4]],
                     [tocke[2], tocke[3], tocke[7], tocke[6]],
                     [tocke[1], tocke[2], tocke[6], tocke[5]],
                     [tocke[4], tocke[7], tocke[3], tocke[0]],
                     [tocke[2], tocke[3], tocke[7], tocke[6]]]

            # plot sides
            ax.add_collection3d(
                Poly3DCollection(verts, facecolors='lightcyan', linewidths=1, edgecolors='dodgerblue', alpha=.25))

            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')


            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_zlim(z_min, z_max)


            ax = fig.add_subplot(1, 2, 1, projection='3d')
            ax.set_aspect('equal')

            # draw_protein_o = all_protein_atoms[draw_protein]
            draw_rna_o = all_rna_atoms[draw_rna]

            # ax.plot(draw_protein_o[:, 0], draw_protein_o[:, 1], draw_protein_o[:, 2], color='grey', marker='o',
            #         linestyle = 'None')
            ax.plot(closest_protein_atoms[:, 0], closest_protein_atoms[:, 1], closest_protein_atoms[:, 2], color='grey', marker='o',
                    linestyle = 'None', label='closet amino acid atoms')
            ax.plot(draw_rna_o[:, 0], draw_rna_o[:, 1], draw_rna_o[:, 2], 'yo', label='rna atoms')

            ax.plot(t[:, 0], t[:, 1], t[:, 2], 'ko', label='amino acid atoms')

            ax.plot(t_rna[:, 0], t_rna[:, 1], t_rna[:, 2], 'mo', label='closest nucleotid atoms')

            # ca = t[t_names.index('CA'), :3]
            # c = t[t_names.index('C'), :3]
            # n = t[t_names.index('N'), :3]
            ax.plot([c1[0]], [c1[1]], [c1[2]], 'ro', label='C')
            ax.plot([c2[0]], [c2[1]], [c2[2]], 'bo', label='N')

            ax.plot([a[0], b[0]], [a[1], b[1]], [a[2], b[2]])
            ax.plot([c1_p[0], c2_p[0]], [c1_p[1], c2_p[1]], [c1_p[2], c2_p[2]])
            # print(np.linalg.norm(an - bn), np.linalg.norm(a - b))


            tocke_1 = np.hstack((tocke, np.ones((tocke.shape[0], 1))))
            tocke_ = np.round(m2.dot(tocke_1.T).T, 5)[:,:3]

            # list of sides' polygons of figure
            verts_ = [[tocke_[0], tocke_[1], tocke_[2], tocke_[3]],
                      [tocke_[4], tocke_[5], tocke_[6], tocke_[7]],
                      [tocke_[0], tocke_[1], tocke_[5], tocke_[4]],
                      [tocke_[2], tocke_[3], tocke_[7], tocke_[6]],
                      [tocke_[1], tocke_[2], tocke_[6], tocke_[5]],
                      [tocke_[4], tocke_[7], tocke_[3], tocke_[0]],
                      [tocke_[2], tocke_[3], tocke_[7], tocke_[6]]]

            #print(np.linalg.norm(tocke[0] - tocke[1]), np.linalg.norm(tocke_[0] - tocke_[1]))

            # plot sides
            ax.add_collection3d(
                Poly3DCollection(verts_, facecolors='lightcyan', linewidths=1, edgecolors='dodgerblue', alpha=.25))

            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')

            center_ = np.mean(tocke_, axis=0)

            ax.set_xlim(center_[0] - max_range, center_[0] + max_range)
            ax.set_ylim(center_[1] - max_range, center_[1] + max_range)
            ax.set_zlim(center_[2] - max_range, center_[2] + max_range)

            fig.suptitle(f"Structure  ID: {structure_id}, Protein: {aminoacid_name}")
            ax.legend(loc='best', fancybox=True, framealpha=0.5)

            plt.show()


In [16]:
plot_transformation('1rgo', 'CYS', 2.7, 3) #homo sapiens

sizes(x, y, z):  4.37258 3.48918 2.79457
angle:  88.36185676413147


<IPython.core.display.Javascript object>

sizes(x, y, z):  4.19236 2.79478 2.77961
angle:  88.15972414583412


<IPython.core.display.Javascript object>

sizes(x, y, z):  3.93092 2.68466 2.85202
angle:  88.15421218317393


<IPython.core.display.Javascript object>

In [17]:
plot_transformation('1a1t', 'CYS', 2.7, 3) #human immunodeficiency virus 1

sizes(x, y, z):  3.71664 3.76014 2.87496
angle:  88.193144261443


<IPython.core.display.Javascript object>

sizes(x, y, z):  3.71264 3.69024 2.80938
angle:  88.1719160335026


<IPython.core.display.Javascript object>

sizes(x, y, z):  3.57798 3.73318 2.82982
angle:  88.2253486033796


<IPython.core.display.Javascript object>

In [7]:
plot_transformation('1a1t', 'CYS', 3.7, 3)

sizes(x, y, z):  4.2435 4.17612 3.6956300000000004
angle:  88.21158577220203


<IPython.core.display.Javascript object>

sizes(x, y, z):  4.2183 4.08166 3.5901899999999998
angle:  88.23293011335832


<IPython.core.display.Javascript object>

sizes(x, y, z):  4.2311 4.21964 3.71347
angle:  88.2215919993361


<IPython.core.display.Javascript object>

In [8]:
plot_transformation('1a1t', 'GLN', 3.7, 3)

sizes(x, y, z):  11.01812 3.68156 3.7898
angle:  88.28755796634007


<IPython.core.display.Javascript object>

sizes(x, y, z):  10.98468 3.39214 3.8734
angle:  88.30268683535166


<IPython.core.display.Javascript object>

sizes(x, y, z):  7.53614 7.33488 3.6703900000000003
angle:  88.31294577948502


<IPython.core.display.Javascript object>

In [9]:
plot_transformation('1a1t', 'GLN', 4.5, 3)

sizes(x, y, z):  7.68062 6.14122 4.666729999999999
angle:  88.30513204307452


<IPython.core.display.Javascript object>

sizes(x, y, z):  10.17638 3.80076 4.44623
angle:  88.31897683652214


<IPython.core.display.Javascript object>

sizes(x, y, z):  10.47644 3.18824 4.48334
angle:  88.32057510078792


<IPython.core.display.Javascript object>

In [19]:
plot_transformation('1rgo', 'HIS', 4.5, 3)

sizes(x, y, z):  8.94266 6.7789 4.44883
angle:  88.64947603008072


<IPython.core.display.Javascript object>

sizes(x, y, z):  8.99066 6.87484 4.3364899999999995
angle:  88.30373868448385


<IPython.core.display.Javascript object>

sizes(x, y, z):  8.21702 7.0486 4.64369
angle:  88.37171272710381


<IPython.core.display.Javascript object>

In [14]:
plot_transformation('1a1t', 'HIS', 4.5, 3)

sizes(x, y, z):  10.8091 4.58312 4.6448
angle:  87.53602542177441


<IPython.core.display.Javascript object>

sizes(x, y, z):  10.6777 4.24078 4.39922
angle:  87.48419111616317


<IPython.core.display.Javascript object>