# Develop a GNN model for ADMET prediction

## Analyse the TDC Dataset

In [3]:
from tdc.benchmark_group import admet_group

In [4]:
group = admet_group(path = 'data/')
benchmark = group.get('Caco2_Wang')

Found local copy...


In [5]:
benchmark['train_val']

Unnamed: 0,Drug_ID,Drug,Y
0,H 95/71,CC(C)NCC(O)COc1ccc(NC=O)cc1,-5.427984
1,H 244/45,CCC(=O)Nc1ccc(OCC(O)CNC(C)C)cc1,-5.219842
2,D-Phe-D-Ala-D-Ser-OH,C[C@H](NC(=O)[C@H](N)Cc1ccccc1)C(=O)N[C@H](CO)...,-6.281999
3,Dexloxiglumide,CCCCCN(CCCOC)C(=O)[C@@H](CCC(=O)O)NC(=O)c1ccc(...,-5.140131
4,Ac-D-phe-NH2,CC(=O)N[C@@H](Cc1ccccc1)C(N)=O,-5.100090
...,...,...,...
723,11,CCCCCCC(N)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N1CCC...,-5.790000
724,Gancyclovir,Nc1nc2c(ncn2COC(CO)CO)c(=O)[nH]1,-6.101228
725,Val-ACV,CC(C)C(N)C(=O)OCCOCn1cnc2c(=O)[nH]c(N)nc21,-5.669776
726,1033-Dextromethorphan (DEM),COc1ccc2c(c1)[C@@]13CCCC[C@@H]1[C@@H](C2)N(C)CC3,-4.628932


In [6]:
from tdc.single_pred import ADME
data = ADME(name = 'Caco2_Wang')
split = data.get_split()

Found local copy...
Loading...


Done!


In [7]:
split['valid']

Unnamed: 0,Drug_ID,Drug,Y
0,Raloxifene HCl,O=C(c1ccc(OCCN2CCCCC2)cc1)c1c(-c2ccc(O)cc2)sc2...,-5.722754
1,13,CCOC(=O)c1ccc2c(C(C(=O)NS(=O)(=O)c3ccc(C)cc3OC...,-4.699485
2,5,N#Cc1ccc(NCC(F)(F)c2ccccc2)c(F)c1CC(=O)NCCONC(...,-5.647924
3,-,O=C(O)c1ccncc1,-5.190000
4,4b,Cc1cc(C(=O)Nc2ccc(-c3ccccc3S(N)(=O)=O)cc2F)n(-...,-6.000000
...,...,...,...
86,atropine,CN1[C@H]2CC[C@@H]1CC(OC(=O)C(CO)c1ccccc1)C2,-4.700000
87,Guanabenz,NC(N)=NN=Cc1c(Cl)cccc1Cl,-4.330000
88,4,CN(C(=O)[C@H](Cc1ccc(CN)cc1)NS(=O)(=O)c1ccc2cc...,-4.958607
89,20(S)-camptothecin (CPT),CC[C@]1(O)C(=O)OCc2c1cc1n(c2=O)-c2cc3ccccc3nc2C1,-4.331849


In [8]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

In [9]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Node features: Atom properties (e.g., atomic number)
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append([atom.GetAtomicNum()])
    x = torch.tensor(atom_features, dtype=torch.float)
    
    # Edge indices: Bonds between atoms
    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])  # Undirected graph
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)

In [10]:
def prepare_data():
    # Load TDC ADMET Benchmark Group
    group = admet_group(path='data/')
    datasets = group.dataset_names  # 22 datasets
    
    # Task types (example mapping, adjust per TDC documentation)
    task_types = ['regression'] * 10 + ['classification'] * 12  # 10 reg, 12 class
    
    # Combine training data
    train_data = []
    for i, dataset_name in enumerate(datasets):
        dataset = group.get(dataset_name)
        train_df = dataset['train_val']
        for _, row in train_df.iterrows():
            graph = smiles_to_graph(row['Drug'])
            if graph is not None:
                graph.task_id = i  # Assign task_id
                graph.task_type = 'regression' if type(row['Y']) == float else 'classification'
                graph.y = torch.tensor([row['Y']], dtype=torch.float if type(row['Y']) == float else torch.int)
                train_data.append(graph)
    
    return train_data, task_types

In [11]:
train_data, task_types = prepare_data()

Found local copy...


In [38]:
# Count the number of regression and classification tasks from train_data.
# Iterate through `train_data` and count based on 'task_type' attribute.
# Also, store the task types in a list in the order of the tasks, once per task_id.
task_types_list = []
num_regression_tasks = 0
num_classification_tasks = 0
seen_task_ids = set()

for data in train_data:
    if data.task_id not in seen_task_ids:
        task_types_list.append(data.task_type)
        seen_task_ids.add(data.task_id)
        if data.task_type == 'regression':
            num_regression_tasks += 1
        elif data.task_type == 'classification':
            num_classification_tasks += 1

print(f"Number of regression tasks: {num_regression_tasks}")
print(f"Number of classification tasks: {num_classification_tasks}")
print(f"Task types list: {task_types_list}")

# Create a Polars DataFrame for the training data. This part is moved from the original cell 12
# because it caused an error. The error was likely due to the fact that the train_data was
# not defined in the original cell 12.




import polars as pl

# Extract relevant data from the list of Data objects
data_list = []
for data_obj in train_data:
    data_dict = {
        'x': data_obj.x.tolist(),  # Convert node features to list
        'edge_index': data_obj.edge_index.tolist(),  # Convert edge indices to list
        'task_id': data_obj.task_id,
        'task_type': data_obj.task_type,
        'y': data_obj.y.item() # Get the scalar value from the tensor
    }
    data_list.append(data_dict)

# Create the Polars DataFrame
df = pl.DataFrame(data_list)

# Print or further process the DataFrame
print(df)

Number of regression tasks: 11
Number of classification tasks: 11
Task types list: ['regression', 'classification', 'classification', 'classification', 'regression', 'regression', 'classification', 'regression', 'regression', 'classification', 'classification', 'classification', 'classification', 'classification', 'classification', 'regression', 'regression', 'regression', 'regression', 'classification', 'regression', 'regression']
shape: (65_430, 5)
┌─────────────────────────┬──────────────────────────────┬─────────┬────────────┬───────────┐
│ x                       ┆ edge_index                   ┆ task_id ┆ task_type  ┆ y         │
│ ---                     ┆ ---                          ┆ ---     ┆ ---        ┆ ---       │
│ list[list[f64]]         ┆ list[list[i64]]              ┆ i64     ┆ str        ┆ f64       │
╞═════════════════════════╪══════════════════════════════╪═════════╪════════════╪═══════════╡
│ [[6.0], [6.0], … [6.0]] ┆ [[0, 1, … 9], [1, 0, … 17]]  ┆ 0       ┆ regres

In [13]:
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

## Build the GNN Encoder 

Initially let's try with a two GCN layer to learn the node and edge features. 

In [29]:
class GNNEncoder(nn.Module):
    def __init__(
            self,
            input_dim: int,
            hidden_dim: int,
    ):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        return x

In [37]:
gnn_encoder = GNNEncoder(input_dim=1, hidden_dim=64)

for batch in train_loader:
    embeddings = gnn_encoder(batch)
    print(embeddings)
    print(embeddings.shape)


tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2539, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2741, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.3426, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.3398, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.3725, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2488, 0.0000]],
       grad_fn=<DivBackward0>)
