In [47]:
import torch
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
from rdkit import Chem
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import warnings
from torch.utils.data import Dataset, random_split

warnings.filterwarnings("ignore")


In [48]:
# Create directories
os.makedirs('saved_models', exist_ok=True)
os.makedirs('plots', exist_ok=True)


In [49]:
# 1. Atom & Bond Features

ATOM_LIST = [1, 6, 7, 8, 9, 15, 16, 17, 35, 53]

def atom_features(atom):
    features = [
        ATOM_LIST.index(atom.GetAtomicNum()) if atom.GetAtomicNum() in ATOM_LIST else 0,
        atom.GetDegree(),
        atom.GetImplicitValence(),
        atom.GetFormalCharge(),
        int(atom.GetIsAromatic())
    ]
    return torch.tensor(features, dtype=torch.float)

def bond_features(bond):
    features = [
        int(bond.GetBondTypeAsDouble()),
        int(bond.GetIsConjugated()),
        int(bond.IsInRing())
    ]
    return torch.tensor(features, dtype=torch.float)

In [52]:
# 2. SMILES to Graph

from torch_geometric.data import Data

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    # Atom features
    atoms = mol.GetAtoms()
    if not atoms:
        return None
    x = torch.stack([atom_features(atom) for atom in atoms])

    # Edge features
    edge_index, edge_attr = [], []
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_index.extend([[i, j], [j, i]])
        bf = bond_features(bond)
        edge_attr.extend([bf, bf])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() if edge_index else torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.stack(edge_attr) if edge_attr else torch.empty((0, 3), dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [53]:
# 3. Custom Dataset

class Tox21Dataset(Dataset):
    def __init__(self, dataframe, target):
        self.graphs = []
        for _, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Processing"):
            if pd.isna(row[target]) or row[target] == -1:
                continue
            g = smiles_to_graph(row['smiles'])
            if g is not None:
                g.y = torch.tensor([int(row[target])], dtype=torch.long)
                self.graphs.append(g)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]
        return self.graphs[idx]

In [54]:
# 4. GCN Model

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)
        return self.lin(x)

In [55]:
# 5. Load & Prepare Data

df = pd.read_csv("tox21.csv")
target = "SR-MMP"  # Change to any target
df = df[df["smiles"].notnull()]

dataset = Tox21Dataset(df, target)
print(f"\nLoaded {len(dataset)} valid molecules for target '{target}'")

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


Processing:  29%|██▊       | 2246/7831 [00:02<00:05, 1101.44it/s][23:58:00] Explicit valence for atom # 4 Al, 6, is greater than permitted
Processing:  45%|████▌     | 3524/7831 [00:03<00:04, 1047.67it/s][23:58:01] Explicit valence for atom # 4 Al, 6, is greater than permitted
Processing:  57%|█████▋    | 4502/7831 [00:04<00:03, 968.54it/s] [23:58:02] Explicit valence for atom # 9 Al, 6, is greater than permitted
Processing:  59%|█████▉    | 4602/7831 [00:04<00:03, 976.69it/s][23:58:02] Explicit valence for atom # 5 Al, 6, is greater than permitted
Processing:  70%|██████▉   | 5479/7831 [00:05<00:02, 1098.58it/s][23:58:03] Explicit valence for atom # 16 Al, 6, is greater than permitted
Processing:  85%|████████▍ | 6622/7831 [00:06<00:01, 1004.30it/s][23:58:04] Explicit valence for atom # 20 Al, 6, is greater than permitted
Processing: 100%|██████████| 7831/7831 [00:07<00:00, 1006.97it/s]


Loaded 5804 valid molecules for target 'SR-MMP'





In [56]:
# 6. Train the Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN(in_channels=5, hidden_channels=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

train_losses = []
val_aucs = []

print("\nStarting training...")
for epoch in range(1, 21):
    # Training
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_losses.append(total_loss / len(train_loader))

    # Validation
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            probs = torch.softmax(out, dim=1)[:, 1]
            y_pred.extend(probs.cpu().numpy())
            y_true.extend(batch.y.cpu().numpy())
    
    if len(y_true) > 0 and len(np.unique(y_true)) >= 2:
        auc = roc_auc_score(y_true, y_pred)
        val_aucs.append(auc)
        print(f"Epoch {epoch:02d} | Loss: {train_losses[-1]:.4f} | AUC: {auc:.4f}")
    else:
        print(f"Epoch {epoch:02d} | Loss: {train_losses[-1]:.4f} | Validation skipped (not enough samples)")

print("\nTraining complete!")



Starting training...
Epoch 01 | Loss: 0.4494 | AUC: 0.6885
Epoch 02 | Loss: 0.4168 | AUC: 0.7006
Epoch 03 | Loss: 0.4171 | AUC: 0.7094
Epoch 04 | Loss: 0.4150 | AUC: 0.7125
Epoch 05 | Loss: 0.4086 | AUC: 0.7239
Epoch 06 | Loss: 0.4076 | AUC: 0.7235
Epoch 07 | Loss: 0.4066 | AUC: 0.7370
Epoch 08 | Loss: 0.4036 | AUC: 0.7449
Epoch 09 | Loss: 0.3983 | AUC: 0.7399
Epoch 10 | Loss: 0.3995 | AUC: 0.7541
Epoch 11 | Loss: 0.4028 | AUC: 0.7496
Epoch 12 | Loss: 0.3920 | AUC: 0.7563
Epoch 13 | Loss: 0.3935 | AUC: 0.7656
Epoch 14 | Loss: 0.3899 | AUC: 0.7649
Epoch 15 | Loss: 0.3913 | AUC: 0.7694
Epoch 16 | Loss: 0.3903 | AUC: 0.7661
Epoch 17 | Loss: 0.3862 | AUC: 0.7738
Epoch 18 | Loss: 0.3862 | AUC: 0.7746
Epoch 19 | Loss: 0.3841 | AUC: 0.7798
Epoch 20 | Loss: 0.3843 | AUC: 0.7704

Training complete!


In [57]:
# 7. Visualization
# Training Loss Curve
if len(train_losses) > 0:
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    plt.savefig('plots/training_loss.png')
    plt.close()

if len(val_aucs) > 0:
    plt.figure(figsize=(10, 5))
    plt.plot(val_aucs, label='Validation AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.title('Validation AUC')
    plt.legend()
    plt.savefig('plots/validation_auc.png')
    plt.close()

# ROC Curve
if len(y_true) > 0 and len(np.unique(y_true)) >= 2:
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, label=f'AUC = {val_aucs[-1]:.2f}')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    plt.savefig('plots/roc_curve.png')
    plt.close()

In [58]:
# 8. Save Model

torch.save(model.state_dict(), f"saved_models/gcn_{target.replace('-', '_')}.pt")
with open(f"saved_models/gcn_{target.replace('-', '_')}.pkl", 'wb') as f:
    pickle.dump(model, f)

print("\nAll done! Models and plots saved.")


All done! Models and plots saved.
