In [4]:
import torch
from rdkit import Chem
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F
import csv
import os
from tqdm import tqdm

class GCN(torch.nn.Module):
    def __init__(self, num_features):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, 1)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.pool(x, batch)
        return torch.sigmoid(x)

def get_atom_features(mol):
    features = []
    for atom in mol.GetAtoms():
        atomic_num = atom.GetAtomicNum()
        aromatic = atom.GetIsAromatic()
        chirality = atom.GetChiralTag()
        formal_charge = atom.GetFormalCharge()
        num_hydrogens = atom.GetTotalNumHs()
        num_valence = atom.GetTotalValence()
        hybridization = atom.GetHybridization()
        is_in_ring = atom.IsInRing()

        hybridization = {
            Chem.rdchem.HybridizationType.SP: 1,
            Chem.rdchem.HybridizationType.SP2: 2,
            Chem.rdchem.HybridizationType.SP3: 3,
            Chem.rdchem.HybridizationType.SP3D: 4,
            Chem.rdchem.HybridizationType.SP3D2: 5,
            Chem.rdchem.HybridizationType.UNSPECIFIED: 0
        }.get(hybridization, 0)

        feature_vector = [
            atomic_num,
            int(aromatic),
            int(chirality != Chem.ChiralType.CHI_UNSPECIFIED),
            formal_charge,
            num_hydrogens,
            num_valence,
            hybridization,
            int(is_in_ring)
        ]
        features.append(feature_vector)

    return features

def molecule_to_graph(mol):
    bonds = mol.GetBonds()

    node_features = get_atom_features(mol)
    edge_index = []
    edge_features = []

    for bond in bonds:
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_index.append((start, end))
        edge_index.append((end, start))
        edge_features += [bond.GetBondTypeAsDouble(), bond.GetBondTypeAsDouble()]

    node_features = torch.tensor(node_features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_features = torch.tensor(edge_features, dtype=torch.float)

    return Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)

def load_model(model_path, num_features):
    model = GCN(num_features=num_features)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

def predict_activity(sdf_path, model):
    mol = Chem.MolFromMolFile(sdf_path)
    graph = molecule_to_graph(mol)
    graph.batch = torch.zeros(graph.num_nodes, dtype=torch.long)  # Add batch attribute
    graph = graph.to(torch.device('cpu'))
    model.eval()
    with torch.no_grad():
        output = model(graph).item()
    return output

In [5]:
model_path = 'gcn_model.pth'
num_features = 8
model = load_model(model_path, num_features)

In [9]:
best_folder = 'best'
model_path = 'gcn_model.pth'
num_features = 8
model = load_model(model_path, num_features)

output_file = 'building_blocks.csv'

with open(output_file, mode='w', newline='') as csv_file:
    writer = csv.writer(csv_file)
    writer.writerow(['Smiles', 'Score'])

    subfolders = [f.path for f in os.scandir(best_folder) if f.is_dir()]
    for subfolder in tqdm(subfolders, desc="Processing subfolders"):
        sdf_files = [f for f in os.listdir(subfolder) if 'rank1_confidence' in f and f.endswith('.sdf')]
        if not sdf_files:
            print(f'Skipped subfolder: {os.path.basename(subfolder)}')
            continue

        sdf_path = os.path.join(subfolder, sdf_files[0])
        score = predict_activity(sdf_path, model)
        writer.writerow([os.path.basename(subfolder), score])

Predicted activity score: 0.40328776836395264
