In [1]:
!pip install rdkit
!pip install torch_geometric
!pip gdown

Collecting rdkit
  Downloading rdkit-2023.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.9.6
Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3
ERROR: unknown command "gdown" - maybe you meant "download"


In [2]:
import numpy as np
import pandas as pd
import pickle
from rdkit import Chem
from rdkit.Chem.rdmolops import Get3DDistanceMatrix
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem import GetAdjacencyMatrix

import torch
from torch.nn import Linear, MSELoss
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, GCNConv, aggr, pool
from torch_geometric.loader import DataLoader as PGDataLoader

device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')

One Hot Encoding.  Рассматриваем молекулу, кодируем инфо об узле в молекуле. Если узел не в запрещенном листе(который мы определяем сами), то значение(кодирование) присваиваем узлу из разрешенного листа. Если не хотим рассматривать, то permitted_list[-1].

In [3]:
def one_hot_encoding(x, permitted_list):
  if x not in permitted_list:
    x = permitted_list[-1]

  binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]

  return binary_encoding

# Получаем начальные свойства узлов
def get_atom_features(atom,
                      use_chirality = True,
                      hydrogens_implicit = True):

  permitted_list_of_atoms =  ['C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As','Al','I', 'B','V','K','Tl','Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn', 'Li','Ge','Cu','Au','Ni','Cd','In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown']

  if hydrogens_implicit == False:
      permitted_list_of_atoms = ['H'] + permitted_list_of_atoms

#тип атома
  atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)
#кол-во соседей
  n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])
#кодируем заряд
  formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])
#свойство узла(степень гибридизации)
  hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
#кольцевой ли
  is_in_a_ring_enc = [int(atom.IsInRing())]
#ароматический ли
  is_aromatic_enc = [int(atom.GetIsAromatic())]
#атомная масса
  atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)]
#Радиус Ван Дер Ваальса
  vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)]
#Ковалентный радиус
  covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]
#Полный вектор (сумма всех свойств)
  atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled

  if use_chirality == True:
      chirality_type_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
      atom_feature_vector += chirality_type_enc

  if hydrogens_implicit == True:
      n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"])
      atom_feature_vector += n_hydrogens_enc

  return np.array(atom_feature_vector)

#Кодируем свойства связей узлов
def get_bond_features(bond,
                      use_stereochemistry = True):
# Тип связи (1, 2, 3...)
  permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]

  bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)

  bond_is_conj_enc = [int(bond.GetIsConjugated())]

  bond_is_in_ring_enc = [int(bond.IsInRing())]

  bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc

  if use_stereochemistry == True:
      stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
      bond_feature_vector += stereo_type_enc

  return np.array(bond_feature_vector)


In [4]:
# Передаем смайзлы, у - logp. Делает граф.
def create_pytorch_geometric_graph_data_list_from_smiles_and_labels(x_smiles, y):

  data_list = []

  for (smiles, y_val) in tqdm(zip(x_smiles, y), total=len(x_smiles)):

    mol = Chem.MolFromSmiles(smiles)
    n_nodes = mol.GetNumAtoms() #считаем связей
    n_edges = 2*mol.GetNumBonds()
    unrelated_smiles = "O=O" #Грязный ход, чтобы посчитать количество фичей для узла и для связей.
    unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
    n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
    n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))

    #Матрица количества узлов и количества признаков на один атом
    X = np.zeros((n_nodes, n_node_features))
    # Заполняем матрицу
    for atom in mol.GetAtoms():
        X[atom.GetIdx(), :] = get_atom_features(atom)

    X = torch.tensor(X, dtype = torch.float)

 #Матрица для связей
    (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))
    torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
    torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
    E = torch.stack([torch_rows, torch_cols], dim = 0)

    # Количество связей и узлов
    EF = np.zeros((n_edges, n_edge_features))
# Проходим по всем колонкам и заполняем матрциу построчно, ставя соответствие цифры атома i-го и j-го
    for (k, (i,j)) in enumerate(zip(rows, cols)):

        EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))

    EF = torch.tensor(EF, dtype = torch.float)

    # Вектор свойств каждой молекулы
    y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)

    ''' Метод дата, где мы делаем матрицу узлов, индекс матрицы, где какой узел с каким узлом связан,
       и передаем матрицу свойств каждой связи, какой величине соответствует'''
    data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))

  return data_list

