In [1]:
import pandas as pd
import numpy as np
import random
import os
import subprocess
import shutil
from typing import List, Set, Dict, Tuple

from tqdm import tqdm
from biopandas.pdb import PandasPdb
from torch_geometric.data import HeteroData

from scipy.spatial.transform import Rotation
import Bio.PDB

from Bio.Data.IUPACData import protein_letters_3to1

from collections import defaultdict

from sklearn.neighbors import kneighbors_graph as knn_graph


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
import statsmodels.datasets.utils as du

### Чтение данных из PDB-файла 

In [3]:
LIG = 'ligand'
REC = 'receptor'

In [4]:
class Atom: 

    def __init__(self, name, coords):
        self.name = name 
        self.x_coord, self.y_coord, self.z_coord = coords
        self.element = 'C' if self.name == 'CA' or self.name == 'C'else 'N'

    def get_coords(self): 
        return np.array([self.x_coord, self.y_coord, self.z_coord])

In [5]:
class Chain: 

    def __init__(self, id, res_name, res_seq):
        self.id = id
        self.atoms = []
        self.res_name = res_name
        self.res_seq = res_seq

    def add_atom(self, name, coords):
        self.atoms.append(Atom(name, coords))

    def __len__(self):
        return len(self.atoms)
    
    def get_vectors_of_coords(self):

        n_c, ca_c, c_c = 0, 0, 0

        for atom in self.atoms:
            if atom.name == 'N':
                n_c = atom.get_coords()
            elif atom.name == 'CA':
                ca_c = atom.get_coords()
            else: 
                c_c = atom.get_coords() 


        u_i = (n_c - ca_c) / np.linalg.norm(n_c - ca_c)
        t_i = (c_c - ca_c) / np.linalg.norm(c_c - ca_c)
        n_i = np.cross(u_i, t_i) / np.linalg.norm(np.cross(u_i, t_i))
        v_i = np.cross(n_i, u_i)

        
        return n_i, u_i, v_i
        
    


In [6]:
class Protein: 

    def __init__(self, type_):
        self.chains = []
        self.type = type_ 
        self.edge_index = None
        self.n_coords, self.c_coords, self.ca_coords = [], [], []


    def add_chain(self, chain: Chain):
        self.chains.append(chain)

    def get_coords_arrays(self):

        for chain in self.chains:
            for atom in chain.atoms: 
                if atom.name == 'CA':
                    self.ca_coords.append(atom.get_coords())
                elif atom.name == 'C':
                    self.c_coords.append(atom.get_coords())
                else:
                    self.n_coords.append(atom.get_coords())

    def __len__(self):
        return len(self.n_coords)


    def get_vectors_of_coords_arrays(self):

        n_i_list, u_i_list, v_i_list = [], [], []

        for chain in self.chains:
            n_i, u_i, v_i = chain.get_vectors_of_coords()
            n_i_list.append(n_i)
            u_i_list.append(u_i)
            v_i_list.append(v_i)
            
        return np.stack(n_i_list), np.stack(u_i_list), np.stack(v_i_list)


In [15]:
class Tasks:

    def __init__(self):
        self.tasks = []

    def add_task(self, paths, type_:str):
        self.tasks.append({'pair': paths, 'type': type_})

    def get_task(self, idx: int):
        if idx < len(self.tasks):
            return self.tasks[idx]
        else:
            raise IndexError('Index out of range')

In [16]:
class Loader:

    def __init__(self, data_file, knn_size=3, split = ',', verbose=False):
        self.parser = Bio.PDB.PDBParser()
        self.knn_size = knn_size
        self.data_to_load = pd.read_csv(data_file, sep=split, header=0)
        print(self.data_to_load)
        self.verbose = verbose
        self.tasks = []

    def parse_pdb(self, file_path, type_) -> Protein:

        structure = self.parser.get_structure(file_path, file_path)
        protein = Protein(type_)
        
        for model in structure:
            for one_chain in model:
                for part in one_chain.get_residues():

                    chain = Chain(
                        id=repr(one_chain).replace("<Chain id=", "").replace(">", ""), 
                        res_name=part.get_resname(), 
                        res_seq=str(part).split(' ')[4].split('=')[1])
                    
                    for atom in part: 
                        if atom.get_name() in ['C', 'CA', 'N']:
                            chain.add_atom(atom.get_name(), atom.get_coord())
                            
                    if len(chain) != 0:
                        protein.add_chain(chain)
            
        return protein

    def read_data(self) -> Tasks:
        tasks = Tasks()
        for id, line in self.data_to_load.iterrows():
            path_to_rec = './structures/' + line['path'] + '_r_b.pdb'
            path_to_lig = './structures/' + line['path']  + '_l_b.pdb'
            tasks.add_task((path_to_lig, path_to_rec), line['split'])
        self.tasks = tasks
        return tasks
    
    def process_data(self, indexes) -> List:
        train_pairs = []
        test_pairs = []
        val_pairs = []
        
        for idx in indexes: 
            paths, type_ = self.tasks.get_task(idx).values()

            ligand = self.parse_pdb(paths[0], LIG)
            receptor = self.parse_pdb(paths[1], REC)

            ligand.get_coords_arrays()
            receptor.get_coords_arrays()

            edge_index = knn_graph(np.vstack((ligand.c_coords, ligand.ca_coords, ligand.n_coords)), self.knn_size)
            ligand.edge_index = edge_index

            edge_index = knn_graph(np.vstack((receptor.c_coords, receptor.ca_coords, receptor.n_coords)), self.knn_size)
            receptor.edge_index = edge_index

            if type_ == 'test':
                test_pairs.append((ligand, receptor))
            elif type_ == 'train':
                train_pairs.append((ligand, receptor))
            else:
                val_pairs.append((ligand, receptor))
        

        return test_pairs, train_pairs, val_pairs

    

In [22]:
loader = Loader('./splits_test.csv')
loader.read_data()
loader.process_data([0])

test_pairs, train_pairs, val_pairs = loader.process_data([0])
one_task = test_pairs[0]
print(one_task[1].edge_index)


   path split
0  1A2K  test


  (0, 497)	1.0
  (0, 248)	1.0
  (0, 249)	1.0
  (1, 498)	1.0
  (1, 249)	1.0
  (1, 250)	1.0
  (2, 499)	1.0
  (2, 250)	1.0
  (2, 251)	1.0
  (3, 500)	1.0
  (3, 251)	1.0
  (3, 252)	1.0
  (4, 501)	1.0
  (4, 252)	1.0
  (4, 253)	1.0
  (5, 502)	1.0
  (5, 253)	1.0
  (5, 254)	1.0
  (6, 503)	1.0
  (6, 254)	1.0
  (6, 502)	1.0
  (7, 504)	1.0
  (7, 255)	1.0
  (7, 503)	1.0
  (8, 505)	1.0
  :	:
  (735, 486)	1.0
  (736, 239)	1.0
  (736, 488)	1.0
  (736, 487)	1.0
  (737, 240)	1.0
  (737, 489)	1.0
  (737, 241)	1.0
  (738, 241)	1.0
  (738, 490)	1.0
  (738, 489)	1.0
  (739, 242)	1.0
  (739, 491)	1.0
  (739, 243)	1.0
  (740, 243)	1.0
  (740, 492)	1.0
  (740, 491)	1.0
  (741, 244)	1.0
  (741, 493)	1.0
  (741, 492)	1.0
  (742, 245)	1.0
  (742, 494)	1.0
  (742, 246)	1.0
  (743, 246)	1.0
  (743, 495)	1.0
  (743, 494)	1.0
