<a href="https://colab.research.google.com/github/BonanYang/GNN/blob/main/GNN_GCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2025.9.3-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.2 kB)
Downloading rdkit-2025.9.3-cp312-cp312-manylinux_2_28_x86_64.whl (36.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.4/36.4 MB[0m [31m76.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2025.9.3


In [2]:
## pdb generator
# smiles = "CCOc1ccc2nc(S(N)(=O)=O)sc2c1"
# mol = Chem.MolFromSmiles(smiles)
# mol = Chem.AddHs(mol)
# AllChem.EmbedMolecule(mol)
# AllChem.MMFFOptimizeMolecule(mol)

# Chem.MolToPDBFile(mol, '1.pdb')

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import pandas as pd
import numpy as np
from collections import Counter
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

df = pd.read_csv('tox21.csv')

label_cols = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER',
              'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5',
              'SR-HSE', 'SR-MMP', 'SR-p53']

df.head()

Unnamed: 0,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,mol_id,smiles
0,0.0,0.0,1.0,,,0.0,0.0,1.0,0.0,0.0,0.0,0.0,TOX3021,CCOc1ccc2nc(S(N)(=O)=O)sc2c1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3020,CCN1C(=O)NC(c2ccccc2)C1=O
2,,,,,,,,0.0,,0.0,,,TOX3024,CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,0.0,TOX3027,CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,TOX20800,CC(O)(P(=O)(O)O)P(=O)(O)O


In [4]:
all_atoms = Counter()
for smiles in df['smiles']:
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        for atom in mol.GetAtoms():
            all_atoms[atom.GetSymbol()] += 1

print(all_atoms.most_common())

[04:32:04] Explicit valence for atom # 8 Al, 6, is greater than permitted
[04:32:05] Explicit valence for atom # 3 Al, 6, is greater than permitted
[04:32:05] Explicit valence for atom # 4 Al, 6, is greater than permitted
[04:32:05] Explicit valence for atom # 4 Al, 6, is greater than permitted
[04:32:05] Explicit valence for atom # 9 Al, 6, is greater than permitted
[04:32:05] Explicit valence for atom # 5 Al, 6, is greater than permitted
[04:32:05] Explicit valence for atom # 16 Al, 6, is greater than permitted
[04:32:06] Explicit valence for atom # 20 Al, 6, is greater than permitted


[('C', 106440), ('O', 21547), ('N', 10479), ('Cl', 2210), ('S', 1821), ('F', 1628), ('Br', 381), ('P', 269), ('I', 198), ('Si', 94), ('B', 26), ('Sn', 20), ('Hg', 14), ('As', 12), ('Zn', 9), ('Cr', 9), ('Na', 9), ('H', 7), ('Fe', 7), ('Se', 6), ('Cu', 5), ('Au', 5), ('Ca', 4), ('Ba', 4), ('Co', 4), ('Sb', 4), ('Ni', 4), ('K', 3), ('In', 3), ('Cd', 3), ('Ti', 3), ('Mn', 3), ('Pt', 3), ('Mg', 2), ('Zr', 2), ('Gd', 2), ('Li', 2), ('Bi', 2), ('Pd', 1), ('Tl', 1), ('Ag', 1), ('Mo', 1), ('V', 1), ('Nd', 1), ('Yb', 1), ('Pb', 1), ('Dy', 1), ('Sr', 1), ('Be', 1), ('Ge', 1)]


In [5]:
def atom_features(atom):
    atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'P']
    features = [1 if atom.GetSymbol() == t else 0 for t in atom_types]
    features += [
        atom.GetDegree(),
        atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
    ]
    return features

# atom_features('C')

In [6]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    n = mol.GetNumAtoms()
    X = []
    for atom in mol.GetAtoms():
        X.append(atom_features(atom))
    X = torch.tensor(X, dtype=torch.float)
    A = torch.zeros(n, n)
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        A[i, j] = 1
        A[j, i] = 1

    return X, A

In [7]:
class toxData(Dataset):
  def __init__(self,df,label_cols):
    self.Y = df[label_cols].values
    self.X = df['smiles'].values

  def __len__(self):
    return len(self.X)

  def __getitem__(self,idx):
    graph = smiles_to_graph(self.X[idx])
    if graph == None:
      return None
    attr,adj = graph
    y = torch.tensor(self.Y[idx],dtype=torch.float)
    mask = ~torch.isnan(y)
    y = torch.nan_to_num(y)
    return attr,adj,mask,y

def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    X_list, A_list, masks, ys = zip(*batch)
    return X_list, A_list, torch.stack(masks), torch.stack(ys)

data = toxData(df,label_cols)
loader = DataLoader(data, batch_size=4, shuffle=True, collate_fn=collate_fn)
X_list, A_list, mask, y = next(iter(loader))
print(len(X_list), mask.shape, y.shape)

4 torch.Size([4, 12]) torch.Size([4, 12])


In [8]:
X, A = smiles_to_graph(df['smiles'].iloc[3])
print(X.shape, A.shape)

torch.Size([20, 12]) torch.Size([20, 20])


In [9]:
class GCNLayer(nn.Module):
  def __init__(self,in_dim,out_dim):
    super().__init__()
    self.W = nn.Linear(in_dim,out_dim)
    self.B = nn.Linear(in_dim,out_dim)

  def forward(self,X,A):
    # A = A + torch.eye(A.size(0))
    A = A + torch.eye(A.size(0), device=A.device)
    D = torch.diag((A.sum(dim=1)).pow(-0.5))
    A_norm = D@A@D
    return self.W(A_norm@X)+self.B(X)


X, A = smiles_to_graph(df['smiles'].iloc[0])
layer = GCNLayer(12, 32)
H = layer(X, A)
print(H.shape)


torch.Size([16, 32])


In [10]:
class GCN(nn.Module):
  def __init__(self,in_dim,hid_dim,out_dim):
    super().__init__()
    self.gcn1 = GCNLayer(in_dim,hid_dim)
    self.gcn2 = GCNLayer(hid_dim,hid_dim)
    self.classifier = nn.Linear(hid_dim,out_dim)

  def forward(self,X_list,A_list):
    graph_embeds = []
    for X,A in zip(X_list,A_list):
      H = self.gcn1(X,A)
      H = torch.relu(H)
      H = self.gcn2(H,A)
      H = H.mean(dim=0)
      graph_embeds.append(H)

    H = torch.stack(graph_embeds)
    out = self.classifier(H)
    return out

model = GCN(12,64,12)
X_list, A_list, mask, y = next(iter(loader))
print(len(X_list), mask.shape, y.shape)
out = model(X_list, A_list)
print(out.shape)

# X, A = smiles_to_graph(df['smiles'].iloc[0])
# out = model(X, A)
# print(out.shape)

4 torch.Size([4, 12]) torch.Size([4, 12])
torch.Size([4, 12])


In [None]:
from torch.optim import Adam
from sklearn.model_selection import train_test_split
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

train_loader = DataLoader(toxData(train_df, label_cols), batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(toxData(test_df, label_cols), batch_size=32, shuffle=False, collate_fn=collate_fn)

model = GCN(12, 128, 12).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss(reduction='none')

for epoch in range(50):
    model.train()
    train_loss = 0
    for X_list, A_list, mask, y in train_loader:
        X_list = [x.to(device) for x in X_list]
        A_list = [a.to(device) for a in A_list]
        mask, y = mask.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(X_list, A_list)
        loss = (criterion(out, y) * mask).sum() / mask.sum()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss {train_loss/len(train_loader):.4f}")


from sklearn.metrics import roc_auc_score

model.eval()
all_preds, all_labels, all_masks = [], [], []
with torch.no_grad():
    for X_list, A_list, mask, y in test_loader:
        X_list = [x.to(device) for x in X_list]
        A_list = [a.to(device) for a in A_list]
        out = torch.sigmoid(model(X_list, A_list))
        all_preds.append(out.cpu())
        all_labels.append(y)
        all_masks.append(mask)

preds = torch.cat(all_preds)
labels = torch.cat(all_labels)
masks = torch.cat(all_masks)

for i, col in enumerate(label_cols):
    m = masks[:, i].bool()
    if m.sum() > 0:
        auc = roc_auc_score(labels[m, i], preds[m, i])
        print(f"{col}: AUC = {auc:.4f}")

Epoch 1: Loss 0.2685
Epoch 2: Loss 0.2440
Epoch 3: Loss 0.2378
Epoch 4: Loss 0.2333
Epoch 5: Loss 0.2308
Epoch 6: Loss 0.2294
Epoch 7: Loss 0.2274
Epoch 8: Loss 0.2263
Epoch 9: Loss 0.2250
Epoch 10: Loss 0.2250
Epoch 11: Loss 0.2250


In [12]:
# df[label_cols].values,df['smiles'].tolist()[:2]
# df[label_cols].values
# type(df[label_cols].values)
# df['smiles'].values[2]
# len(df['smiles'].values)


In [13]:
# a = torch.tensor([0, 1, float('nan'), 1])
# m = ~torch.isnan(a)
# m

In [14]:
# dataset = Tox21Dataset(df, label_cols)
# X, A, y, mask = dataset[0]
# print(X.shape, A.shape, y.shape, mask.shape)

In [15]:
# model = GCN(in_dim=12, hidden_dim=64, out_dim=12)
# X, A = smiles_to_graph(df['smiles'].iloc[0])
# out = model(X, A)
# print(out.shape)

In [16]:
# # test
# # def atom_features(atom):
# #     atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'P']
# #     features = [1 if atom.GetSymbol() == t else 0 for t in atom_types]
# #     features += [
# #         atom.GetDegree(),
# #         atom.GetFormalCharge(),
# #         int(atom.GetIsAromatic()),
# #     ]
# #     return features
# # s = df['smiles'].iloc[0]
# # mol = Chem.MolFromSmiles(s)
# # for i in mol.GetAtoms():
# #   print(atom_features(i))
# # smiles_to_graph(df['smiles'].iloc[0])[1].shape
# t = smiles_to_graph(df['smiles'].iloc[0])[1]
# print(t)
# a =t.sum(dim=1)
# print()
# # a.shape
# torch.diag(a)