The jupyter notebook "Generate_input_matrix.ipnyb" contains the code for extracting the coorindates of a whole protein. If we want to re-implement the prediction task from the Wilke-paper
as a pre-training task, we need to the following:

1. Select an amino acid of the protein.
2. Get all coordinates of the atoms surrounding this amino acid ( up to a specific distance; in the example below 20 Angstrom/voxels)
3. Mask the selected amino acid, i.e. not encoding the atoms of this amino acid

import numpy as np
import pandas as pd
from scipy.spatial import distance

filename = "./data/4hhb.pdb"

In [None]:
import numpy as np
import pandas as pd
import json
from scipy.spatial import distance
from Bio.PDB.PDBParser import PDBParser
parser = PDBParser()

# filename = "/home/sedrica/students/Vishal/4hhb.pdb"

input_size = 20
dic_VdW_radius = {"C":1.70, 'N':1.55, "S":1.80, "O" : 1.52, "P":1.80}
dic = {"C":0, 'N':1, "S":2, "O" : 3, "P":4}
sigma_sqd = 1


all_residues = ['GLN', 'THR', 'ASP', 'CYS', 'ARG', 'GLU', 'VAL', 'ILE', 'MET', 'TRP',
            'HIS', 'LYS', 'LEU', 'GLY', 'PHE', 'SER', 'PRO', 'ASN', 'ALA', 'TYR']


residue_class = {'GLN': 0, 'THR': 1, 'ASP': 2, 'CYS': 3, 'ARG': 4, 'GLU': 5, 'VAL': 6, 'ILE': 7, 'MET': 8,
                 'TRP': 9, 'HIS': 10, 'LYS': 11, 'LEU': 12, 'GLY': 13, 'PHE': 14, 'SER': 15, 'PRO': 16,
                 'ASN': 17, 'ALA': 18, 'TYR': 19}


In [None]:
def get_input_matrix(point_list, input_size = 20):
    X = np.zeros([input_size,input_size,input_size,5])
    for p in point_list:
        x, y, z, c, v = tuple(p)
        X[int(x),int(y),int(z), int(c)] = v
    return(X)


In [None]:
def shift_axes(PDB_DataFrame, input_size = 96):
    'Help function to shift x-, y- and z-axis such that minimum of all three axis is 0 and round values to next whole number'

    x_values = np.array(PDB_DataFrame["x"])
    y_values = np.array(PDB_DataFrame["y"])
    z_values = np.array(PDB_DataFrame["z"])

    #Calculate minimum
    x_min, y_min, z_min   = np.min(x_values), np.min(y_values), np.min(z_values)
    x_max, y_max, z_max   = np.max(x_values), np.max(y_values), np.max(z_values)
    x_range, y_range, z_range = x_max-x_min, y_max-y_min, z_max-z_min

    #centering the protein:
    x_shift = (input_size - x_range)/2
    y_shift = (input_size - y_range)/2
    z_shift = (input_size - z_range )/2
    #Shift values:
    x_values = x_values - x_min + x_shift
    y_values = y_values - y_min + y_shift
    z_values = z_values - z_min + z_shift


    PDB_DataFrame["x"]= x_values
    PDB_DataFrame["y"]= y_values
    PDB_DataFrame["z"]= z_values
    return PDB_DataFrame