In [5]:
from tqdm.notebook import tqdm

df = pd.read_csv('dataset_v1.csv', nrows=13000)
logp = [MolLogP(Chem.MolFromSmiles(x)) for x in df.SMILES]
data_list = create_pytorch_geometric_graph_data_list_from_smiles_and_labels(df['SMILES'], logp)

  0%|          | 0/13000 [00:00<?, ?it/s]

In [6]:
smiles_mol = df.SMILES[0]
#smiles_mol = "O=O"
print(smiles_mol)
print(MolLogP(Chem.MolFromSmiles(smiles_mol)))
mol = Chem.MolFromSmiles(smiles_mol)
n_nodes = mol.GetNumAtoms()
n_edges = 2*mol.GetNumBonds()
print(n_nodes)
print(n_edges)

CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1
1.6806999999999999
19
40


Суммарный вектор - то с чем работает граф нейросеть, он и будет обновляться. Он должен быть информативен, потому что в графе ее нет. Функции позволяют отследить, как составляется этот суммарный вектор

In [7]:
unrelated_smiles = "O=O"
unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))
print(n_node_features)
print(n_edge_features)

79
10


In [8]:
atom = unrelated_mol.GetAtomWithIdx(0)
permitted_list_of_atoms =  ['C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As',
                            'Al','I', 'B','V','K','Tl','Yb','Sb','Sn','Ag','Pd','Co','Se',
                            'Ti','Zn', 'Li','Ge','Cu','Au','Ni','Cd','In','Mn','Zr','Cr',
                            'Pt','Hg','Pb','Unknown']
atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)
print(atom_type_enc)

[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [9]:
n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])
print(n_heavy_neighbors_enc)

[0, 1, 0, 0, 0, 0]


In [10]:
formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])
print(formal_charge_enc)

[0, 0, 0, 1, 0, 0, 0, 0]


In [11]:
hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
print(hybridisation_type_enc)

[0, 0, 1, 0, 0, 0, 0]


In [12]:
is_in_a_ring_enc = [int(atom.IsInRing())]
print(is_in_a_ring_enc)

[0]


In [13]:
is_aromatic_enc = [int(atom.GetIsAromatic())]
print(is_aromatic_enc)

[0]


In [14]:
atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)]
print(atomic_mass_scaled)

[0.04468008131481929]


In [15]:
vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)]
print(vdw_radius_scaled)

[0.08333333333333341]


In [16]:
covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]
print(covalent_radius_scaled)

[0.026315789473684233]


In [17]:
atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled
print(atom_feature_vector)


[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0.04468008131481929, 0.08333333333333341, 0.026315789473684233]


In [18]:
X = np.zeros((n_nodes, n_node_features))
for atom in unrelated_mol.GetAtoms():
  X[atom.GetIdx(), :] = get_atom_features(atom)
X = torch.tensor(X, dtype = torch.float)
print(X)

