# Tox21 GINENet

We’re solving a multi-task binary toxicity prediction problem on Tox21. Concretely, for each molecule the model outputs 12 probabilities, one for each assay (e.g. NR-AR, SR-ARE, p53, etc.). At training time we use a binary cross-entropy loss (with masking for missing labels) over those 12 tasks, and at the end of each epoch we compute the ROC-AUC per task (then average) on the held-out validation set to see how well the model is distinguishing actives vs. inactives across all assays.

In [3]:
%pip -q install rdkit-pypi torch_geometric

Note: you may need to restart the kernel to use updated packages.


Cell below must be run twice for some obscure reason

In [4]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset

from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.metrics import roc_auc_score

In [5]:
!wget wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
!gunzip tox21.csv.gz

--2025-05-24 19:49:40--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-05-24 19:49:40--  https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125310 (122K) [application/octet-stream]
Saving to: ‘tox21.csv.gz’


2025-05-24 19:49:40 (5.18 MB/s) - ‘tox21.csv.gz’ saved [125310/125310]

FINISHED --2025-05-24 19:49:40--
Total wall clock time: 0.3s
Downloaded: 1 files, 122K in 0.02s (5.18 MB/s)


In [6]:
!mkdir -p raw
!mv tox21.csv raw/

## Dataset

In [7]:
from rdkit.Chem import MolFromSmiles, MolToSmiles, rdchem

In [8]:
import os

from torch.serialization import safe_globals, add_safe_globals
from torch_geometric.data.data import DataTensorAttr, DataEdgeAttr
from torch_geometric.data.storage import GlobalStorage

In [9]:
add_safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage])

#### Tox21 Dataset Description
**Nuclear Receptor (NR) assays (7 endpoints)**

1. **NR-AR** (Androgen Receptor)  
   - **Target**: Full-length human AR  
   - **Cell line**: MDA-kb2 breast cancer cells expressing luciferase under an AR-responsive promoter  
   - **Readout**: Luminescence upon AR agonist binding, measures agonist activity  

2. **NR-AR-LBD** (Androgen Receptor Ligand-Binding Domain)  
   - **Target**: AR ligand-binding domain fused to a transcriptional activator  
   - **Assay format**: β-lactamase reporter in HEK293-derived cells  
   - **Readout**: β-lactamase signal indicating direct LBD binding  

3. **NR-ER** (Estrogen Receptor α)  
   - **Target**: Full-length human ER α  
   - **Cell line**: BG1 ovarian carcinoma cells with ER-responsive luciferase reporter  
   - **Readout**: Luminescence upon estrogen-like agonism  

4. **NR-ER-LBD** (Estrogen Receptor α Ligand-Binding Domain)  
   - **Target**: ER α LBD fused to a DNA-binding domain  
   - **Cell line**: GAL4-ER α LBD reporter in HEK293-derived cells  
   - **Readout**: Reporter activation upon direct LBD engagement  

5. **NR-AhR** (Aryl Hydrocarbon Receptor)  
   - **Target**: Human AhR  
   - **Cell line**: HepG2 cells with DRE-driven luciferase reporter  
   - **Readout**: Luminescence when ligands (e.g. dioxin analogs) activate AhR  

6. **NR-Aromatase**  
   - **Target**: CYP19A1 (Aromatase enzyme)  
   - **Format**: Cell-free or microsomal conversion of testosterone to estradiol  
   - **Readout**: Fluorescent or luminescent detection of estradiol, measures inhibition of estrogen synthesis  

7. **NR-PPAR-γ** (Peroxisome Proliferator-Activated Receptor γ)  
   - **Target**: Human PPAR γ  
   - **Cell line**: PPRE-driven luciferase reporter in CV-1 or HEK293 cells  
   - **Readout**: Reporter activation by PPAR γ agonists  

---

### Stress Response (SR) assays (5 endpoints)

8. **SR-ARE** (Antioxidant Response Element)  
   - **Pathway**: Nrf2-ARE oxidative stress response  
   - **Cell line**: HepG2 cells with ARE-luciferase reporter  
   - **Readout**: Luminescence when oxidative-stress inducers activate the pathway  

