# Tox21 GINENet

We’re solving a multi-task binary toxicity prediction problem on Tox21. Concretely, for each molecule the model outputs 12 probabilities, one for each assay (e.g. NR-AR, SR-ARE, p53, etc.). At training time we use a binary cross-entropy loss (with masking for missing labels) over those 12 tasks, and at the end of each epoch we compute the ROC-AUC per task (then average) on the held-out validation set to see how well the model is distinguishing actives vs. inactives across all assays.

In [1]:
%pip -q install rdkit-pypi torch_geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m58.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
[?25hNote: you may need to restart the kernel to use updated packages.


Cell below must be run twice for some obscure reason

In [2]:
import pandas as pd
import numpy as np


# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, WeightedRandomSampler
from torch.serialization import safe_globals, add_safe_globals

# Pytorch Geometric Imports
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.data.data import DataTensorAttr, DataEdgeAttr
from torch_geometric.data.storage import GlobalStorage
from torch_geometric.nn import GINEConv, GatedGraphConv, global_mean_pool

# RDKit Imports
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import MolFromSmiles, MolToSmiles, rdchem

from sklearn.metrics import roc_auc_score
import os

In [3]:
!wget wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
!gunzip tox21.csv.gz
!mkdir -p raw
!mv tox21.csv raw/

--2025-05-25 15:43:56--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-05-25 15:43:56--  https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125310 (122K) [application/octet-stream]
Saving to: ‘tox21.csv.gz’


2025-05-25 15:43:56 (5.58 MB/s) - ‘tox21.csv.gz’ saved [125310/125310]

FINISHED --2025-05-25 15:43:56--
Total wall clock time: 0.2s
Downloaded: 1 files, 122K in 0.02s (5.58 MB/s)


In [4]:
add_safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage])

## Dataset

Perhaps we will add more features later/do some feature optimization

In [5]:
class Tox21Dataset(InMemoryDataset):
    def __init__(self, root: str = '.', transform=None, pre_transform=None):
        """
        Expects:
        root/
            raw/tox21.csv
        Will create:
            processed/data.pt
        """
        super().__init__(root, transform, pre_transform)
        # Load the processed data
        with safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage]):
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        # File expected in root/raw/
        return ['tox21.csv']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # No download step needed; CSV is already in place
        return

    def process(self):
        df = pd.read_csv(self.raw_paths[0])
        df = df.fillna(value=0)
        
        if "mol_id" in df.columns:
            df = df.drop(columns=["mol_id"])

        # 3) convert all non-smiles columns to numeric
        non_smiles = [c for c in df.columns if c != "smiles"]
        df[non_smiles] = df[non_smiles].apply(pd.to_numeric, errors="coerce")
        df = df.reset_index(drop=True)

        data_list = []
        bond_types = [
            rdchem.BondType.SINGLE,
            rdchem.BondType.DOUBLE,
            rdchem.BondType.TRIPLE,
            rdchem.BondType.AROMATIC,
        ]

        for _, row in df.iterrows():
            smiles = row["smiles"]
            mol    = MolFromSmiles(smiles, sanitize=True)
            if mol is None:
                continue

            # node features x
            atom_feats = [
                [
                    a.GetAtomicNum(),
                    a.GetDegree(),
                    a.GetFormalCharge(),
                    a.GetNumRadicalElectrons(),
                ]
                for a in mol.GetAtoms()
            ]
            x = torch.tensor(atom_feats, dtype=torch.float32)

            # build edge_index AND edge_attr in lock‐step
            edge_index = []
            edge_attr  = []
            for bond in mol.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                # one‐hot encode this bond’s type
                bt = bond.GetBondType()
                feat = [int(bt == t) for t in bond_types]

                # add both directions
                edge_index += [[i, j], [j, i]]
                edge_attr  += [feat, feat]

            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr  = torch.tensor(edge_attr, dtype=torch.float32)  # [2E, len(bond_types)]

            # your labels
            y = torch.tensor(row[[c for c in df.columns if c!="smiles"]]
                             .tolist(), dtype=torch.float32)

            data_list.append(Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
            ))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [6]:
