<a href="https://colab.research.google.com/github/Raiden-Makoto/ChemicalsRTasty/blob/main/ChemicalsRTasty.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip -q install rdkit-pypi torch_geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[?25h

Cell below must be run twice for some obscure reason

In [3]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset

from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.loader import DataLoader
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.metrics import roc_auc_score

In [4]:
!wget wget https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
!gunzip tox21.csv.gz

--2025-05-22 01:28:16--  http://wget/
Resolving wget (wget)... failed: Name or service not known.
wget: unable to resolve host address ‘wget’
--2025-05-22 01:28:16--  https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/tox21.csv.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125310 (122K) [application/octet-stream]
Saving to: ‘tox21.csv.gz’


2025-05-22 01:28:16 (12.2 MB/s) - ‘tox21.csv.gz’ saved [125310/125310]

FINISHED --2025-05-22 01:28:16--
Total wall clock time: 0.3s
Downloaded: 1 files, 122K in 0.01s (12.2 MB/s)


In [5]:
!mkdir -p raw
!mv tox21.csv raw/

## Dataset

In [6]:
from rdkit.Chem import MolFromSmiles, MolToSmiles

In [26]:
import os

from torch.serialization import safe_globals, add_safe_globals
from torch_geometric.data.data import DataTensorAttr, DataEdgeAttr
from torch_geometric.data.storage import GlobalStorage

In [28]:
add_safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage])

In [29]:
class Tox21Dataset(InMemoryDataset):
    def __init__(self, root: str = '.', transform=None, pre_transform=None):
        """
        Expects:
        root/
            raw/tox21.csv
        Will create:
            processed/data.pt
        """
        super().__init__(root, transform, pre_transform)
        # Load the processed data
        with safe_globals([DataTensorAttr, DataEdgeAttr, GlobalStorage]):
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        # File expected in root/raw/
        return ['tox21.csv']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # No download step needed; CSV is already in place
        return

    def process(self):
        df = pd.read_csv(self.raw_paths[0])
        if "mol_id" in df.columns:
            df = df.drop(columns=["mol_id"])

        # 3) convert all non-smiles columns to numeric
        non_smiles = [c for c in df.columns if c != "smiles"]
        df[non_smiles] = df[non_smiles].apply(pd.to_numeric, errors="coerce")

        # 4) drop any row with NaN in smiles or any label
        df = df.dropna(how="any", axis=0).reset_index(drop=True)

        # 5) prepare list of label columns
        label_cols = non_smiles

        data_list = []
        for _, row in df.iterrows():
            # a) get the SMILES string
            smiles = row["smiles"]
            molecule = MolFromSmiles(smiles, sanitize=True)
            if not molecule:
                continue

            # b) build atom features & edge_index (as before)…
            atom_features = [
                [
                    atom.GetAtomicNum(),
                    atom.GetDegree(),
                    atom.GetFormalCharge(),
                    atom.GetNumRadicalElectrons(),
                ]
                for atom in molecule.GetAtoms()
            ]
            x = torch.tensor(atom_features, dtype=torch.float32)

            edge_index = []
            for bond in molecule.GetBonds():
                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                edge_index.extend([[i, j], [j, i]])
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

            # c) pull out just the label columns, now all float dtype
            y_vals = row[label_cols].values.astype(float)  # a float64 array
            y = torch.tensor(y_vals, dtype=torch.float32)

            data_list.append(Data(x=x, edge_index=edge_index, y=y))

        data, slices = self.collate(data_list)
        os.makedirs(self.processed_dir, exist_ok=True)
        torch.save((data, slices), self.processed_paths[0])
        #return data, slices


In [30]:
dataset = Tox21Dataset('.')

A **Murcko Scaffold** is a technique used to group molecules based on their core structural components. It identifies the essential building blocks by removing side chains and other non-core components, leaving behind the important ring systems and connecting chains. This method is widely used in medicinal chemistry and drug design to identify core structures that have preferential activity against specific targets, which is very useful in **molecular property prediction**.