tensor([[0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [19]:
train = []
test = []
test_scaffolds = []

In [20]:
for sid, split in enumerate(df.SPLIT):
  if split == 'train':
      train.append(data_list[sid])
  elif split == 'test':
      test.append(data_list[sid])
  else:
      test_scaffolds.append(data_list[sid])

In [21]:
class GCN(torch.nn.Module):
  def __init__(self):
      super(GCN, self).__init__()
      self.conv1 = GCNConv(79, 32) #конечный вектор (свертка для атомов)
      self.conv2 = GCNConv(32, 32) #Свертка для связей
      self.lin1 = Linear(32, 16) #Линейное преобразование. Первый слой
      self.lin2 = Linear(16, 1) #Второй слой
      self.mean_pooling = pool.global_mean_pool #Усреднение n векторов узлов для молекулы

  def forward(self, data):
      x, edge_index= data.x, data.edge_index
      x = self.conv1(x, edge_index)
      x = F.relu(x) #Функция активации
      x = F.dropout(x, p=0.25, training=self.training) #Уменьшение с 79 до 32. Отключаем 25% узлов нейросети

      x = self.conv2(x, edge_index)
      x = F.relu(x)

      x = self.mean_pooling(x, data.batch) #Берем граф полностью
      x = self.lin1(x)
      x = F.relu(x)

      x = self.lin2(x)
      return x

In [22]:
gnn_model = GCN().to(device)

In [23]:
batch_size = 10

train_dataloader = PGDataLoader(dataset = train, batch_size = batch_size)
test_dataloader = PGDataLoader(dataset = test, batch_size = batch_size)
test_scaffolds_dataloader = PGDataLoader(dataset = test_scaffolds, batch_size = batch_size)

loss_function = MSELoss()

# Градиентный спуск
optimiser = torch.optim.Adam(gnn_model.parameters(), lr = 1e-3)

In [24]:

for epoch in range(10):
#Обучаем модель благодаря отключению 25% узлов в функции forward
  gnn_model.train()
  losses = []
  for (k, batch) in tqdm(enumerate(train_dataloader), total = len(train) // 2**7):
#Предказываем logP
    output = gnn_model(batch.to(device))
    loss_function_value = loss_function(output[:,0], batch.y)
    losses.append(loss_function_value.detach().cpu())
#Обнуляем градиенты
    optimiser.zero_grad()
#Обратное распространение
    loss_function_value.backward()


    optimiser.step()
#Лосс по тесту.
  losses_test = []
  losses_test_scaffolds = []

  with torch.no_grad():

    for (k, batch) in tqdm(enumerate(test_dataloader), total = len(test) // 2**7):

        output = gnn_model(batch.to(device))
        loss_function_value = loss_function(output[:,0], batch.y)
        losses_test.append(loss_function_value.detach().cpu())


    for (k, batch) in tqdm(enumerate(test_scaffolds_dataloader), total = len(test_scaffolds) // 2**7):
        output = gnn_model(batch.to(device))
        loss_function_value = loss_function(output[:,0], batch.y)
        losses_test_scaffolds.append(loss_function_value.detach().cpu())

  test_line = 'train loss: %1.4f\ntest loss: %1.4f\nscaf loss: %1.4f'
  print(test_line % (np.mean(losses), np.mean(losses_test), np.mean(losses_test_scaffolds)))

  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 1.0218
test loss: 0.5726
scaf loss: 0.6865


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.4221
test loss: 0.2885
scaf loss: 0.3407


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.2453
test loss: 0.2274
scaf loss: 0.2541


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.2115
test loss: 0.2009
scaf loss: 0.2017


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.1983
test loss: 0.1874
scaf loss: 0.1867


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.1885
test loss: 0.1900
scaf loss: 0.1785


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.1776
test loss: 0.1677
scaf loss: 0.1594


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.1757
test loss: 0.1717
scaf loss: 0.1809


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.1653
test loss: 0.1588
scaf loss: 0.1587


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

train loss: 0.1541
test loss: 0.1419
scaf loss: 0.1408


In [25]:
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr

def productive(gnn_model, dataloader, loss_function, device):
    gnn_model.eval()
    total_loss = 0
    pred = []
    targ= []

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            output = gnn_model(batch)
            loss = loss_function(output[:, 0], batch.y)
            total_loss += loss.item() * batch.num_graphs
            pred.extend(output.cpu().numpy())
            targ.extend(batch.y.cpu().numpy())

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss, np.array(pred), np.array(targ)

In [26]:
test_loss, test_predictions, test_targets = productive(gnn_model, test_dataloader, loss_function, device)
test_MSE = mean_squared_error(test_targets, test_predictions)
test_R2 = r2_score(test_targets, test_predictions)
# test_pearson = pearsonr(test_targets, test_predictions)[0]
test_pearson = pearsonr(test_targets.flatten(), test_predictions.flatten())[0]


In [27]:
print(f"Test Loss: {test_loss:.4f}")
print(f"Test MSE: {test_MSE:.4f}")
print(f"Test R2: {test_R2:.4f}")
print(f"Test Pearson Correlation: {test_pearson:.4f}")

Test Loss: 0.1213
Test MSE: 0.1213
Test R2: 0.8552
Test Pearson Correlation: 0.9305


In [28]:
np.savetxt('y_true_GCN.csv', test_targets, delimiter=',')
np.savetxt('y_pred_GCN.csv', test_predictions, delimiter=',')