In [119]:
import numpy as np

import torch
import torch.nn as nn
import torch.functional as F


class RBF(nn.Module):
    """
    Radial Basis Function
    """
    def __init__(self, centers, gamma):
        super().__init__()
        self.centers = torch.tensor(centers, dtype=torch.float).unsqueeze(0)
        self.gamma = gamma
    
    def forward(self, x):
        """
        Args:
            x(tensor): (N, 1).
        Returns:
            y(tensor): (N, n_centers)
        """
        x = x.view(-1, 1)
        
        return torch.exp(-self.gamma * torch.square(x - self.centers))
        

class BondLengthRBF(nn.Module):
    """
    Bond Length Encoder using Radial Basis Functions
    """
    def __init__(self, embed_dim, rbf_params=None):
        super().__init__()

        if rbf_params is None:
            self.rbf_params = (np.arange(0, 2, 0.1), 10.0)   # (centers, gamma)
        else:
            self.rbf_params = rbf_params

        centers, gamma = self.rbf_params
        self.rbf = RBF(centers, gamma)
        self.fc = nn.Linear(len(centers), embed_dim)

    def forward(self, bond_lengths):
        rbf_x = self.rbf(bond_lengths)
        out_embed = self.fc(rbf_x)
        
        return out_embed
    

class BondAngleRBF(nn.Module):
    """
    Bond Length Encoder using Radial Basis Functions
    """
    def __init__(self, embed_dim, rbf_params=None):
        super().__init__()

        if rbf_params is None:
            self.rbf_params = (np.arange(0, np.pi, 0.1), 10.0)   # (centers, gamma)
        else:
            self.rbf_params = rbf_params

        centers, gamma = self.rbf_params
        self.rbf = RBF(centers, gamma)
        self.fc = nn.Linear(len(centers), embed_dim)

    def forward(self, bond_lengths):
        rbf_x = self.rbf(bond_lengths)
        out_embed = self.fc(rbf_x)
        
        return out_embed


In [92]:
import os
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from rdkit.Chem.rdMolTransforms import GetBondLength, GetAngleRad

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

import torch_geometric.nn as gnn
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset


from ogb.utils.features import (allowable_features, atom_to_feature_vector,
 bond_to_feature_vector, atom_feature_vector_to_dict, bond_feature_vector_to_dict) 