dataset = Tox21Dataset('.')

Processing...
Done!


A **Murcko Scaffold** is a technique used to group molecules based on their core structural components. It identifies the essential building blocks by removing side chains and other non-core components, leaving behind the important ring systems and connecting chains. This method is widely used in medicinal chemistry and drug design to identify core structures that have preferential activity against specific targets, which is very useful in **molecular property prediction**.

In [7]:
def GetMurckoScaffold(data, train_frac: float, val_frac: float, test_frac: float):
    scaffold_indices = {}
    for idx, smiles in enumerate(data):
        scaffold_smiles = MurckoScaffold.MurckoScaffoldSmiles(smiles)
        #scaffold_smiles = MolToSmiles(scaffold)
        scaffold_indices.setdefault(scaffold_smiles, []).append(idx)

    groups = sorted(scaffold_indices.values(), key=len, reverse=True)
    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_valid = int(val_frac * n_total)

    train_idx, valid_idx, test_idx = [], [], []
    for group in groups:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(valid_idx) + len(group) <= n_valid:
            valid_idx.extend(group)
        else:
            test_idx.extend(group)

    return train_idx, valid_idx, test_idx

## GNN Model

In [8]:
np.random.seed(1638)
torch.manual_seed(1638)

<torch._C.Generator at 0x7bfeb9e4f1f0>

In [30]:
class GINEModel(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_tasks: int,
        dropout: float,
        n_layers: int
    ):
        super().__init__()
        
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim = 4,
        )

        self.conv_block = nn.ModuleList()   
        for _ in range(n_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.conv_block.append(GINEConv(mlp, edge_dim=4))

        self.gate = GatedGraphConv(hidden_dim, 3) # 2nd arg is num of layers
        self.bn_gate = nn.BatchNorm1d(hidden_dim)
        self.pool = global_mean_pool
        self.fc = nn.Linear(hidden_dim, num_tasks)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_attr, batch):
        # 1st GINE layer
        x = self.conv1(x, edge_index, edge_attr)
        x = nn.Dropout(0.1)(x)

        # GINE layers + skip connections
        for block in self.conv_block:
            h = block(x, edge_index, edge_attr)
            x = (x + h).relu()
            x = nn.Dropout(0.1)(x)
    
        # normalization, dropout, pooling, and final FC
        x = self.bn_gate(x)
        x = self.dropout(x)
        x = self.pool(x, batch)
        return self.fc(x)

## Train the Model

