https://ochem.eu/static/challenge.do

# TOX24

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from rdkit import Chem
from sklearn.metrics import mean_squared_error, r2_score

In [21]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    x = torch.tensor([[atom.GetAtomicNum()] for atom in mol.GetAtoms()], dtype=torch.float)

    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_feature = [bond.GetBondTypeAsDouble()] 

        edge_index.append((i, j))
        edge_index.append((j, i))
        edge_attr.append(bond_feature)
        edge_attr.append(bond_feature)

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() if edge_index else torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float) if edge_attr else torch.zeros((edge_index.shape[1], 1), dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [31]:
class GATNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout):
        super(GATNet, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=10, dropout=dropout)
        self.conv2 = GATConv(hidden_channels * 10, hidden_channels, heads=5, dropout=dropout)
        self.fc = nn.Linear(hidden_channels * 5, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_attr = data.edge_attr if hasattr(data, "edge_attr") else None 

        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = torch.relu(x)

        x = global_mean_pool(x, batch)  #Aggregate node features
        return self.fc(x)

In [23]:
datafile = pd.read_csv("tox24_challenge_train.csv")
smiles_list = datafile["SMILES"].to_list()
activity_values = datafile["activity"].to_list()

In [24]:
graph_data = []
for s, y in zip(smiles_list, activity_values):
    graph = smiles_to_graph(s)
    if graph is not None:
        graph.y = torch.tensor([y], dtype=torch.float)
        graph_data.append(graph)

dataloader = DataLoader(graph_data, batch_size=100, shuffle=True)



In [32]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GATNet(in_channels=graph_data[0].x.shape[1], hidden_channels=512, out_channels=1, dropout=0.1).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=7, verbose=True)
loss_fn = nn.MSELoss()


epochs = 100
for epoch in range(epochs):
    total_loss = 0
    all_preds, all_targets = [], []
    
    model.train()
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        output = model(batch).squeeze() 
        loss = loss_fn(output, batch.y.to(device))
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        all_preds.extend(output.cpu().detach().numpy())
        all_targets.extend(batch.y.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    scheduler.step(avg_loss)

    rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
    r2 = r2_score(all_targets, all_preds)

    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | RMSE: {rmse:.4f} | R²: {r2:.4f}")

Epoch 1/100 | Loss: 2065.3941 | RMSE: 46.0781 | R²: -0.5970
Epoch 2/100 | Loss: 1494.2948 | RMSE: 39.5183 | R²: -0.1747
Epoch 3/100 | Loss: 1488.1286 | RMSE: 38.5549 | R²: -0.1181
Epoch 4/100 | Loss: 1535.3967 | RMSE: 39.7381 | R²: -0.1878
Epoch 5/100 | Loss: 1530.6414 | RMSE: 39.5205 | R²: -0.1748
Epoch 6/100 | Loss: 1588.5416 | RMSE: 40.0372 | R²: -0.2058
Epoch 7/100 | Loss: 1469.0220 | RMSE: 37.4441 | R²: -0.0546
Epoch 8/100 | Loss: 1519.9440 | RMSE: 37.8180 | R²: -0.0758
Epoch 9/100 | Loss: 1445.8883 | RMSE: 36.9990 | R²: -0.0297
Epoch 10/100 | Loss: 1441.6210 | RMSE: 37.0494 | R²: -0.0325
Epoch 11/100 | Loss: 1492.8501 | RMSE: 38.3745 | R²: -0.1077
Epoch 12/100 | Loss: 1695.4257 | RMSE: 40.7676 | R²: -0.2501


KeyboardInterrupt: 