def mol2graph(mol):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """
    conf = mol.GetConformer()
    
    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_to_feature_vector(atom))
    x = np.array(atom_features_list, dtype = np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        edge_lengths_list = []
        
        for bond_idx, bond in enumerate(mol.GetBonds()):
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)
            edge_length = GetBondLength(conf, i, j)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edge_lengths_list.append(edge_length)
            
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)
            edge_lengths_list.append(edge_length)
            
        bond_bond_list = []
        bond_bond_angles_list = []
            
        for edge_idx, edge in enumerate(edges_list):
            i, j = edge
            for another_edge_idx, another_edge in enumerate(edges_list):
                if j == another_edge[0]:  # connected
                    bond_bond_list.append((edge_idx, another_edge_idx))
                    bond_bond_angles_list.append(GetAngleRad(conf, i, j, another_edge[1]))
                    

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype = np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype = np.int64)
        bond_lengths = np.array(edge_lengths_list, dtype=np.float32)
        
        bond_bond_index = np.array(bond_bond_list, dtype=np.int64).T
        bond_bond_angles = np.array(bond_bond_angles_list, dtype=np.float32)

    else:   # mol has no bonds
        edge_index = np.empty((2, 0), dtype = np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

    return x, edge_attr, edge_index, bond_lengths, bond_bond_index, bond_bond_angles


def get_coordinate_features(mol):
    conf = mol.GetConformer()
    return conf.GetPositions()

def get_mol_data(root, prefix, y=None):
    if prefix.startswith("train"):
        set_dir = "train_set"
    else:
        set_dir = "test_set"
        
    ex = Chem.MolFromMolFile(f"{root}/{set_dir}/{prefix}_ex.mol", removeHs=False)
    g = Chem.MolFromMolFile(f"{root}/{set_dir}/{prefix}_g.mol", removeHs=False)
    
    # Atom features
    X, edge_attr, edge_index, bond_lengths_ex, bond_bond_index, bond_bond_angles_ex = mol2graph(ex)
    X, edge_attr, edge_index, bond_lengths_g, bond_bond_index, bond_bond_angles_g = mol2graph(g)
    
    bond_lengths_ex = torch.tensor(bond_lengths_ex, dtype=torch.float)
    bond_lengths_g = torch.tensor(bond_lengths_g, dtype=torch.float)
    
    bond_bond_index = torch.tensor(bond_bond_index, dtype=torch.long)
    
    bond_bond_angles_ex = torch.tensor(bond_bond_angles_ex, dtype=torch.float)
    bond_bond_angles_g = torch.tensor(bond_bond_angles_g, dtype=torch.float)
    
    # Atom 3D coordinates
    co_ex = get_coordinate_features(ex)
    co_g = get_coordinate_features(g)
    
    X = torch.tensor(X, dtype=torch.float)
    co_ex = torch.tensor(co_ex, dtype=torch.float)
    co_g = torch.tensor(co_g, dtype=torch.float)
    
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    y = torch.tensor([y], dtype=torch.float)
            
    return Data(x=X, pos_g=co_g, pos_ex=co_ex, 
                edge_index=edge_index, edge_attr=edge_attr, 
                bond_lengths_ex=bond_lengths_ex, bond_lengths_g=bond_lengths_g,
                bond_bond_index=bond_bond_index, 
                bond_bond_angles_ex=bond_bond_angles_ex, bond_bond_angles_g=bond_bond_angles_g,
                y=y)
        

def get_datalist(df, root):
    data_list = []
    if "Reorg_g" in df.columns:
        for _, item in tqdm(df.iterrows()):
            y = [item.Reorg_g, item.Reorg_ex]
            data = get_mol_data(root, item[0], y)
            data_list.append(data)
    else:
        for _, item in tqdm(df.iterrows()):
            data = get_mol_data(root, item[0])
            data_list.append(data)
        
    return data_list


class TrainDataset(InMemoryDataset):
    def __init__(
        self,
        root="/data/project/danyoung/reorg/data/mol_files",
        transform=None,
        pre_transform=None,
        pre_filter=None
    ):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        mol_list = os.listdir(os.path.join(self.root, "train_set"))
        mol_list = [os.path.join(self.root, "train_set", file) for file in mol_list]
            
        return mol_list

    @property
    def processed_file_names(self):
        return ["2d_dataset_train.pt"]

    def process(self):
        # Read data into huge `Data` list.
        df = pd.read_csv(f"{self.root}/../train_set.ReorgE.csv")
        data_list = get_datalist(df, self.root)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        
    
class TestDataset(InMemoryDataset):
    def __init__(
        self,
        root="/data/project/danyoung/reorg/data/mol_files", 
        transform=None,
        pre_transform=None,
        pre_filter=None
    ):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        mol_list = os.listdir(os.path.join(self.root, "test_set"))
        mol_list = [os.path.join(self.root, "test_set", file) for file in mol_list]
            
        return mol_list

    @property
    def processed_file_names(self):
        return ["2d_dataset_test.pt"]

    def process(self):
        # Read data into huge `Data` list.
        df = pd.read_csv(f"{self.root}/../test_set.csv")
        data_list = get_datalist(df, self.root)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [93]:
train_data = TrainDataset()

Processing...
18157it [03:23, 89.33it/s] 
Done!


In [94]:
test_data = TestDataset()

Processing...
457it [00:07, 58.09it/s]
Done!


In [95]:
train_data[0]

Data(x=[53, 9], edge_index=[2, 110], edge_attr=[110, 3], y=[1, 2], pos_g=[53, 3], pos_ex=[53, 3], bond_lengths_ex=[110], bond_lengths_g=[110], bond_bond_index=[2, 308], bond_bond_angles_ex=[308], bond_bond_angles_g=[308])

In [107]:
from torch_geometric.loader import DataLoader

train_dataloader = DataLoader(train_data, batch_size=16, follow_batch=["edge_attr"])
batch = next(iter(train_dataloader))
batch

DataBatch(x=[633, 9], edge_index=[2, 1310], edge_attr=[1310, 3], edge_attr_batch=[1310], edge_attr_ptr=[17], y=[16, 2], pos_g=[633, 3], pos_ex=[633, 3], bond_lengths_ex=[1310], bond_lengths_g=[1310], bond_bond_index=[2, 3608], bond_bond_angles_ex=[3608], bond_bond_angles_g=[3608], batch=[633], ptr=[17])

In [113]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn


class MLP(nn.Module):
    def __init__(
        self,
        embed_dim
    ):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.ReLU(),
            nn.Linear(embed_dim * 2, embed_dim)
        )
    
    def forward(self, x):
        return self.main(x)

    

class GINBlock(nn.Module):
    def __init__(
        self,
        embed_dim,
        dropout,
        last_act
    ):
        super().__init__()
        self.gin = gnn.GINEConv(MLP(embed_dim))
        self.ln = gnn.LayerNorm(embed_dim)
        self.gn = gnn.GraphNorm(embed_dim)
        self.do = nn.Dropout(dropout)
        self.act = nn.ReLU() if last_act else nn.Identity()
        
    def forward(self, x, edge_index, edge_attr, batch):
        out = self.gin(x, edge_index, edge_attr)
        out = self.ln(out, batch)
        out = self.gn(out, batch)
        out = self.do(out)
        out = self.act(out)
        
        return x + out

In [139]:
from typing import List

class GEM1(nn.Module):
    def __init__(
        self,
        embed_dim: int = 32,
        dropout: float = 0.1,
        last_act: bool = True,
        n_layers: int = 3,
        pool: str = "mean"
    ):
        super().__init__()
        self.embed_atom = nn.Linear(9, embed_dim)
        self.embed_bond = nn.Linear(3, embed_dim)
        self.embed_bond_length = BondLengthRBF(embed_dim)
        self.embed_bond_angle = BondAngleRBF(embed_dim)
        
        self.atom_gin_layers = nn.ModuleList([GINBlock(embed_dim, dropout, last_act) for _ in range(n_layers)])
        self.bond_gin_layers = nn.ModuleList([GINBlock(embed_dim, dropout, last_act) for _ in range(n_layers)])
        
        if pool == "mean":
            self.pool = gnn.global_mean_pool
        elif pool == "add":
            self.pool = gnn.global_add_pool
        
        self.n_layers = n_layers
        
        
    def forward(self, batch):
        atom_x = self.embed_atom(batch.x)
        
        edge_x = self.embed_bond(batch.edge_attr)
        edge_x = edge_x + self.embed_bond_length(batch.bond_lengths_g) + self.embed_bond_length(batch.bond_lengths_ex)
        
        angle_x = self.embed_bond_angle(batch.bond_bond_angles_g) + self.embed_bond_angle(batch.bond_bond_angles_ex)
        
        for layer_idx in range(self.n_layers):
            bond_gin = self.bond_gin_layers[layer_idx]
            edge_x = bond_gin(edge_x, batch.bond_bond_index, angle_x, batch.edge_attr_batch)
            
            atom_gin = self.atom_gin_layers[layer_idx]
            atom_x = atom_gin(atom_x, batch.edge_index, edge_x, batch.batch)
        
        x = self.pool(atom_x, batch.batch)
        
        return x
    

class Classifier(nn.Module):
    def __init__(
        self,
        input_dim: int = 32,
        hidden_dims: List = [256, 256],
        batch_norm: bool = True,
        dropout: float = 0.5
    ):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.BatchNorm1d(hidden_dims[0]) if batch_norm else nn.Identity(),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.fc_list = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(hidden_dims[i], hidden_dims[i+1]),
                    nn.BatchNorm1d(hidden_dims[i+1]) if batch_norm else nn.Identity(),
                    nn.ReLU(),
                    nn.Dropout()
                ) for i in range(len(hidden_dims) - 1)
            ]
        )
        
        self.fc_last = nn.Linear(hidden_dims[-1], 2)
        
    def forward(self, x):
        x = self.fc1(x)
        
        for fc in self.fc_list:
            x = fc(x)
        
        return self.fc_last(x)
    
net = GEM1()
classifier = Classifier(hidden_dims=[128, 128, 128, 128])
classifier(net(batch)).shape

torch.Size([16, 2])