In [3]:
import os
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

import torch_geometric.nn as gnn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from sklearn.model_selection import train_test_split

In [21]:
from ogb.utils.features import (allowable_features, atom_to_feature_vector,
 bond_to_feature_vector, atom_feature_vector_to_dict, bond_feature_vector_to_dict) 
from rdkit import Chem
import numpy as np

def mol2graph(mol):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """

    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_to_feature_vector(atom))
    x = np.array(atom_features_list, dtype = np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype = np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype = np.int64)

    else:   # mol has no bonds
        edge_index = np.empty((2, 0), dtype = np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

    return x, edge_attr, edge_index

In [31]:
def get_coordinate_features(mol):
    conf = mol.GetConformer()
    return conf.GetPositions()

def get_mol_data(prefix, y=None):
    if prefix.startswith("train"):
        what_set = "train_set"
    else:
        what_set = "test_set"
    ex = Chem.MolFromMolFile(f"../data/mol_files/{what_set}/{prefix}_ex.mol", removeHs=False)
    g = Chem.MolFromMolFile(f"../data/mol_files/{what_set}/{prefix}_g.mol", removeHs=False)
    
    # Atom features
    X, edge_attr, edge_index = mol2graph(ex)
    
    # Atom 3D coordinates
    co_ex = get_coordinate_features(ex)
    co_g = get_coordinate_features(g)
            
    X = np.concatenate([X, co_ex, co_g], axis=1)
    X = torch.tensor(X, dtype=torch.float)
    
    y = torch.tensor([y], dtype=torch.float)
            
    return Data(x=X, edge_index=edge_index, edge_attr=edge_attr, y=y)
        

def get_dataset(df):
    data_list = []
    if "Reorg_g" in df.columns:
        for _, item in tqdm(df.iterrows()):
            y = [item.Reorg_g, item.Reorg_ex]
            data = get_mol_data(item[0], y)
            data_list.append(data)
    else:
        for _, item in tqdm(df.iterrows()):
            data = get_mol_data(item[0])
            data_list.append(data)
        
    return data_list

In [27]:
def one_hot_encoding(x, permitted_list):
    """
    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.
    """

    if x not in permitted_list:
        x = permitted_list[-1]

    binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]

    return binary_encoding

def get_atom_features(atom):
    """
    Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output.
    """

    # define list of permitted atoms
    
    permitted_list_of_atoms =  ['B', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', 'Si', "H"]
    
    # compute atom features
    
    atom_feature_vector = []
    
    atom_feature_vector += one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)
    
    atom_feature_vector += one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
    
    #atom_feature_vector += one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
    
    #atom_feature_vector += [int(atom.IsInRing())]
    
    #atom_feature_vector += [int(atom.GetIsAromatic())]

    return np.array(atom_feature_vector)


def get_bond_features(bond):
    """
    Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output.
    """

    permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
    
    bond_feature_vector = []
    
    bond_feature_vector += one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)
    
    bond_feature_vector += one_hot_encoding(str(bond.GetBondDir()), ["NONE", "ENDUPRIGHT", "ENDDOWNRIGHT"])
    
    #bond_feature_vector += [int(bond.GetIsConjugated())]
    
    #bond_feature_vector += [int(bond.IsInRing())]
    
    #bond_feature_vector += one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])

    return np.array(bond_feature_vector)

def get_node_features(mol):
    # get feature dimensions
    X = []
    for atom in mol.GetAtoms():
        X.append(get_atom_features(atom))
    
    return np.array(X)

def get_edge_features(mol):
        # get feature dimensions
    X = []
    for atom in mol.GetBonds():
        X.append(get_bond_features(atom))
    
    return np.array(X)

def get_coordinate_features(mol):
    conf = mol.GetConformer()
    return conf.GetPositions()

def get_mol_data(prefix, y=None):
    if prefix.startswith("train"):
        what_set = "train_set"
    else:
        what_set = "test_set"
    ex = Chem.MolFromMolFile(f"../data/mol_files/{what_set}/{prefix}_ex.mol", removeHs=False)
    g = Chem.MolFromMolFile(f"../data/mol_files/{what_set}/{prefix}_g.mol", removeHs=False)
    
    # Atom features
    node_X = get_node_features(ex)
    
    # Atom 3D coordinates
    co_ex = get_coordinate_features(ex)
    co_g = get_coordinate_features(g)
    
    # Adjacency matrix
    (rows, cols) = np.nonzero(GetAdjacencyMatrix(ex))
    torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
    torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
    edge_index = torch.stack([torch_rows, torch_cols], dim = 0)
    
    # Bond features
    edge_attr = []

    for i, j in zip(rows, cols):
        edge_attr.append(get_bond_features(ex.GetBondBetweenAtoms(int(i),int(j))))

    edge_attr = torch.tensor(np.array(edge_attr), dtype = torch.float)
    
    if type(y) != type(None):
            y = torch.tensor(np.array([y]), dtype=torch.float)
            
    X = np.concatenate([node_X, co_ex, co_g], axis=1)
    X = torch.tensor(X, dtype=torch.float)
            
    return Data(x=X, edge_index=edge_index, edge_attr=edge_attr, y=y)
        

def get_dataset(df):
    data_list = []
    if "Reorg_g" in df.columns:
        for _, item in tqdm(df.iterrows()):
            y = [item.Reorg_g, item.Reorg_ex]
            data = get_mol_data(item[0], y)
            data_list.append(data)
    else:
        for _, item in tqdm(df.iterrows()):
            data = get_mol_data(item[0])
            data_list.append(data)
        
    return data_list

In [None]:
train_df = pd.read_csv("../data/train_set.ReorgE.csv")
test_df = pd.read_csv("../data/test_set.csv")

train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)

train_data = get_dataset(train_df)
val_data = get_dataset(val_df)
test_data = get_dataset(test_df)

14525it [00:38, 377.04it/s]
1942it [00:04, 391.04it/s]

In [24]:
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size=32, shuffle=False, drop_last=False, num_workers=4, pin_memory=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, drop_last=False, num_workers=4, pin_memory=True)

In [26]:
next(iter(train_dataloader))

DataBatch(x=[1468, 15], edge_index=[32], edge_attr=[32], y=[32], batch=[1468], ptr=[33])

In [25]:
train_data[0]

Data(x=[57, 15], edge_index=[2, 122], edge_attr=[122, 3], y=[2])

In [26]:
class GATNet(nn.Module):
    def __init__(self, input_dim=15, out_dim=2):
        super().__init__()
        self.main = gnn.Sequential("x, edge_index, edge_attr, batch", [
            (gnn.GATv2Conv(in_channels=input_dim, out_channels=64, heads=4, edge_dim=7), "x, edge_index, edge_attr -> x"),
            nn.ReLU(inplace=True),
            (gnn.GATv2Conv(in_channels=256, out_channels=64, heads=4, edge_dim=7), "x, edge_index, edge_attr -> x"),
            nn.ReLU(inplace=True),
            (gnn.GATv2Conv(in_channels=256, out_channels=64, heads=4, edge_dim=7), "x, edge_index, edge_attr -> x"),
            nn.ReLU(inplace=True),
            (gnn.global_mean_pool, "x, batch -> x"),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, out_dim)
        ])
    
    def forward(self, x, edge_index, edge_attr, batch):
        return self.main(x, edge_index, edge_attr, batch)
    
class GINNet(nn.Module):
    def __init__(self, input_dim=19, out_dim=2):
        super().__init__()
        self.main = gnn.Sequential("x, edge_index, edge_attr, batch", [
            (gnn.GINEConv(nn.Linear(input_dim, 256), edge_dim=7), "x, edge_index, edge_attr -> x"),
            nn.ReLU(inplace=True),
            (gnn.GINEConv(nn.Linear(256, 256), edge_dim=7), "x, edge_index, edge_attr -> x"),
            nn.ReLU(inplace=True),
            (gnn.global_mean_pool, "x, batch -> x"),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, out_dim)
        ])
    
    def forward(self, x, edge_index, edge_attr, batch):
        return self.main(x, edge_index, edge_attr, batch)


class TwoNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net_ex = GATNet()
        self.net_g = GATNet()
        
    def forward(self, batch):
        G_e = self.net_g(batch.x_ex, batch.edge_index, batch.edge_attr, batch.batch)
        G_g = self.net_g(batch.x_g, batch.edge_index, batch.edge_attr, batch.batch)
        G = G_e - G_g
        
        E_e = self.net_ex(batch.x_ex, batch.edge_index, batch.edge_attr, batch.batch)
        E_g = self.net_ex(batch.x_g, batch.edge_index, batch.edge_attr, batch.batch)
        E = E_g - E_e
        
        return torch.cat([G, E], dim=1)

In [27]:
num_epochs = 30

model = GINNet(22, 2)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs * len(train_dataloader))
device = torch.device("cuda:1")

model.to(device)


for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    train_loss, val_loss = 0., 0.
    
    # train
    model.train()

    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()

        batch = batch.to(device)
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        loss = criterion(pred, batch.y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
    
    train_loss /= len(train_dataloader)
    
    # validation
    model.eval()
    
    for batch in val_dataloader:
        batch = batch.to(device)
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        loss = criterion(pred, batch.y)
        val_loss += loss * len(batch.y)
    
    val_loss /= len(val_data)
    
    print(f"Train Loss: {train_loss}")
    print(f"Val Loss: {val_loss}")

Epoch 0


100%|██████████| 226/226 [00:03<00:00, 64.82it/s]


Train Loss: 0.20342496080340539
Val Loss: 0.11379211395978928
Epoch 1


100%|██████████| 226/226 [00:03<00:00, 63.66it/s]


Train Loss: 0.10948931112621738
Val Loss: 0.10087671875953674
Epoch 2


100%|██████████| 226/226 [00:03<00:00, 61.54it/s]


Train Loss: 0.09700299028010495
Val Loss: 0.09784424304962158
Epoch 3


100%|██████████| 226/226 [00:03<00:00, 64.80it/s]


Train Loss: 0.09287589287335894
Val Loss: 0.09605482965707779
Epoch 4


100%|██████████| 226/226 [00:03<00:00, 64.45it/s]


Train Loss: 0.09078152205762083
Val Loss: 0.08552774786949158
Epoch 5


100%|██████████| 226/226 [00:03<00:00, 61.81it/s]


Train Loss: 0.08604284389092859
Val Loss: 0.08660633862018585
Epoch 6


100%|██████████| 226/226 [00:03<00:00, 60.92it/s]


Train Loss: 0.08404391408955629
Val Loss: 0.0792718231678009
Epoch 7


100%|██████████| 226/226 [00:03<00:00, 60.61it/s]


Train Loss: 0.08131421532119269
Val Loss: 0.08753850311040878
Epoch 8


100%|██████████| 226/226 [00:03<00:00, 65.50it/s]


Train Loss: 0.0798107481569843
Val Loss: 0.07733990252017975
Epoch 9


100%|██████████| 226/226 [00:03<00:00, 69.29it/s]


Train Loss: 0.0796310806531558
Val Loss: 0.0800861120223999
Epoch 10


100%|██████████| 226/226 [00:03<00:00, 61.97it/s]


Train Loss: 0.07864695236878058
Val Loss: 0.07887125760316849
Epoch 11


100%|██████████| 226/226 [00:03<00:00, 63.74it/s]


Train Loss: 0.0793935161545477
Val Loss: 0.07330000400543213
Epoch 12


100%|██████████| 226/226 [00:03<00:00, 65.29it/s]


Train Loss: 0.07703827659442888
Val Loss: 0.07407233119010925
Epoch 13


100%|██████████| 226/226 [00:03<00:00, 64.59it/s]


Train Loss: 0.07748670013934637
Val Loss: 0.07228431105613708
Epoch 14


100%|██████████| 226/226 [00:03<00:00, 61.28it/s]


Train Loss: 0.0753069851811217
Val Loss: 0.07786265760660172
Epoch 15


100%|██████████| 226/226 [00:03<00:00, 64.04it/s]


Train Loss: 0.07500791465616331
Val Loss: 0.07258640974760056
Epoch 16


100%|██████████| 226/226 [00:03<00:00, 65.90it/s]


Train Loss: 0.07446256044230103
Val Loss: 0.0731412023305893
Epoch 17


100%|██████████| 226/226 [00:03<00:00, 63.08it/s]


Train Loss: 0.07438242038200914
Val Loss: 0.07230972498655319
Epoch 18


100%|██████████| 226/226 [00:03<00:00, 66.47it/s]


Train Loss: 0.0738449555417869
Val Loss: 0.07205026596784592
Epoch 19


100%|██████████| 226/226 [00:03<00:00, 65.41it/s]


Train Loss: 0.07261720907965065
Val Loss: 0.07570961862802505
Epoch 20


100%|██████████| 226/226 [00:03<00:00, 66.27it/s]


Train Loss: 0.07215696471059217
Val Loss: 0.071410171687603
Epoch 21


100%|██████████| 226/226 [00:03<00:00, 65.63it/s]


Train Loss: 0.07149629553664574
Val Loss: 0.07132541388273239
Epoch 22


100%|██████████| 226/226 [00:03<00:00, 66.42it/s]


Train Loss: 0.07066374127052527
Val Loss: 0.0705186203122139
Epoch 23


100%|██████████| 226/226 [00:03<00:00, 63.06it/s]


Train Loss: 0.07008481304268394
Val Loss: 0.06938931345939636
Epoch 24


100%|██████████| 226/226 [00:03<00:00, 65.20it/s]


Train Loss: 0.06965060098574752
Val Loss: 0.06960906088352203
Epoch 25


100%|██████████| 226/226 [00:03<00:00, 63.71it/s]


Train Loss: 0.06898473714173367
Val Loss: 0.06893468648195267
Epoch 26


100%|██████████| 226/226 [00:03<00:00, 64.10it/s]


Train Loss: 0.06872402996945698
Val Loss: 0.06892707198858261
Epoch 27


100%|██████████| 226/226 [00:03<00:00, 66.35it/s]


Train Loss: 0.06848679591728522
Val Loss: 0.06870012730360031
Epoch 28


100%|██████████| 226/226 [00:03<00:00, 64.39it/s]


Train Loss: 0.06828281240521279
Val Loss: 0.06878464668989182
Epoch 29


100%|██████████| 226/226 [00:03<00:00, 65.32it/s]


Train Loss: 0.06813607758850651
Val Loss: 0.06870730221271515


In [None]:
preds = []

model.eval()
for batch in tqdm(test_dataloader):
    batch = batch.to(device)
    pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    preds.append(pred)

preds = torch.cat(preds).detach().cpu().numpy()

sub_df = pd.read_csv("data/sample_submission.csv")
sub_df["Reorg_g"] = preds[:, 0]
sub_df["Reorg_ex"] = preds[:, 1]
sub_df.to_csv("submission.csv", sep=",", index=False)