# Message-Passing Neural Net for Blood-Brain Barrier Permeability Prediction

In [1]:
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [2]:
from rdkit import Chem
from rdkit.Chem import AllChem

In [4]:
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset

## BBBP Dataset

In [5]:
import wget

url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"
wget.download(url, "BBBP.csv")

'BBBP.csv'

In [6]:
df = pd.read_csv("BBBP.csv")
df.head()

Unnamed: 0,num,name,p_np,smiles
0,1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12
1,2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
2,3,40730,1,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...
3,4,24,1,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C
4,5,cloxacillin,1,Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...


In [9]:
class BBBPDataset(InMemoryDataset):
    def __init__(self, csv_path, transform=None):
        super().__init__(root=None, transform=transform, pre_transform=None)
        df = pd.read_csv(csv_path, usecols=["smiles", "p_np"])
        data_list = []
        skipped = 0
        for idx, row in df.iterrows():
            try:
                mol = Chem.MolFromSmiles(row["smiles"])
                if mol is None:
                    skipped += 1
                    continue
                mol = Chem.AddHs(mol)
                result = AllChem.EmbedMolecule(mol, randomSeed=67)
                if result == -1:  # Embedding failed
                    skipped += 1
                    continue
                atom_features = []
                for atom in mol.GetAtoms():
                    element = atom.GetSymbol()
                    degree = atom.GetDegree()
                    valence = atom.GetTotalValence()
                    numH = atom.GetTotalNumHs()
                    feat = [
                        float(element == symbol) for symbol in ["C","N","O","F","P","S","Cl","Br","I"]
                    ] + [degree, valence, numH]
                    atom_features.append(feat)
                x = torch.tensor(atom_features, dtype=torch.float)
                edge_index, edge_attr = [], []
                for bond in mol.GetBonds():
                    i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                    bt = bond.GetBondType()
                    attr = [
                        float(bt == Chem.rdchem.BondType.SINGLE),
                        float(bt == Chem.rdchem.BondType.DOUBLE),
                        float(bt == Chem.rdchem.BondType.TRIPLE),
                        float(bond.GetIsConjugated())
                    ]
                    edge_index += [[i,j],[j,i]]
                    edge_attr += [attr, attr]
                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_attr  = torch.tensor(edge_attr,  dtype=torch.float)
                y = torch.tensor(row["p_np"], dtype=torch.float)
                data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))
            except Exception as e:
                skipped += 1
                continue
        print(f"Successfully processed {len(data_list)} molecules, skipped {skipped}")
        self.data, self.slices = self.collate(data_list)
    
    def len(self):
        return len(self.data.y)
    
    def get(self, idx):
        data = Data(
            x=self.data.x[self.slices['x'][idx]:self.slices['x'][idx+1]],
            edge_index=self.data.edge_index[:, self.slices['edge_index'][idx]:self.slices['edge_index'][idx+1]],
            edge_attr=self.data.edge_attr[self.slices['edge_attr'][idx]:self.slices['edge_attr'][idx+1]],
            y=self.data.y[idx]
        )
        if self.transform is not None:
            data = self.transform(data)
        return data

In [None]:
bbbp = BBBPDataset("BBBP.csv")

## MPNN EdgeNetwork

In [12]:
import torch.nn as nn
from torch_geometric.nn import MessagePassing, global_mean_pool

In [13]:
class EdgeNetwork(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr="add")
        self.lin1 = nn.Linear(in_channels, in_channels * out_channels)
        self.out_channels = out_channels

    def forward(self, x, edge_index, edge_attr):
        weight = self.lin1(edge_attr)
        weight = weight.view(-1, x.size(1), self.out_channels)
        return self.propagate(edge_index, x=x, weight=weight)

    def message(self, x_j, weight):
        return (weight @ x_j.unsqueeze(-1)).squeeze(-1)

In [14]:
class MPNN(nn.Module):
    def __init__(
        self,
        node_dim,
        edge_dim,
        message_dim: int=64,
        steps: int=4,
        hidden_dim: int=128,
    ):
        super(MPNN, self).__init__()
        self.steps = steps
        self.edge_network = EdgeNetwork(edge_dim, message_dim)
        self.gru = nn.GRUCell(message_dim, message_dim)
        self.lin1 = nn.Linear(node_dim, message_dim)
        self.readout = nn.Sequential(
            nn.Linear(message_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.lin1(x)
        for _ in range(self.steps):
            m = self.edge_network(x, edge_index, edge_attr)
            h = self.gru(m, h)
        hg = global_mean_pool(h, batch)
        return torch.sigmoid(self.readout(hg)).view(-1)

## Training