In [15]:
import numpy as np
import pandas as pd
import tensorboard
from rdkit import Chem

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Timer

from lion_pytorch import Lion

if torch.cuda.is_available():
    print("cuda", torch.cuda.is_available())
    print(torch.cuda.get_device_name(0))
    torch.cuda.empty_cache()
else:
    print("CUDA is not available.")


import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.connectors.data_connector")
warnings.filterwarnings("ignore", category=UserWarning, module="lightning_fabric.plugins.environments.slurm")


torch.cuda.empty_cache()

from utils.train import MoleculeModel, MoleculeDataModule, GATv2Model, evaluate_model
from utils.prepare import MoleculeData, MoleculeDataset, FeaturizationParameters

#TODO
#SAGE попробовать 
#слой чебышев попробовать
#отправить 10к на трейн остальное на тест 

cuda True
NVIDIA GeForce RTX 3080


In [16]:
molecule_dataset = torch.load("../data/QM_137k.pt")

In [17]:
molecule_dataset[0]

Data(x=[31, 133], edge_index=[2, 64], edge_attr=[64, 14], y=[31], smiles='CNC(=S)N/N=C/c1c(O)ccc2ccccc12')

In [18]:
batch_size = 128   
num_workers = 8  

data_module = MoleculeDataModule(molecule_dataset, batch_size=batch_size, num_workers=num_workers)

In [19]:
from torch_geometric.nn import GATv2Conv
from torch_scatter import scatter_mean

class GATv2Model(nn.Module):
    def __init__(self, atom_in_features, edge_in_features, num_preprocess_layers, preprocess_hidden_features, num_heads, dropout_rates, activation_fns, use_batch_norm, num_postprocess_layers, postprocess_hidden_features, out_features):
        super(GATv2Model, self).__init__()

        # Preprocessing layers for atom features
        self.atom_preprocess = nn.ModuleList()
        for i in range(num_preprocess_layers):
            preprocess_layer = nn.Sequential()
            in_features = atom_in_features if i == 0 else preprocess_hidden_features[i-1]
            preprocess_layer.add_module(f'atom_linear_{i}', nn.Linear(in_features, preprocess_hidden_features[i]))
            if use_batch_norm[i]:
                preprocess_layer.add_module(f'atom_bn_{i}', nn.BatchNorm1d(preprocess_hidden_features[i]))
            preprocess_layer.add_module(f'atom_activation_{i}', activation_fns[i]())
            preprocess_layer.add_module(f'atom_dropout_{i}', nn.Dropout(dropout_rates[i]))
            self.atom_preprocess.append(preprocess_layer)

        # Preprocessing layers for edge features
        self.edge_preprocess = nn.ModuleList()
        for i in range(num_preprocess_layers):
            preprocess_layer = nn.Sequential()
            in_features = edge_in_features if i == 0 else preprocess_hidden_features[i-1]
            preprocess_layer.add_module(f'edge_linear_{i}', nn.Linear(in_features, preprocess_hidden_features[i]))
            if use_batch_norm[i]:
                preprocess_layer.add_module(f'edge_bn_{i}', nn.BatchNorm1d(preprocess_hidden_features[i]))
            preprocess_layer.add_module(f'edge_activation_{i}', activation_fns[i]())
            preprocess_layer.add_module(f'edge_dropout_{i}', nn.Dropout(dropout_rates[i]))
            self.edge_preprocess.append(preprocess_layer)

        # GATv2 convolutional layers
        self.gat_convolutions = nn.ModuleList()
        for i, num_head in enumerate(num_heads):
            gat_layer = GATv2Conv(
                in_channels=preprocess_hidden_features[-1] * (2 if i == 0 else num_heads[i - 1]),
                out_channels=preprocess_hidden_features[-1],
                heads=num_head,
                dropout=dropout_rates[num_preprocess_layers + i],
                concat=True
            )
            self.gat_convolutions.add_module(f'gat_conv_{i}', gat_layer)


        # Postprocessing layers
        self.postprocess = nn.ModuleList()
        for i in range(num_postprocess_layers):
            post_layer = nn.Sequential()
            in_features = preprocess_hidden_features[-1] * num_heads[-1] if i == 0 else postprocess_hidden_features[i-1]
            post_layer.add_module(f'post_linear_{i}', nn.Linear(in_features, postprocess_hidden_features[i]))
            if use_batch_norm[num_preprocess_layers + len(num_heads) + i]:
                post_layer.add_module(f'post_bn_{i}', nn.BatchNorm1d(postprocess_hidden_features[i]))
            post_layer.add_module(f'post_activation_{i}', activation_fns[num_preprocess_layers + len(num_heads) + i]())
            post_layer.add_module(f'post_dropout_{i}', nn.Dropout(dropout_rates[num_preprocess_layers + len(num_heads) + i]))
            self.postprocess.append(post_layer)

        self.output_layer = nn.Linear(postprocess_hidden_features[-1], out_features)

    def forward(self, x, edge_index, edge_attr):
        for layer in self.atom_preprocess:
            x = layer(x)

        for layer in self.edge_preprocess:
            edge_attr = layer(edge_attr)

        # Combine atom and edge features
        row, col = edge_index
        aggregated_edge_features = scatter_mean(edge_attr, col, dim=0, dim_size=x.size(0))
        x = torch.cat([x, aggregated_edge_features], dim=1)

        # Apply GATv2 convolutions
        for conv in self.gat_convolutions.children():
            x = conv(x, edge_index)
        
        # Apply postprocessing
        for layer in self.postprocess:
            x = layer(x)

        x = self.output_layer(x).squeeze(-1)
        return x