In [20]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        if hasattr(batch, 'edge_attr'):
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        else:
            logits = model(batch.x, batch.edge_index, batch.batch)
        mask = (batch.y >= 0).float()
        bs, nt = logits.size()
        batch_y = batch.y.view(bs, nt)
        mask = mask.view(bs, nt)
        loss = (criterion(logits, batch_y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

In [21]:
def evaluate(model, loader, device):
    model.eval()
    y, preds = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if hasattr(batch, 'edge_attr'):
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            else:
                logits = model(batch.x, batch.edge_index, batch.batch)
            y.append(batch.y.cpu())
            preds.append(torch.sigmoid(logits).cpu())
    return torch.cat(preds, dim=0).numpy(), torch.cat(y, dim=0).numpy()

## Main

In [22]:
df = pd.read_csv('raw/tox21.csv').fillna(0).reset_index(drop=True)
smiles_list = df['smiles'].tolist()
train_idx, valid_idx, test_idx = GetMurckoScaffold(smiles_list, 0.8, 0.1, 0.1)

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, valid_idx)
test_ds  = Subset(dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)#, sampler=sampler)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=32)

sample = dataset[0]
in_channels = sample.x.size(1)
num_tasks   = sample.y.size(0)

print("There are", num_tasks, "tasks.")



There are 12 tasks.


In [23]:
df.tail()

Unnamed: 0,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,mol_id,smiles
8009,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,TOX2725,CCOc1nc2cccc(C(=O)O)c2n1Cc1ccc(-c2ccccc2-c2nnn...
8010,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,TOX2370,CC(=O)[C@H]1CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(...
8011,1.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,TOX2371,C[C@]12CC[C@H]3[C@@H](CCC4=CC(=O)CC[C@@]43C)[C...
8012,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,TOX2377,C[C@]12CC[C@@H]3c4ccc(O)cc4CC[C@H]3[C@@H]1CC[C...
8013,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,TOX2724,COc1ccc2c(c1OC)CN1CCc3cc4c(cc3C1C2)OCO4


In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GINEModel(
    in_channels,
    hidden_dim = 384,
    num_tasks = num_tasks,
    dropout = 0.2,
    n_layers = 3
).to(device)

print(model)

GINEModel(
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=4, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (conv_block): ModuleList(
    (0-2): 3 x GINEConv(nn=Sequential(
      (0): Linear(in_features=384, out_features=384, bias=True)
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=384, bias=True)
    ))
  )
  (gate): GatedGraphConv(384, num_layers=3)
  (bn_gate): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=384, out_features=12, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)


In [32]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-5,
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",           # we track val AUC, so “max”
    factor=0.5,           # halve the lr
    patience=5,           # after 5 epochs with no AUC gain
    min_lr=1e-6
)

criterion = nn.BCEWithLogitsLoss(reduction='none')

In [33]:
for epoch in range(1, 21):
    loss = train_epoch(model, train_loader, optimizer, criterion, device)
    if loss is None:
        print("Error during training")
        break
    ps_val, ys_val = evaluate(model, val_loader, device)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    scheduler.step(np.mean(aucs)) # we use the mean like below
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.3516 | Val AUC: 0.629
Epoch 02 | Loss: 0.2002 | Val AUC: 0.651
Epoch 03 | Loss: 0.1970 | Val AUC: 0.663
Epoch 04 | Loss: 0.1947 | Val AUC: 0.647
Epoch 05 | Loss: 0.1937 | Val AUC: 0.680
Epoch 06 | Loss: 0.1908 | Val AUC: 0.669
Epoch 07 | Loss: 0.1896 | Val AUC: 0.666
Epoch 08 | Loss: 0.1871 | Val AUC: 0.670
Epoch 09 | Loss: 0.1872 | Val AUC: 0.691
Epoch 10 | Loss: 0.1854 | Val AUC: 0.698
Epoch 11 | Loss: 0.1860 | Val AUC: 0.680
Epoch 12 | Loss: 0.1838 | Val AUC: 0.699
Epoch 13 | Loss: 0.1840 | Val AUC: 0.715
Epoch 14 | Loss: 0.1816 | Val AUC: 0.703
Epoch 15 | Loss: 0.1810 | Val AUC: 0.704
Epoch 16 | Loss: 0.1804 | Val AUC: 0.702
Epoch 17 | Loss: 0.1794 | Val AUC: 0.712
Epoch 18 | Loss: 0.1780 | Val AUC: 0.701
Epoch 19 | Loss: 0.1771 | Val AUC: 0.695
Epoch 20 | Loss: 0.1733 | Val AUC: 0.710


Try hyperparameter tuning on the `hidden_dim`, learning rate, and number of `GINEConv` layers.

In [18]:
import optuna

In [36]:
def objective(trial):
    lr = trial.suggest_loguniform("lr", 5e-5, 5e-3)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    hidden_dim = trial.suggest_categorical("hidden_dim", [192, 384, 768])
    dropout = trial.suggest_uniform("dropout", 0.1, 0.3)
    n_layers = trial.suggest_int("n_layers", 3, 4)

    test_model = GINEModel(
        in_channels,
        hidden_dim = hidden_dim,
        num_tasks = num_tasks,
        dropout = dropout,
        n_layers = n_layers
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )

    for epoch in range(1, 11):
        loss = train_epoch(test_model, train_loader, optimizer, criterion, device)
        
    ps_val, ys_val = evaluate(model, val_loader, device)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))

    return np.mean(aucs)

In [37]:
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(),
    pruner=optuna.pruners.MedianPruner()
)

study.optimize(objective, n_trials=20, timeout=3600)