9. **SR-ATAD5** (ATAD5 DNA Damage Response)  
   - **Target**: ATAD5 promoter stability  
   - **Cell line**: HEK293 cells with luciferase-tagged ATAD5  
   - **Readout**: Reporter stabilization upon genotoxic stress  

10. **SR-HSE** (Heat Shock Element)  
    - **Pathway**: HSF1-mediated heat shock response  
    - **Cell line**: Neuro-2a or HEK cells with HSE-driven reporter  
    - **Readout**: Luminescence when proteotoxic stress activates HSF1  

11. **SR-MMP** (Mitochondrial Membrane Potential)  
    - **Assay format**: Dye uptake (e.g. TMRE or Mito-MPS) in hepatocytes  
    - **Readout**: Fluorescent ratio indicating mitochondrial depolarization  

12. **SR-p53** (p53 Response)  
    - **Pathway**: p53 tumor suppressor activation  
    - **Cell line**: HepG2 cells with p53-responsive luciferase reporter  
    - **Readout**: Luminescence when DNA damage or stress stabilizes p53  


In [12]:
class Tox21Dataset(InMemoryDataset):
    def __init__(self, root: str = '.', transform=None, pre_transform=None):
        """
        Expects:
        root/
            raw/tox21.csv
        Will create:
            processed/data.pt
        """
        super().__init__(root, transform, pre_transform)
        # Load the processed data
        with safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage]):
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        # File expected in root/raw/
        return ['tox21.csv']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # No download step needed; CSV is already in place
        return

    def process(self):
        df = pd.read_csv(self.raw_paths[0]).dropna()
        if "mol_id" in df.columns:
            df = df.drop(columns=["mol_id"])

        # 3) convert all non-smiles columns to numeric
        non_smiles = [c for c in df.columns if c != "smiles"]
        df[non_smiles] = df[non_smiles].apply(pd.to_numeric, errors="coerce")

        # 4) drop any row with NaN in smiles or any label
        df = df.dropna(how="any", axis=0).reset_index(drop=True)

        data_list = []
        bond_types = [
            rdchem.BondType.SINGLE,
            rdchem.BondType.DOUBLE,
            rdchem.BondType.TRIPLE,
            rdchem.BondType.AROMATIC,
        ]

        for _, row in df.iterrows():
            smiles = row["smiles"]
            mol    = MolFromSmiles(smiles, sanitize=True)
            if mol is None:
                continue

            # node features x
            atom_feats = [
                [
                    a.GetAtomicNum(),
                    a.GetDegree(),
                    a.GetFormalCharge(),
                    a.GetNumRadicalElectrons(),
                ]
                for a in mol.GetAtoms()
            ]
            x = torch.tensor(atom_feats, dtype=torch.float32)

            # build edge_index AND edge_attr in lock‐step
            edge_index = []
            edge_attr  = []
            for bond in mol.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                # one‐hot encode this bond’s type
                bt = bond.GetBondType()
                feat = [int(bt == t) for t in bond_types]

                # add both directions
                edge_index += [[i, j], [j, i]]
                edge_attr  += [feat, feat]

            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_attr  = torch.tensor(edge_attr, dtype=torch.float32)  # [2E, len(bond_types)]

            # your labels
            y = torch.tensor(row[[c for c in df.columns if c!="smiles"]]
                             .tolist(), dtype=torch.float32)

            data_list.append(Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
            ))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [16]:
dataset = Tox21Dataset('.')

A **Murcko Scaffold** is a technique used to group molecules based on their core structural components. It identifies the essential building blocks by removing side chains and other non-core components, leaving behind the important ring systems and connecting chains. This method is widely used in medicinal chemistry and drug design to identify core structures that have preferential activity against specific targets, which is very useful in **molecular property prediction**.

In [17]:
def GetMurckoScaffold(data, train_frac: float, val_frac: float, test_frac: float):
    scaffold_indices = {}
    for idx, smiles in enumerate(data):
        scaffold_smiles = MurckoScaffold.MurckoScaffoldSmiles(smiles)
        #scaffold_smiles = MolToSmiles(scaffold)
        scaffold_indices.setdefault(scaffold_smiles, []).append(idx)

    groups = sorted(scaffold_indices.values(), key=len, reverse=True)
    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_valid = int(val_frac * n_total)

    train_idx, valid_idx, test_idx = [], [], []
    for group in groups:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(valid_idx) + len(group) <= n_valid:
            valid_idx.extend(group)
        else:
            test_idx.extend(group)

    return train_idx, valid_idx, test_idx

