In [1]:

import torch
import pandas as pdfrom rdkit import Chem

### Extracting the Data

In [2]:
tox = pd.read_csv("data.csv")
tox.head()

Unnamed: 0,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,mol_id,smiles
0,0.0,0.0,1.0,,,0.0,0.0,1.0,0.0,0.0,0.0,0.0,TOX3021,CCOc1ccc2nc(S(N)(=O)=O)sc2c1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3020,CCN1C(=O)NC(c2ccccc2)C1=O
2,,,,,,,,0.0,,0.0,,,TOX3024,CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3027,CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C
4,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,TOX3028,CC(O)(P(=O)(O)O)P(=O)(O)O


In [60]:
tox_rows, tox_cols = tox.shape
classes = ['NR-AR'] #,'NR-AR-LBD','NR-AhR','NR-Aromatase','NR-ER','NR-ER-LBD','NR-PPAR-gamma','SR-ARE','SR-ATAD5','SR-HSE','SR-MMP','SR-p53'

tox = tox.dropna(subset=['NR-AR'])
# Train, Test and Validation Split
in_train, in_val, in_test = tox['smiles'][: int(tox_rows * 0.6)], tox['smiles'][int(tox_rows * 0.6): int(tox_rows * 0.7)], tox['smiles'][int(tox_rows * 0.7):]
lb_train, lb_val, lb_test = tox[classes][: int(tox_rows * 0.6)], tox[classes][int(tox_rows * 0.6): int(tox_rows * 0.7)], tox[classes][int(tox_rows * 0.7):]

### Generating edges and node features

In [61]:
MAX_ATOMS = 100

# Each atom is a node and the bonds between each atom are the edges. The features of the edges will be the type of the bond 
def get_edges(molecule):
    edge = []
    for bond in molecule.GetBonds():
        edge.append([bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()])
    return torch.tensor(edge,dtype=torch.long)

def get_node_features(molecule):
    # Atomic number
    # Valence
    # Hybridization
    features = []
    for atom in molecule.GetAtoms():
        ind_ft = [atom.GetAtomicNum(),atom.GetTotalValence()]
        hyb = atom.GetHybridization()
        if hyb == Chem.rdchem.HybridizationType.SP:
            ind_ft.extend([1,0,0,0,0])
        elif hyb == Chem.rdchem.HybridizationType.SP2:
            ind_ft.extend([0,1,0,0,0])
        elif hyb == Chem.rdchem.HybridizationType.SP3:
            ind_ft.extend([0,0,1,0,0])
        elif hyb == Chem.rdchem.HybridizationType.SP3D:
            ind_ft.extend([0,0,0,1,0])
        elif hyb == Chem.rdchem.HybridizationType.SP3D2:
            ind_ft.extend([0,0,0,0,1])
        else:
            ind_ft.extend([0,0,0,0,0])
        features.append(ind_ft)
    return torch.tensor(features,dtype=torch.float)

def get_inputs(smiles):
    molecule = Chem.MolFromSmiles(smiles)
    if molecule:
        if len(molecule.GetAtoms()) < MAX_ATOMS:
            molecule = Chem.AddHs(molecule) # add the hydrogens
            Chem.rdmolops.Kekulize(molecule) # Dealing with aromatic bonds
            edges = get_edges(molecule)
            node_ft = get_node_features(molecule)
            return edges, node_ft
        else:
            print("Molecule {} too big".format(smiles))
            return torch.tensor([]), torch.tensor([])


    else:
        print("Unable to parse {}".format(smiles))
        return torch.tensor([]), torch.tensor([])



### Creating the Model

In [62]:
from torch_geometric.nn.conv import GCNConv, GATConv
from torch_geometric.nn import global_max_pool
from torch import nn
import torch.optim as optim

class GNNModel(nn.Module):
    def __init__(self, input_features,hidden_features1,hidden_features2,hidden_features3,num_classes):
        super().__init__()
        self.gcl1 = GCNConv(input_features, hidden_features1)
        self.gcl2 = GCNConv(hidden_features1, hidden_features2)
        self.gat = GATConv(hidden_features2,hidden_features3, dropout=0.6)
        self.fc = nn.Linear(hidden_features3, num_classes)

    def forward(self, x, edges):
        x = self.gcl1(x, edges).relu()
        x = self.gcl2(x, edges).relu()
        x = self.gat(x,edges).relu()
        x = global_max_pool(x, torch.zeros(x.size(0), dtype=torch.long)) 
        x = self.fc(x).sigmoid()
        return x

### Training the Model

In [73]:
from torch_geometric.data import Data, DataLoader

