In [None]:
import torch
if torch.__version__ != '1.6.0':
  !pip uninstall torch -y
  !pip uninstall torchvision -y
  !pip install torch==1.6.0
  !pip install torchvision==0.7.0

# Check pytorch version and make sure you use a GPU Kernel
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"
!python --version
!nvidia-smi
import torch
pytorch_version = f"torch-{torch.__version__}.html"
!pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install torch-geometric

In [None]:
pip install rdkit

In [None]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from rdkit import DataStructs
import numpy as np
import rdkit
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem.rdchem import HybridizationType
from rdkit import RDConfig
from rdkit.Chem import ChemicalFeatures
from rdkit.Chem.rdchem import BondType as BT
import torch.nn.functional as F
from torch_sparse import coalesce
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.data import (InMemoryDataset, download_url, extract_zip,
                                  Data)
import os
import os.path as osp
from torch_geometric.data import Dataset
from torch_sparse import coalesce
import torch.utils.data 
from torch.utils.data.dataloader import default_collate

In [None]:
class GNN1(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None,
                 pre_filter=None):
        super(GNN1, self).__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property 
    def raw_file_names(self):
        return 'raw.pt' if rdkit is None else 'raw.csv'

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        pass

    def process(self):
        data = pd.read_csv('JSONdata (1).csv')
        data_smiles = data['smiles']
        data_list = []
       
        writer = Chem.SDWriter('temporary.sdf')
        for m in data_smiles.values:
            mol = Chem.rdmolfiles.MolFromSmiles(m)
            # mol = Chem.rdmolops.AddHs(mol)   # explicit trivial Hs (excluded)
            writer.write(mol)
        del writer

        suppl = Chem.SDMolSupplier('temporary.sdf', removeHs=False)
        target = torch.tensor(data['energy'].values, dtype=torch.float)

        atoms = []
        bonds = []
        dict_types = dict()
        dict_bonds = dict()
        for i in data_smiles:
            mol= Chem.MolFromSmiles(i)
            for x,j in zip(mol.GetAtoms(),mol.GetBonds()):
                if x.GetSymbol() not in atoms:
                    atoms.append(x.GetSymbol())
                if j.GetBondType() not in bonds:
                    bonds.append(j.GetBondType())
        for i in range(len(atoms)):
            dict_types[atoms[i]] = i
        for i in range(len(bonds)):
            dict_bonds[bonds[i]] = i
        dict_types['H'] = len(dict_types)

        for mol in suppl:
            N = mol.GetNumAtoms()
            text = suppl.GetItemText(i)
            pos = text.split('\n')[4:4 + N]
            pos = [[float(x) for x in line.split()[:3]] for line in pos]
            pos = torch.tensor(pos, dtype=torch.float)
            type_idx = []
            aromatic = []
            ring = []
            sp = []
            sp2 = []
            sp3 = []
            sp3d = []
            sp3d2 = []
            num_hs = []
            num_neighbors = []
            for atom in mol.GetAtoms():
                type_idx.append(dict_types[atom.GetSymbol()])
                aromatic.append(1 if atom.GetIsAromatic() else 0)
                ring.append(1 if atom.IsInRing() else 0)
                hybridization = atom.GetHybridization()
                sp.append(1 if hybridization == HybridizationType.SP else 0)
                sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
                sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
                sp3d.append(1 if hybridization == HybridizationType.SP3D else 0)
                sp3d2.append(1 if hybridization == HybridizationType.SP3D2 else 0)
                num_hs.append(atom.GetTotalNumHs(includeNeighbors=True))
                num_neighbors.append(len(atom.GetNeighbors()))
            x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(dict_types))
            x2 = torch.tensor([aromatic, ring, sp, sp2, sp3, sp3d, sp3d2], dtype=torch.float).t().contiguous()
            x3 = F.one_hot(torch.tensor(num_neighbors) , num_classes=7)
            x4 = F.one_hot(torch.tensor(num_hs), num_classes=5)
            x = torch.cat([x1.to(torch.float), x2, x3.to(torch.float),x4.to(torch.float)], dim=-1)
         
            row, col, bond_idx, conj, ring, stereo = [], [], [], [], [], []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                bond_idx += 2 * [dict_bonds[bond.GetBondType()]]
                conj.append(bond.GetIsConjugated())
                conj.append(bond.GetIsConjugated())
                ring.append(bond.IsInRing())
                ring.append(bond.IsInRing())
                stereo.append(bond.GetStereo())
                stereo.append(bond.GetStereo())
            edge_index = torch.tensor([row, col], dtype=torch.long)
            e1 = F.one_hot(torch.tensor(bond_idx).to(torch.long),num_classes=len(dict_bonds)).to(torch.float)
            e2 = torch.tensor([conj, ring], dtype=torch.float).t().contiguous()
            e3 = F.one_hot(torch.tensor(stereo).to(torch.long),num_classes=6).to(torch.float)
            edge_attr = torch.cat([e1, e2, e3], dim=-1)
            edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)

            y = target[i].unsqueeze(0)
          
            data = Data(x=x, pos=pos, edge_index=edge_index,
                     edge_attr=edge_attr,  y = y)
         
            data_list.append(data)

        torch.save((self.collate(data_list)), self.processed_paths[0])

