In [None]:
%pip -q install rdkit-pypi torch_geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m34.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25h

Cell below must be run twice for some obscure reason

In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset

from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.metrics import roc_auc_score

In [None]:
!wget wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
!gunzip tox21.csv.gz

--2025-05-21 03:30:53--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-05-21 03:30:53--  https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125310 (122K) [application/octet-stream]
Saving to: ‘tox21.csv.gz’


2025-05-21 03:30:53 (3.24 MB/s) - ‘tox21.csv.gz’ saved [125310/125310]

FINISHED --2025-05-21 03:30:53--
Total wall clock time: 0.3s
Downloaded: 1 files, 122K in 0.04s (3.24 MB/s)


In [None]:
!mkdir -p raw
!mv tox21.csv raw/

mv: cannot stat 'tox21.csv': No such file or directory


## Dataset

In [None]:
from rdkit.Chem import MolFromSmiles, MolToSmiles

In [None]:
class Tox21Dataset(InMemoryDataset):
    def __init__(self, root: str = '.', transform=None, pre_transform=None):
        """
        Expects:
        root/
            raw/tox21.csv
        Will create:
            processed/data.pt
        """
        super().__init__(root, transform, pre_transform)
        # Load the processed data
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        # File expected in root/raw/
        return ['tox21.csv']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # No download step needed; CSV is already in place
        return

    def process(self):
        df = pd.read_csv(self.raw_paths[0])
        data_list = []
        for _, row in df.iterrows():
            smiles = row['smiles']
            molecule = MolFromSmiles(smiles, sanitize=True)
            if not molecule: continue
            atom_features = []

            for atom in molecule.GetAtoms():
                atom_features.append([
                    atom.GetAtomicNum(),
                    atom.GetDegree(),
                    atom.GetFormalCharge(),
                    atom.GetNumRadicalElectrons(),
                ])
            x = torch.tensor(atom_features, dtype=torch.float32)

            edge_index = []
            for bond in molecule.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                edge_index.append([i, j])
                edge_index.append([j, i])
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

            y = torch.tensor(row.iloc[1:].values, dtype=torch.float32)
            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)

        data, slices = self.collate(data_list)
        return data, slices


In [None]:
dataset = Tox21Dataset()

Processing...


TypeError: No registered converter was able to produce a C++ rvalue of type std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> > from this Python object of type numpy.float64

A **Murcko Scaffold** is a technique used to group molecules based on their core structural components. It identifies the essential building blocks by removing side chains and other non-core components, leaving behind the important ring systems and connecting chains. This method is widely used in medicinal chemistry and drug design to identify core structures that have preferential activity against specific targets, which is very useful in **molecular property prediction**.

In [None]:
def GetMurckoScaffold(data, train_frac: float, val_frac: float, test_frac: float):
    scaffold_indices = {}
    for idx, smiles in enumerate(data):
        scaffold_smiles = MurckoScaffold.MurckoScaffoldSmiles(smiles)
        #scaffold_smiles = MolToSmiles(scaffold)
        scaffold_indices.setdefault(scaffold_smiles, []).append(idx)

    groups = sorted(scaffold_indices.values(), key=len, reverse=True)
    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_valid = int(val_frac * n_total)

    train_idx, valid_idx, test_idx = [], [], []
    for group in groups:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(valid_idx) + len(group) <= n_valid:
            valid_idx.extend(group)
        else:
            test_idx.extend(group)

    return train_idx, valid_idx, test_idx

## GNN Model

In [None]:
from torch_geometric.nn import GINEConv, GatedGraphConv, global_mean_pool

In [None]:
class GINEModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_tasks):
        super().__init__()
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
        )
        self.conv2 = GINEConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
        )
        self.gate = GatedGraphConv(hidden_dim, 2)
        self.pool = global_mean_pool
        self.fc = nn.Linear(hidden_dim, num_tasks)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr).relu()
        x = nn.Dropout(0.2)(x)
        x = self.gate(x, edge_index).relu()
        x = self.pool(x, batch)
        return self.fc(x)

## Train the Model

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        if hasattr(batch, 'edge_attr'):
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        else:
            logits = model(batch.x, batch.edge_index, batch.batch)
        mask = (batch.y >= 0).float()
        loss = (criterion(logits, batch.y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

In [None]:
def evaluate(model, loader, device):
    model.eval()
    y, preds = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if hasattr(batch, 'edge_attr'):
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            else:
                logits = model(batch.x, batch.edge_index, batch.batch)
            y.append(batch.y.cpu())
            preds.append(torch.sigmoid(logits).cpu())
    return torch.cat(preds, dim=0).numpy(), torch.cat(y, dim=0).numpy()

## Main

In [None]:
df = pd.read_csv('tox21.csv')
smiles_list = df['smiles'].tolist()
train_idx, valid_idx, test_idx = GetMurckoScaffold(smiles_list, 0.8, 0.1, 0.1)



In [None]:
Tox21Dataset

In [None]:
dataset = Tox21Dataset(path='tox21.csv')

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, valid_idx)
test_ds  = Subset(dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=32)
test_loader  = DataLoader(test_ds,  batch_size=32)

NotImplementedError: 