# Message-Passing Neural Net for Blood-Brain Barrier Permeability Prediction

In [1]:
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem

In [3]:
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

## BBBP Dataset

In [10]:
import wget
import os

url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"
if not os.path.exists("BBBP.csv"):
    print("Downloading BBBP.csv...")
    wget.download(url, "BBBP.csv")
    print("\nDownload complete!")
else:
    print("BBBP.csv already exists, skipping download.")

BBBP.csv already exists, skipping download.


In [5]:
df = pd.read_csv("BBBP.csv")
df.head()

Unnamed: 0,num,name,p_np,smiles
0,1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12
1,2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
2,3,40730,1,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...
3,4,24,1,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C
4,5,cloxacillin,1,Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...


In [6]:
class BBBPDataset(InMemoryDataset):
    def __init__(self, csv_path, transform=None):
        super().__init__(root=None, transform=transform, pre_transform=None)
        df = pd.read_csv(csv_path, usecols=["smiles", "p_np"])
        data_list = []
        skipped = 0
        for idx, row in df.iterrows():
            try:
                mol = Chem.MolFromSmiles(row["smiles"])
                if mol is None:
                    skipped += 1
                    continue
                mol = Chem.AddHs(mol)
                result = AllChem.EmbedMolecule(mol, randomSeed=67)
                if result == -1:  # Embedding failed
                    skipped += 1
                    continue
                atom_features = []
                for atom in mol.GetAtoms():
                    element = atom.GetSymbol()
                    degree = atom.GetDegree()
                    valence = atom.GetTotalValence()
                    numH = atom.GetTotalNumHs()
                    feat = [
                        float(element == symbol) for symbol in ["C","N","O","F","P","S","Cl","Br","I"]
                    ] + [degree, valence, numH]
                    atom_features.append(feat)
                x = torch.tensor(atom_features, dtype=torch.float)
                edge_index, edge_attr = [], []
                for bond in mol.GetBonds():
                    i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                    bt = bond.GetBondType()
                    attr = [
                        float(bt == Chem.rdchem.BondType.SINGLE),
                        float(bt == Chem.rdchem.BondType.DOUBLE),
                        float(bt == Chem.rdchem.BondType.TRIPLE),
                        float(bond.GetIsConjugated())
                    ]
                    edge_index += [[i,j],[j,i]]
                    edge_attr += [attr, attr]
                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_attr  = torch.tensor(edge_attr,  dtype=torch.float)
                y = torch.tensor(row["p_np"], dtype=torch.float)
                data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))
            except Exception as e:
                skipped += 1
                continue
        print(f"Successfully processed {len(data_list)} molecules, skipped {skipped}")
        self.data, self.slices = self.collate(data_list)
    
    def len(self):
        return len(self.data.y)
    
    def get(self, idx):
        data = Data(
            x=self.data.x[self.slices['x'][idx]:self.slices['x'][idx+1]],
            edge_index=self.data.edge_index[:, self.slices['edge_index'][idx]:self.slices['edge_index'][idx+1]],
            edge_attr=self.data.edge_attr[self.slices['edge_attr'][idx]:self.slices['edge_attr'][idx+1]],
            y=self.data.y[idx]
        )
        if self.transform is not None:
            data = self.transform(data)
        return data

## EdgeNetwork and MPNN

In [7]:
import torch.nn as nn
from torch_geometric.nn import MessagePassing, global_mean_pool

In [24]:
class EdgeNetwork(MessagePassing):
    def __init__(self, edge_dim, message_dim):
        super().__init__(aggr='add')
        self.message_dim = message_dim
        # Output a message_dim x message_dim weight matrix per edge
        self.lin = nn.Linear(edge_dim, message_dim * message_dim)

    def forward(self, x, edge_index, edge_attr):
        weight = self.lin(edge_attr)                              # [E, 32*32]
        weight = weight.view(-1, self.message_dim, self.message_dim)  # [E, 32, 32]
        return self.propagate(edge_index, x=x, weight=weight)

    def message(self, x_j, weight):
        # x_j shape: [E, 32]
        return (weight @ x_j.unsqueeze(-1)).squeeze(-1)

In [25]:
class MPNN(nn.Module):
    def __init__(self, node_dim=12, edge_dim=4, message_dim=32):
        super().__init__()
        self.message_dim = message_dim
        self.node_lin = nn.Linear(node_dim, message_dim)  # input projection
        self.edge_net = EdgeNetwork(edge_dim, message_dim)
        self.gru = nn.GRUCell(message_dim, message_dim)
        self.readout = nn.Sequential(
            nn.Linear(message_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        h = self.node_lin(x)  # project input features from 12 → 32
        for _ in range(4):    # number of message passing steps
            m = self.edge_net(h, edge_index, edge_attr)
            h = self.gru(m, h)
        hg = global_mean_pool(h, batch)
        return torch.sigmoid(self.readout(hg)).view(-1)

## Training

In [15]:
from torch.optim import Adam
from tqdm import tqdm

In [11]:
dataset = BBBPDataset("BBBP.csv")
print("Dataset loaded with {} molecules".format(len(dataset)))



Successfully processed 2036 molecules, skipped 14
Dataset loaded with 2036 molecules


[13:27:55] UFFTYPER: Unrecognized charge state for atom: 16


In [12]:
train_ds, test_ds = train_test_split(dataset, test_size=0.2, random_state=67)
loader_args = dict(batch_size=32, shuffle=True, num_workers=4)
train_loader = DataLoader(train_ds, **loader_args)
test_loader = DataLoader(test_ds, **loader_args)



In [26]:
model = MPNN(
    node_dim = dataset.num_node_features,
    edge_dim = dataset.num_edge_features,
).to(device)
print(f"Model initialized with node_dim={dataset.num_node_features}, edge_dim={dataset.num_edge_features}")

Model initialized with node_dim=12, edge_dim=4


In [27]:
optimizer = Adam(model.parameters(), lr=5e-4)
criterion = nn.BCELoss()

In [None]:
for epoch in range(1,41):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch)
        loss = criterion(pred, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    print(f"Epoch {epoch} - Loss: {total_loss / len(train_loader.dataset):.4f}")

In [29]:
# Evaluate on test set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred = model(batch)
        all_preds.extend(pred.cpu().numpy())
        all_labels.extend(batch.y.cpu().numpy())

test_auc = roc_auc_score(all_labels, all_preds)
print(f"Test ROC-AUC: {test_auc:.4f}")

Test ROC-AUC: 0.9161


In [30]:
# Save model checkpoint
checkpoint_path = "checkpoints/mpnn_bbbp.pt"
os.makedirs("checkpoints", exist_ok=True)

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'node_dim': dataset.num_node_features,
    'edge_dim': dataset.num_edge_features,
}, checkpoint_path)

print(f"Model checkpoint saved to {checkpoint_path}")


Model checkpoint saved to checkpoints/mpnn_bbbp.pt




## Load the Pre-Trained Model

In [31]:
# To load the checkpoint later:
# checkpoint = torch.load("checkpoints/mpnn_bbbp.pt")
# model = MPNN(
#     node_dim=checkpoint['node_dim'],
#     edge_dim=checkpoint['edge_dim']
# ).to(device)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