print("Best AUC: ", study.best_value)
print("Best hyperparams: ", study.best_params)

[I 2025-05-25 16:18:40,988] A new study created in memory with name: no-name-cae8272b-676b-4e90-887d-227eeb5071a6
  lr = trial.suggest_loguniform("lr", 5e-5, 5e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
  dropout = trial.suggest_uniform("dropout", 0.1, 0.3)
[I 2025-05-25 16:19:00,947] Trial 0 finished with value: 0.7174598574712118 and parameters: {'lr': 0.00258709777543951, 'weight_decay': 0.000126524404497115, 'hidden_dim': 768, 'dropout': 0.17285110043010013, 'n_layers': 3}. Best is trial 0 with value: 0.7174598574712118.
  lr = trial.suggest_loguniform("lr", 5e-5, 5e-3)
  weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
  dropout = trial.suggest_uniform("dropout", 0.1, 0.3)
[I 2025-05-25 16:19:23,786] Trial 1 finished with value: 0.7163709779591758 and parameters: {'lr': 0.00044563857267675697, 'weight_decay': 1.0514300732872546e-06, 'hidden_dim': 768, 'dropout': 0.12717895439668586, 'n_layers': 4}. Best is trial 0 with value: 0.717

Best AUC:  0.7240326163972816
Best hyperparams:  {'lr': 0.0008333090883296305, 'weight_decay': 3.4116326791858185e-05, 'hidden_dim': 384, 'dropout': 0.14275025219273693, 'n_layers': 3}


### Evaluate Best Model

In [40]:
best_model = GINEModel(
    in_channels,
    hidden_dim = study.best_params["hidden_dim"],
    num_tasks = num_tasks,
    dropout = study.best_params["dropout"],
    n_layers = study.best_params["n_layers"]
).to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=study.best_params["lr"],
    weight_decay=study.best_params["weight_decay"]
)

In [41]:
print(best_model)

GINEModel(
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=4, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (conv_block): ModuleList(
    (0-2): 3 x GINEConv(nn=Sequential(
      (0): Linear(in_features=384, out_features=384, bias=True)
      (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=384, bias=True)
    ))
  )
  (gate): GatedGraphConv(384, num_layers=3)
  (bn_gate): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=384, out_features=12, bias=True)
  (dropout): Dropout(p=0.14275025219273693, inplace=False)
)


In [42]:
for epoch in range(1, 21):
    loss = train_epoch(best_model, train_loader, optimizer, criterion, device)
    ps_val, ys_val = evaluate(best_model, val_loader, device)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

Epoch 01 | Loss: 0.6977 | Val AUC: 0.516
Epoch 02 | Loss: 0.6978 | Val AUC: 0.522
Epoch 03 | Loss: 0.6981 | Val AUC: 0.515
Epoch 04 | Loss: 0.6978 | Val AUC: 0.504
Epoch 05 | Loss: 0.6972 | Val AUC: 0.512
Epoch 06 | Loss: 0.6976 | Val AUC: 0.513
Epoch 07 | Loss: 0.6983 | Val AUC: 0.529
Epoch 08 | Loss: 0.6972 | Val AUC: 0.524
Epoch 09 | Loss: 0.6968 | Val AUC: 0.523
Epoch 10 | Loss: 0.6977 | Val AUC: 0.507
Epoch 11 | Loss: 0.6982 | Val AUC: 0.523
Epoch 12 | Loss: 0.6979 | Val AUC: 0.515
Epoch 13 | Loss: 0.6972 | Val AUC: 0.538
Epoch 14 | Loss: 0.6971 | Val AUC: 0.516
Epoch 15 | Loss: 0.6971 | Val AUC: 0.509
Epoch 16 | Loss: 0.6977 | Val AUC: 0.528
Epoch 17 | Loss: 0.6971 | Val AUC: 0.524
Epoch 18 | Loss: 0.6970 | Val AUC: 0.518
Epoch 19 | Loss: 0.6971 | Val AUC: 0.540
Epoch 20 | Loss: 0.6968 | Val AUC: 0.523


Somehow did worse...