## GNN Model

In [18]:
from torch_geometric.nn import GINEConv, GatedGraphConv, global_mean_pool

In [39]:
class GINEModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_tasks):
        super().__init__()
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim = 4,
        )
        self.conv2 = GINEConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            edge_dim = 4
        )
        self.conv3 = GINEConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ),
            edge_dim = 4
        )
        self.gate = GatedGraphConv(hidden_dim, 12)
        self.bn_gate = nn.BatchNorm1d(hidden_dim)
        self.pool = global_mean_pool
        self.fc = nn.Linear(hidden_dim, num_tasks)
        self.dropout = nn.Dropout(0.2)

    # Maybe add skip connections 
    def forward(self, x, edge_index, edge_attr, batch):
        # 1st GINE layer + residual
        h1 = self.conv1(x, edge_index, edge_attr)
        x = (x + h1).relu()
        x = nn.Dropout(0.1)(x)
    
        # 2nd GINE layer + residual
        h2 = self.conv2(x, edge_index, edge_attr)
        x = (x + h2).relu()
        x = nn.Dropout(0.1)(x)
    
        # 3rd GINE layer + residual
        h3 = self.conv3(x, edge_index, edge_attr)
        x = (x + h3).relu()
    
        # normalization, dropout, pooling, and final FC
        x = self.bn_gate(x)
        x = self.dropout(x)
        x = self.pool(x, batch)
        return self.fc(x)

## Train the Model

In [40]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        if hasattr(batch, 'edge_attr'):
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        else:
            logits = model(batch.x, batch.edge_index, batch.batch)
        mask = (batch.y >= 0).float()
        bs, nt = logits.size()
        batch_y = batch.y.view(bs, nt)
        mask = mask.view(bs, nt)
        loss = (criterion(logits, batch_y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

In [41]:
def evaluate(model, loader, device):
    model.eval()
    y, preds = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if hasattr(batch, 'edge_attr'):
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            else:
                logits = model(batch.x, batch.edge_index, batch.batch)
            y.append(batch.y.cpu())
            preds.append(torch.sigmoid(logits).cpu())
    return torch.cat(preds, dim=0).numpy(), torch.cat(y, dim=0).numpy()

## Main

In [42]:
df = pd.read_csv('raw/tox21.csv').dropna()
smiles_list = df['smiles'].tolist()
train_idx, valid_idx, test_idx = GetMurckoScaffold(smiles_list, 0.8, 0.1, 0.1)

In [43]:
train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, valid_idx)
test_ds  = Subset(dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=32)
test_loader  = DataLoader(test_ds,  batch_size=32)

In [44]:
sample = dataset[0]
in_channels = sample.x.size(1)
num_tasks   = sample.y.size(0)

In [45]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [46]:
model = GINEModel(in_channels, 384, num_tasks).to(device)

print(model)

GINEModel(
  (conv1): GINEConv(nn=Sequential(
    (0): Linear(in_features=4, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (conv2): GINEConv(nn=Sequential(
    (0): Linear(in_features=384, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (conv3): GINEConv(nn=Sequential(
    (0): Linear(in_features=384, out_features=384, bias=True)
    (1): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=384, out_features=384, bias=True)
  ))
  (gate): GatedGraphConv(384, num_layers=12)
  (bn_gate): BatchNorm1d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=384, out_featur

In [47]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-5,
)

criterion = nn.BCEWithLogitsLoss(reduction='none')

In [48]:
for epoch in range(1, 21):
    loss = train_epoch(model, train_loader, optimizer, criterion, device)
    if loss is None:
        print("Error during training")
        break
    ps_val, ys_val = evaluate(model, val_loader, device)
    ys_val = ys_val.reshape(ps_val.shape)
    aucs = []
    for i in range(ps_val.shape[1]):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

RuntimeError: The size of tensor a (4) must match the size of tensor b (384) at non-singleton dimension 1