In [31]:
def GetMurckoScaffold(data, train_frac: float, val_frac: float, test_frac: float):
    scaffold_indices = {}
    for idx, smiles in enumerate(data):
        scaffold_smiles = MurckoScaffold.MurckoScaffoldSmiles(smiles)
        #scaffold_smiles = MolToSmiles(scaffold)
        scaffold_indices.setdefault(scaffold_smiles, []).append(idx)

    groups = sorted(scaffold_indices.values(), key=len, reverse=True)
    n_total = len(data)
    n_train = int(train_frac * n_total)
    n_valid = int(val_frac * n_total)

    train_idx, valid_idx, test_idx = [], [], []
    for group in groups:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(valid_idx) + len(group) <= n_valid:
            valid_idx.extend(group)
        else:
            test_idx.extend(group)

    return train_idx, valid_idx, test_idx

## GNN Model

In [32]:
from torch_geometric.nn import GINEConv, GatedGraphConv, global_mean_pool

In [33]:
class GINEModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_tasks):
        super().__init__()
        self.conv1 = GINEConv(
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
        )
        self.conv2 = GINEConv(
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
        )
        self.gate = GatedGraphConv(hidden_dim, 2)
        self.pool = global_mean_pool
        self.fc = nn.Linear(hidden_dim, num_tasks)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr).relu()
        x = nn.Dropout(0.2)(x)
        x = self.gate(x, edge_index).relu()
        x = self.pool(x, batch)
        return self.fc(x)

## Train the Model

In [34]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        if hasattr(batch, 'edge_attr'):
            logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        else:
            logits = model(batch.x, batch.edge_index, batch.batch)
        mask = (batch.y >= 0).float()
        loss = (criterion(logits, batch.y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

In [35]:
def evaluate(model, loader, device):
    model.eval()
    y, preds = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if hasattr(batch, 'edge_attr'):
                logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            else:
                logits = model(batch.x, batch.edge_index, batch.batch)
            y.append(batch.y.cpu())
            preds.append(torch.sigmoid(logits).cpu())
    return torch.cat(preds, dim=0).numpy(), torch.cat(y, dim=0).numpy()

## Main

In [37]:
df = pd.read_csv('raw/tox21.csv').dropna(inplace=True)
smiles_list = df['smiles'].tolist()
train_idx, valid_idx, test_idx = GetMurckoScaffold(smiles_list, 0.8, 0.1, 0.1)



In [38]:
df.head()

Unnamed: 0,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,mol_id,smiles
0,0.0,0.0,1.0,,,0.0,0.0,1.0,0.0,0.0,0.0,0.0,TOX3021,CCOc1ccc2nc(S(N)(=O)=O)sc2c1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3020,CCN1C(=O)NC(c2ccccc2)C1=O
2,,,,,,,,0.0,,0.0,,,TOX3024,CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3027,CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C
4,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,TOX3028,CC(O)(P(=O)(O)O)P(=O)(O)O


In [39]:
train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, valid_idx)
test_ds  = Subset(dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=32)
test_loader  = DataLoader(test_ds,  batch_size=32)

In [41]:
sample = dataset[0]
in_channels = sample.x.size(1)
num_tasks   = sample.y.size(0)

In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [45]:
model = GINEModel(in_channels, 300, num_tasks).to(device)

optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-5,
)

criterion = nn.BCEWithLogitsLoss(reduction='none')

In [46]:
for epoch in range(1, 11):
    loss = train_epoch(model, train_loader, optimizer, criterion, device)
    ys_val, ps_val = evaluate(model, val_loader, device)
    aucs = []
    for i in range(ps_val.size(1)):
        mask = ys_val[:, i] >= 0
        if mask.sum() > 0:
            aucs.append(roc_auc_score(ys_val[mask, i], ps_val[mask, i]))
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val AUC: {np.mean(aucs):.3f}")

IndexError: range object index out of range