In [20]:
in_features = molecule_dataset[0].x.shape[1]
hidden_features = [64, 64, 64, 64, 64, 64, 64, 64, 64]  # Размеры предобработки для каждого слоя
postprocess_hidden_features = [64, 64]  # Размеры слоёв постобработки
#попробовать уменшьаюсщиеся #TODO
num_heads = [16, 20]  # Количество голов внимания для каждого слоя GATv2
hidden_features = [128, 128, 128, 128, 128, 128, 128, 128, 128]  # Размеры предобработки для каждого слоя
postprocess_hidden_features = [128, 128]  # Размеры слоёв постобработки
num_heads = [16, 20, 16]  # Количество голов внимания для каждого слоя GATv2

edge_attr_dim = molecule_dataset[0].edge_attr.shape[1]

dropout_rates = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]  


activation_fns = [nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU, nn.PReLU]

use_batch_norm = [True, True, True, True, True, True, True, True, True, True, True, True, True]
use_batch_norm = [False, False, False, False, False, False, False, False, False, False, False, False, False]

optimizer_class = Lion

learning_rate = 2.2e-5
weight_decay = 3e-5

step_size = 80
gamma = 0.2

max_epochs = 100
patience = 3

torch.set_float32_matmul_precision('high')

base_model = GATv2Model(
    atom_in_features=in_features,
    edge_in_features=edge_attr_dim,
    num_preprocess_layers=len(hidden_features),
    preprocess_hidden_features=hidden_features,
    num_heads=num_heads,
    dropout_rates=dropout_rates,
    activation_fns=activation_fns,
    use_batch_norm=use_batch_norm,
    num_postprocess_layers=len(postprocess_hidden_features),
    postprocess_hidden_features=postprocess_hidden_features,
    out_features=1
)

model = MoleculeModel(
    base_model=base_model,
    optimizer_class=optimizer_class,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    step_size=step_size,
    batch_size=batch_size,
    gamma=gamma,
    metric='rmse'
)

print("Model:\n", model)

checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, verbose=True)
early_stop_callback = EarlyStopping(monitor='val_loss', patience=patience, verbose=True, mode='min')
timer = Timer()
logger = pl.loggers.TensorBoardLogger('logs', name='GATv2')

trainer = pl.Trainer(
    max_epochs=max_epochs,
    enable_checkpointing=False,
    callbacks=[early_stop_callback, timer],
    enable_progress_bar=False,
    logger=logger,
    accelerator='gpu',
    devices=1,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Model:
 MoleculeModel(
  (base_model): GATv2Model(
    (atom_preprocess): ModuleList(
      (0): Sequential(
        (atom_linear_0): Linear(in_features=133, out_features=128, bias=True)
        (atom_activation_0): PReLU(num_parameters=1)
        (atom_dropout_0): Dropout(p=0.0, inplace=False)
      )
      (1): Sequential(
        (atom_linear_1): Linear(in_features=128, out_features=128, bias=True)
        (atom_activation_1): PReLU(num_parameters=1)
        (atom_dropout_1): Dropout(p=0.0, inplace=False)
      )
      (2): Sequential(
        (atom_linear_2): Linear(in_features=128, out_features=128, bias=True)
        (atom_activation_2): PReLU(num_parameters=1)
        (atom_dropout_2): Dropout(p=0.0, inplace=False)
      )
      (3): Sequential(
        (atom_linear_3): Linear(in_features=128, out_features=128, bias=True)
        (atom_activation_3): PReLU(num_parameters=1)
        (atom_dropout_3): Dropout(p=0.0, inplace=False)
      )
      (4): Sequential(
        (atom_linea

In [21]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | base_model | GATv2Model | 12.2 M
------------------------------------------
12.2 M    Trainable params
0         Non-trainable params
12.2 M    Total params
48.722    Total estimated model params size (MB)


TypeError: expected Tensor as element 1 in argument 0, but got tuple

In [None]:
seconds = timer.time_elapsed()
h, m, s = int(seconds // 3600), int((seconds % 3600) // 60), int(seconds % 60)

print(f"Время обучения: {h}:{m:02d}:{s:02d}")


In [None]:
evaluate_model(model, data_module)

In [None]:
def draw_molecule(smiles, predictions):
    mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
    predictions_rounded = np.round(predictions, 2)

    for atom, pred in zip(mol.GetAtoms(), predictions_rounded):
        atom.SetProp('atomNote', str(pred))

    img = Chem.Draw.MolToImage(mol, size=(600, 600), kekulize=True)
    img.show()

#smiles = df_results.iloc[0]['smiles']
#predictions = df_results.iloc[0]['predictions']

#draw_molecule(smiles, predictions)