def prepare_data(smiles_list, labels):
    data_list = []
    for smiles, label in zip(smiles_list, labels):
        edges, x = get_inputs(smiles)
        if edges.nelement() > 0: 
            data = Data(x=x, edge=edges.t().contiguous(), y=torch.tensor(label, dtype=torch.float).unsqueeze(0))
            data_list.append(data)
    return data_list

train_data = prepare_data(in_train, lb_train.values)
val_data = prepare_data(in_val, lb_val.values)
test_data = prepare_data(in_test, lb_test.values)

train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=False)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)

model = GNNModel(input_features=7, hidden_features1=8, hidden_features2=12, hidden_features3=16, num_classes=len(classes))
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.BCELoss()  # Binary Cross-Entropy for binary classification

def train(model, data_loader):
    model.train()
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()
        output = model(batch.x, batch.edge) 
        
        loss = criterion(output, batch.y)
        loss.backward()
        optimizer.step() 
        total_loss += loss.item()
    return total_loss / len(data_loader)

def evaluate(model, data_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in data_loader:
            output = model(batch.x,batch.edge)
            preds = (output >= 0.5).long() 
            all_preds.append(preds)
            all_labels.append(batch.y)
    
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    correct = (all_preds == all_labels).sum().item()  
    total = all_labels.size(0)
    accuracy = correct / total 
    return accuracy

def test(model, data_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in data_loader:
            output = model(batch.x,batch.edge)  
            preds = (output >= 0.5).long() 
            all_preds.append(preds)
            all_labels.append(batch.y)
    
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    correct = (all_preds == all_labels).sum().item()  
    total = all_labels.size(0)
    accuracy = correct / total 
    return accuracy


ITERATIONS = 20
for epoch in range(ITERATIONS):
    train_loss = train(model, train_loader)
    print('Epoch {}/{}, Train Loss: {:.4f}'.format(epoch+1,ITERATIONS,train_loss))
    
    val_accuracy = evaluate(model, val_loader)
    print('Epoch {}/{}, Validation Accuracy: {:.4f}'.format(epoch+1,ITERATIONS,val_accuracy))

test_accuracy = test(model, test_loader)
print('Test Accuracy: {:.4f}'.format(test_accuracy))




Molecule O=C(OC[C@H]1O[C@@H](OC(=O)c2cc(O)c(O)c(OC(=O)c3cc(O)c(O)c(O)c3)c2)[C@H](OC(=O)c2cc(O)c(O)c(OC(=O)c3cc(O)c(O)c(O)c3)c2)[C@@H](OC(=O)c2cc(O)c(O)c(OC(=O)c3cc(O)c(O)c(O)c3)c2)[C@@H]1OC(=O)c1cc(O)c(O)c(OC(=O)c2cc(O)c(O)c(O)c2)c1)c1cc(O)c(O)c(OC(=O)c2cc(O)c(O)c(O)c2)c1 too big
Molecule CC[C@@H](C)CCCCC(=O)N[C@@H](CCNCS(=O)(=O)[O-])C(=O)N[C@H](C(=O)N[C@@H](CCNCS(=O)(=O)[O-])C(=O)N[C@H]1CCNC(=O)[C@H]([C@@H](C)O)NC(=O)[C@H](CCNCS(=O)(=O)[O-])NC(=O)[C@H](CCNCS(=O)(=O)[O-])NC(=O)[C@H](CC(C)C)NC(=O)[C@@H](CC(C)C)NC(=O)[C@H](CCNCS(=O)(=O)[O-])NC1=O)[C@@H](C)O too big
Molecule CN[C@H](CC(C)C)C(=O)N[C@H]1C(=O)N[C@@H](CC(N)=O)C(=O)N[C@H]2C(=O)N[C@H]3C(=O)N[C@H](C(=O)N[C@@H](C(=O)O)c4cc(O)cc(O)c4-c4cc3ccc4O)[C@H](O)c3ccc(c(Cl)c3)Oc3cc2cc(c3O[C@@H]2O[C@H](CO)[C@@H](O)[C@H](O)[C@H]2O[C@H]2C[C@](C)(N)C(O)[C@H](C)O2)Oc2ccc(cc2Cl)[C@H]1O too big
Molecule CC[C@H](C)[C@@H](NC(=O)[C@H](CCC(=O)O)NC(=O)[C@H](CC(C)C)NC(=O)[C@@H]1CSC([C@H](N)[C@H](C)CC)=N1)C(=O)N[C@@H]1CCCCNC(=O)[C@@H](CC(N)=O)NC(=O)[C@@H