def correct_errors(ls):
    if len(ls[2]) > 4: #residue name is not separated from atom name
            new_ls = ls[0:2]
            new_ls.append(ls[2][0:3])
            new_ls.append(ls[2][3:])
            for k in range(3,len(ls)):
                new_ls.append(ls[k])
            ls = new_ls

    if len(ls[4]) > 1:
        new_ls = ls[0:4]
        new_ls.append(ls[4][0])
        new_ls.append(ls[4][1:])
        for k in range(5,len(ls)):
            new_ls.append(ls[k])
        ls = new_ls


    if len(ls[6]) > 8: #x and y value are connected
        new_ls = ls[0:6]
        pos= ls[6].find(".")
        new_ls.append(ls[6][:(pos+4)])
        new_ls.append(ls[6][(pos+4):])
        for k in range(7,len(ls)):
            new_ls.append(ls[k])
        ls = new_ls

    if len(ls[7]) > 8: #z and y value are connected
        new_ls = ls[0:7]
        pos= ls[7].find(".")
        new_ls.append(ls[7][:(pos+4)])
        new_ls.append(ls[7][(pos+4):])
        for k in range(8,len(ls)):
            new_ls.append(ls[k])
        ls = new_ls


    if len(ls[9]) > 4: #occupancy is connected with temperature factor
        new_ls = ls[0:9]
        new_ls.append(ls[9][0:4])
        new_ls.append(ls[9][4:])
        new_ls.append(ls[10])
        ls = new_ls
    if len(ls) != 12:
        print("list still has not length 12")
    return(ls)

def point_dict_to_npy(point_dict):
    point_list = []
    for keys, values in point_dict.items():
        if not keys[0] <0 and not  keys[1] <0 and not  keys[2] < 0:
             if not keys[0] >19 and not keys[1] > 19 and not keys[2] > 19:
                point_list.append([keys[0], keys[1], keys[2], keys[3], np.round(values,5)])
    return(np.array(point_list))


def generate_point_dict(filename, ID):
    ATOMS = []
    #read pdb_file abd create DataFrame pdb:
    for line in open(filename):

        ls = line.split()
        if ls[0] == 'ATOM':
            if len(ls) != 12:
                ls = correct_errors(ls)
            ATOMS.append(ls)

    pdb = pd.DataFrame(ATOMS, columns =["record type","atom ID","atom name","residue name","chain ID","residue ID",
                             "x","y","z","occupancy","temperature factor","atom"])
    structure = parser.get_structure(filename.split("/")[-1], filename)

    pdb["x"]= pd.to_numeric(pdb["x"])
    pdb["y"]= pd.to_numeric(pdb["y"])
    pdb["z"]= pd.to_numeric(pdb["z"])

    #shift axis and round values:
    pdb= shift_axes(PDB_DataFrame =pdb)
    #create numpy array:
    residues = list(set(pdb["residue ID"]))


    residue_name = list(pdb["residue name"].loc[pdb["residue ID"] == ID])[0]

    #get center of residue:
    help_df = pdb.loc[pdb["residue ID"] == ID].loc[pdb["atom name"] == "CA"]
    residue_type = list(help_df["residue name"])[0]
    x,y,z = list(help_df["x"])[0], list(help_df["y"])[0], list(help_df["z"])[0]

    #remove/mask selected resiude:
    help_df = pdb.loc[pdb["residue ID"] != ID]

    help_df = help_df.loc[help_df["x"]< x +15].loc[help_df["x"]> x -15]
    help_df = help_df.loc[help_df["y"]< y +15].loc[help_df["y"]> y -15]
    help_df = help_df.loc[help_df["z"]< z +15].loc[help_df["z"]> z -15]

    help_df["x"] = np.array(help_df["x"]) - x + 9
    help_df["y"] = np.array(help_df["y"]) - y + 9
    help_df["z"] = np.array(help_df["z"]) - z + 9

    x_values = np.array(help_df["x"])
    y_values = np.array(help_df["y"])
    z_values = np.array(help_df["z"])
    atom_types = list(help_df["atom"])
    atom_names = list(help_df["atom name"])

    protein_size = np.max([np.max(x_values), np.max(y_values), np.max(z_values)])


    point_dict = {}

    for i in range(len(atom_types)):
        atom = atom_types[i]
        coord = np.array([x_values[i], y_values[i], z_values[i]])
        try:
            for c1 in range(max(0, int(coord[0]-3)), min(input_size, int(coord[0]+4))):
                for c2 in range(max(0, int(coord[1]-3)), min(input_size, int(coord[1]+4))):
                    for c3 in range(max(0, int(coord[2]-3)), min(input_size, int(coord[2]+4))):
                        coord2 = np.array([c1+0.5, c2+0.5, c3+0.5])
                        dst = distance.euclidean(coord, coord2)
                        if (c1,c2,c3, dic[atom]) in point_dict.keys():
                            point_dict[(c1,c2,c3, dic[atom])] += np.exp(-dst/sigma_sqd)
                        else:
                            point_dict[(c1,c2,c3, dic[atom])] = np.exp(-dst/sigma_sqd)
        except KeyError:
            pass
    return(point_dict)



