In [44]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from ase import Atoms
from ase.io import read
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx
from torch_geometric.data import DataLoader
from sklearn.model_selection import train_test_split
from ase.build import molecule
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem

In [17]:
class GraphConvolution(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GraphConvolution, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return self.lin(aggr_out)

class GraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphEncoder, self).__init__()
        self.conv1 = GraphConvolution(input_dim, hidden_dim)
        self.conv2 = GraphConvolution(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = nn.ReLU()(x)
        x = self.conv2(x, edge_index)
        return x


In [79]:
def load_data(csv_path, sdf_path):
    graphs = []
    # Load features from CSV
    features_df = pd.read_csv(csv_path)
    suppl = Chem.SDMolSupplier(sdf_path)

    for mol in suppl:
        if mol is not None:
            positions = mol.GetConformer().GetPositions()
            atomic_numbers = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
            properties = {}
            for prop in mol.GetPropNames():
                properties[prop] = mol.GetProp(prop)

            #formula = properties['Formula']

            atoms_list = []

            bonds = []
            for bond in mol.GetBonds():
                bonds.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondTypeAsDouble()))

            charges = []
            for atom in mol.GetAtoms():
                atom_index = atom.GetIdx()
                symbol = atom.GetSymbol()
                formal_charge = atom.GetFormalCharge()
                atomic_number = atom.GetAtomicNum()
                atomic_mass = atom.GetMass()
                hybridization = atom.GetHybridization()
                num_explicit_hs = atom.GetNumExplicitHs()
                num_radical_electrons = atom.GetNumRadicalElectrons()
                is_aromatic = atom.GetIsAromatic()
                is_in_ring = atom.IsInRing()
                position = mol.GetConformer().GetAtomPosition(atom_index)
                charges.append(atom.GetFormalCharge())

                atoms_list.append({
                    "index": atom_index,
                    "symbol": symbol,
                    "formal_charge": formal_charge,
                    "atomic_number": atomic_number,
                    "atomic_mass": atomic_mass,
                    "hybridization": hybridization,
                    "num_explicit_hs": num_explicit_hs,
                    "num_radical_electrons": num_radical_electrons,
                    "is_aromatic": is_aromatic,
                    "is_in_ring": is_in_ring,
                    "position": position
                })

            # Create ASE atoms object
            graphs.append({
                "Atoms": atoms_list,
                "Bonds": bonds,
                "Charges": charges,
                "Properties": properties
            })
    return graphs, features_df

In [80]:
# Load data
graphs, features = load_data("../data/tox21/tox21_dense_train.csv", "../data/tox21/tox21.sdf")


[17:48:52] Explicit valence for atom # 3 Cl, 2, is greater than permitted
[17:48:52] ERROR: Could not sanitize molecule ending on line 21572
[17:48:52] ERROR: Explicit valence for atom # 3 Cl, 2, is greater than permitted
[17:49:04] Explicit valence for atom # 2 Si, 8, is greater than permitted
[17:49:04] ERROR: Could not sanitize molecule ending on line 346021
[17:49:04] ERROR: Explicit valence for atom # 2 Si, 8, is greater than permitted
[17:49:09] Explicit valence for atom # 3 Cl, 2, is greater than permitted
[17:49:09] ERROR: Could not sanitize molecule ending on line 446665
[17:49:09] ERROR: Explicit valence for atom # 3 Cl, 2, is greater than permitted
[17:49:14] Explicit valence for atom # 1 Cl, 2, is greater than permitted
[17:49:14] ERROR: Could not sanitize molecule ending on line 619150
[17:49:14] ERROR: Explicit valence for atom # 1 Cl, 2, is greater than permitted
[17:49:21] Explicit valence for atom # 2 Si, 8, is greater than permitted
[17:49:21] ERROR: Could not sanitiz

In [81]:
print(graphs[0])

{'Atoms': [{'index': 0, 'symbol': 'Cl', 'formal_charge': -1, 'atomic_number': 17, 'atomic_mass': 35.453, 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3, 'num_explicit_hs': 0, 'num_radical_electrons': 0, 'is_aromatic': False, 'is_in_ring': False, 'position': <rdkit.Geometry.rdGeometry.Point3D object at 0x17a7d6a40>}, {'index': 1, 'symbol': 'C', 'formal_charge': 0, 'atomic_number': 6, 'atomic_mass': 12.011, 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP3, 'num_explicit_hs': 0, 'num_radical_electrons': 0, 'is_aromatic': False, 'is_in_ring': False, 'position': <rdkit.Geometry.rdGeometry.Point3D object at 0x17a7d7140>}, {'index': 2, 'symbol': 'N', 'formal_charge': 1, 'atomic_number': 7, 'atomic_mass': 14.007, 'hybridization': rdkit.Chem.rdchem.HybridizationType.SP2, 'num_explicit_hs': 0, 'num_radical_electrons': 0, 'is_aromatic': True, 'is_in_ring': True, 'position': <rdkit.Geometry.rdGeometry.Point3D object at 0x17a7d6b40>}, {'index': 3, 'symbol': 'C', 'formal_charge': 0

In [82]:
print(len(graphs))

12700


In [85]:
print(features)

            Unnamed: 0            AW  AWeight   Arto  BertzCT    Chi0    Chi1  \
0      NCGC00178831-03  5.436720e+07   13.053  2.176    3.194  23.112  15.868   
1      NCGC00166114-03  1.268818e+07   22.123  2.065    3.137  21.033  13.718   
2      NCGC00263563-01  3.076932e+06   13.085  2.154    3.207  46.896  29.958   
3      NCGC00013058-02  7.168569e+07   12.832  2.029    3.380  51.086  32.045   
4      NCGC00167516-01  7.989702e+06   12.936  2.124    3.573  70.295  46.402   
...                ...           ...      ...    ...      ...     ...     ...   
12055  NCGC00261292-01  1.428572e+07   14.255  2.000    2.628   9.259   6.309   
12056  NCGC00261245-01  1.193182e+07   13.674  2.061    2.920  21.142  15.382   
12057  NCGC00260828-01  1.081800e+01   12.374  2.045    3.128  33.242  20.457   
12058  NCGC00260687-01  3.229000e+00   12.543  2.267    2.700  10.251   7.381   
12059  NCGC00261465-01  1.931035e+07   15.004  1.867    2.985  20.190  12.619   

       Chi10    Chi2    Chi

In [None]:
train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, random_state=42)


In [None]:
class ToxicClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ToxicClassifier, self).__init__()
        self.encoder = GraphEncoder(input_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x = self.encoder(data)
        x = self.fc(x)
        return torch.sigmoid(x)


# Initialize model
input_dim = len(graphs[0].x[0])  # Assuming node feature size is same for all graphs
hidden_dim = 64
model = ToxicClassifier(input_dim, hidden_dim)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train model
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for graph in train_graphs:
        optimizer.zero_grad()
        output = model(graph)
        target = torch.tensor([graph.tox], dtype=torch.float)  #### CHECK with the tox21 data to validate what the column name is!!
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

# Evaluate model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for graph in test_graphs:
        output = model(graph)
        predicted = torch.round(output)
        target = torch.tensor([graph.y], dtype=torch.float)
        total += 1
        if predicted == target:
            correct += 1

    accuracy = correct / total
    print(f'Accuracy: {accuracy}')