torch.Size([32, 64])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.5403, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2726, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.3238, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.7396, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2513, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2503, 0.0000]],
       grad_fn=<DivBackward0>)
torch.Size([32, 64])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.4925, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..

## Build the Multi-Task Decoder 

In [43]:
class MultiTaskDecoder(nn.Module):
    def __init__(
            self,
            hidden_dim: int,
            task_types: list[str],
    ):
        super(MultiTaskDecoder, self).__init__()
        self.task_types = task_types
        self.heads = nn.ModuleList(
            [nn.Linear(hidden_dim, 1) if task_type == 'regression' else nn.Linear(hidden_dim, 2) for task_type in task_types]
        )

    def forward(self, x: torch.Tensor, task_ids: list[int]) -> list[torch.Tensor]: 
        outputs = []
        for i, task_id in enumerate(task_ids):
            head = self.heads[task_id]
            out = head(x[i:i+1])
            print(out.shape)
            if self.task_types[task_id] == 'classification':
                out = F.softmax(out, dim=1)
            outputs.append(out)
        return outputs


In [49]:
decoder = MultiTaskDecoder(hidden_dim=64, task_types=task_types_list)

for batch in train_loader:
    task_ids = batch.task_id
    embeddings = gnn_encoder(batch)
    outputs = decoder(embeddings, task_ids)



torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([

In [46]:
outputs

[tensor([[0.7116, 0.2884]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.3996]], grad_fn=<AddmmBackward0>),
 tensor([[0.4778, 0.5222]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.0328]], grad_fn=<AddmmBackward0>),
 tensor([[0.5778, 0.4222]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.7124, 0.2876]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.4734, 0.5266]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.4713, 0.5287]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.4738, 0.5262]], grad_fn=<SoftmaxBackward0>),
 tensor([[-0.5252]], grad_fn=<AddmmBackward0>),
 tensor([[0.7964]], grad_fn=<AddmmBackward0>),
 tensor([[-0.8853]], grad_fn=<AddmmBackward0>),
 tensor([[0.5317]], grad_fn=<AddmmBackward0>),
 tensor([[0.7185, 0.2815]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.0295]], grad_fn=<AddmmBackward0>),
 tensor([[0.5739, 0.4261]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.0510]], grad_fn=<AddmmBackward0>),
 tensor([[0.4721, 0.5279]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.4735, 0.5265]], grad_fn=<SoftmaxBackward

## Put it together to create the Multi-Task GNN Model 

In [51]:
class MultiTaskGNN(nn.Module):
    def __init__(
            self,
            input_dim: int,
            hidden_dim: int,
            task_types: list[str],
    ) -> None:
        super(MultiTaskGNN, self).__init__()
        self.encoder = GNNEncoder(input_dim, hidden_dim)
        self.decoder = MultiTaskDecoder(hidden_dim, task_types)

    def forward(self, data: Data) -> list[torch.Tensor]:
        embeddings = self.encoder(data)
        outputs = self.decoder(embeddings, data.task_id)
        return outputs

In [53]:
mtGNN = MultiTaskGNN(input_dim=1, hidden_dim=64, task_types=task_types_list)

for batch in train_loader:
    outputs = mtGNN(batch)


torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([

In [54]:
outputs

[tensor([[0.4382, 0.5618]], grad_fn=<SoftmaxBackward0>),
 tensor([[-0.1800]], grad_fn=<AddmmBackward0>),
 tensor([[0.4350, 0.5650]], grad_fn=<SoftmaxBackward0>),
 tensor([[-0.2055]], grad_fn=<AddmmBackward0>),
 tensor([[0.6404, 0.3596]], grad_fn=<SoftmaxBackward0>),
 tensor([[1.0563]], grad_fn=<AddmmBackward0>),
 tensor([[0.4384, 0.5616]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.6592]], grad_fn=<AddmmBackward0>),
 tensor([[0.4388, 0.5612]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.5568, 0.4432]], grad_fn=<SoftmaxBackward0>),
 tensor([[-0.1807]], grad_fn=<AddmmBackward0>),
 tensor([[0.4374, 0.5626]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.4354, 0.5646]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.4336, 0.5664]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.6361, 0.3639]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.7100, 0.2900]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.6520, 0.3480]], grad_fn=<SoftmaxBackward0>),
 tensor([[0.5017]], grad_fn=<AddmmBackward0>),
 tensor([[0.9477]], grad_fn=<Ad

## Build a Training Step 


In [61]:
def train_step(
        model: nn.Module,
        train_dataloader: DataLoader,
        optimizer: torch.optim.Optimizer,
        n_epochs: int,
        device: torch.device,
        regression_loss_fn: nn.Module,
        classification_loss_fn: nn.Module,
) -> None:
    model = model.to(device)
    model.train()
    for epoch in range(n_epochs):
        total_loss = 0
        n_batches = 0
        for batch in train_dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()
            task_ids = batch.task_id
            y = batch.y
            outputs = model(batch)
            loss = 0
            for i, out in enumerate(outputs):
                task_id = task_ids[i]
                y_i = y[i]
                if task_types[task_id] == 'regression':
                    loss += regression_loss_fn(out, y_i)
                else:
                    loss += classification_loss_fn(out, y_i.long().unsqueeze(0))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            n_batches += 1
        avg_loss = total_loss / n_batches
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")

In [62]:
from utils import set_device, set_seed
device = set_device()

In [63]:
set_seed()
train_step(
    model=mtGNN,
    train_dataloader=train_loader,
    optimizer=torch.optim.Adam(mtGNN.parameters(), lr=0.001),
    n_epochs=10,
    device=device,
    regression_loss_fn=nn.MSELoss(),
    classification_loss_fn=nn.CrossEntropyLoss(),
)

torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([

KeyboardInterrupt: 