### Generating Dataset

In [None]:
import os
import warnings
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from IPython.display import FileLink

warnings.filterwarnings("ignore", category=PDBConstructionWarning)


# specify the directory in which u have the pdb files
directory = '/home/sedrica/students/Vishal/50000PDBFiles'

for file in os.listdir(directory):
    filename = os.path.join(directory, file)
    base_name = os.path.splitext(file)[0]

    print("Processing file:", file)
    ATOMS = []
    #read pdb_file abd create DataFrame pdb:
    for line in open(filename):
        ls = line.split()
        if ls[0] == 'ATOM':
            if len(ls) != 12:
                ls = correct_errors(ls)
            ATOMS.append(ls)

    pdb = pd.DataFrame(ATOMS, columns =["record type","atom ID","atom name","residue name","chain ID","residue ID",
                            "x","y","z","occupancy","temperature factor","atom"])
    structure = parser.get_structure(filename.split("/")[-1], filename)

    pdb["x"]= pd.to_numeric(pdb["x"])
    pdb["y"]= pd.to_numeric(pdb["y"])
    pdb["z"]= pd.to_numeric(pdb["z"])

    #shift axis and round values:
    pdb= shift_axes(PDB_DataFrame =pdb)

    chain_id = list(set(pdb["chain ID"]))
    print(chain_id)

    # Create an empty list to store the dataset for the processing PDB file
    dataset = []

    # Intializing voxel_id
    v_id = 0


    for cid in chain_id:
        print(cid)
        res_ID = list(set(pdb["residue ID"].loc[pdb["chain ID"] == cid]))

        for r_id in res_ID:
            unique_v_id = base_name + '_' + str(v_id)
            print(unique_v_id)

            # extracting label of the voxel
            r_n = list(pdb["residue name"].loc[pdb["residue ID"] == r_id].loc[pdb["chain ID"] == cid])[0]
            label = residue_class.get(r_n, -1)

            point_dict = generate_point_dict(filename, ID = r_id)
            point_list = point_dict_to_npy(point_dict)
            X = get_input_matrix(point_list = point_list, input_size = 20)
            dataset.append({'voxel_id': unique_v_id, 'X': X.tolist(), 'label': label})
            v_id +=1
    with open(f'{base_name}.json', 'w') as json_file:
        json.dump(dataset, json_file)




Processing file: pdb6v29.ent
['A', 'C', 'D', 'B']
A
pdb6v29_0
pdb6v29_1
pdb6v29_2
pdb6v29_3
pdb6v29_4
pdb6v29_5
pdb6v29_6
pdb6v29_7
pdb6v29_8
pdb6v29_9
pdb6v29_10
pdb6v29_11
pdb6v29_12
pdb6v29_13
pdb6v29_14
pdb6v29_15
pdb6v29_16


KeyboardInterrupt: 

Visualising the 3D structure

In [None]:
from plotly import graph_objs as go
#pip install pythreejs
#pip install plotly
#pip install nbformat

protein = X.copy()

#set all values below 0.05 to 0
protein[protein < 0.05] = 0

#plot all carbon atoms (channel 0)
x = protein[:,:,:,1].nonzero()[0]
y = protein[:,:,:,1].nonzero()[1]
z = protein[:,:,:,1].nonzero()[2]

#set alpha accoridng to value
alpha = protein[:,:,:,1][x,y,z]


fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=4, color=alpha, opacity=0.8))])
# Show the plot
fig.show()