In [None]:
dataset = GNN1('data/')

In [None]:
from torch.nn import Linear
import torch.nn.functional as F 
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

In [None]:
embedding_size = 64

class GCN(torch.nn.Module):
    def __init__(self):
        # Init parent
        super(GCN, self).__init__()
        torch.manual_seed(42)

        # GCN layers
        self.initial_conv = GCNConv(dataset.num_features, embedding_size)
        self.conv1 = GCNConv(embedding_size, embedding_size)
        self.conv2 = GCNConv(embedding_size, embedding_size)
        self.conv3 = GCNConv(embedding_size, embedding_size)

        # Output layer
        self.out = Linear(embedding_size*2, 1)

    def forward(self, x, edge_index, batch_index):
        # First Conv layer
        hidden = self.initial_conv(x, edge_index)
        hidden = F.tanh(hidden)

        # Other Conv layers
        hidden = self.conv1(hidden, edge_index)
        hidden = F.tanh(hidden)
        hidden = self.conv2(hidden, edge_index)
        hidden = F.tanh(hidden)
        hidden = self.conv3(hidden, edge_index)
        hidden = F.tanh(hidden)
          
        # Global Pooling (stack different aggregations)
        hidden = torch.cat([gmp(hidden, batch_index), 
                            gap(hidden, batch_index)], dim=1)

        # Apply a final (linear) classifier.
        out = self.out(hidden)

        return out, hidden

model = GCN()
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

In [None]:
from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

# Root mean squared error
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)  

# Use GPU for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Wrap data in a data loader
data_size = len(dataset)
NUM_GRAPHS_PER_BATCH = 64
loader = DataLoader(dataset[:int(data_size * 0.8)], 
                    batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
test_loader = DataLoader(dataset[int(data_size * 0.8):], 
                         batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

def train(data):
    # Enumerate over the data
    for batch in loader:
        # Use GPU
        batch.to(device)  
        # Reset gradients
        optimizer.zero_grad() 
        # Passing the node features and the connection info
        pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch) 
        # Calculating the loss and gradients
        loss = loss_fn(pred, batch.y)     
        loss.backward()  
        # Update using the gradients
        optimizer.step()   
    return loss, embedding

print("Starting training...")
losses = []
for epoch in range(2000):
    loss, h = train(dataset)
    losses.append(loss)
    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Train Loss {loss}")
        
        

In [None]:
import seaborn as sns
losses_float = [float(loss.cpu().detach().numpy()) for loss in losses] 
loss_indices = [i for i,l in enumerate(losses_float)] 
plt = sns.lineplot(loss_indices, losses_float)
plt

In [None]:
# Analyze the results for one batch
test_batch = next(iter(test_loader))
with torch.no_grad():
    test_batch.to(device)
    pred, embed = model(test_batch.x.float(), test_batch.edge_index, test_batch.batch) 
    df = pd.DataFrame()
    df["y_real"] = test_batch.y.tolist()
    df["y_pred"] = pred.tolist()
df