<a href="https://colab.research.google.com/github/Shubodh/learn_mol/blob/master/graphVAE_shub.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wandb
!wandb login

In [None]:
!pip install git+https://github.com/zotko/xyz2graph.git plotly networkx torch_geometric rdkit

In [None]:
import torch
print(torch.__version__)
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html

In [None]:
from xyz2graph import MolGraph, to_networkx_graph, to_plotly_figure
from plotly.offline import init_notebook_mode, iplot
import networkx as nx
import numpy as np
import torch
from torch_geometric.utils.convert import from_networkx
from torch_geometric.utils import negative_sampling
import re
from itertools import combinations
from math import sqrt
from torch_geometric.nn import DimeNet
from rdkit import Chem
import random
import wandb
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch import Tensor
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import train_test_split_edges


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
atomic_radii = dict(Ac=1.88, Ag=1.59, Al=1.35, Am=1.51, As=1.21, Au=1.50, B=0.83, Ba=1.34, Be=0.35, Bi=1.54, Br=1.21,
                    C=0.68, Ca=0.99, Cd=1.69, Ce=1.83, Cl=0.99, Co=1.33, Cr=1.35, Cs=1.67, Cu=1.52, D=0.23, Dy=1.75,
                    Er=1.73, Eu=1.99, F=0.64, Fe=1.34, Ga=1.22, Gd=1.79, Ge=1.17, H=0.23, Hf=1.57, Hg=1.70, Ho=1.74,
                    I=1.40, In=1.63, Ir=1.32, K=1.33, La=1.87, Li=0.68, Lu=1.72, Mg=1.10, Mn=1.35, Mo=1.47, N=0.68,
                    Na=0.97, Nb=1.48, Nd=1.81, Ni=1.50, Np=1.55, O=0.68, Os=1.37, P=1.05, Pa=1.61, Pb=1.54, Pd=1.50,
                    Pm=1.80, Po=1.68, Pr=1.82, Pt=1.50, Pu=1.53, Ra=1.90, Rb=1.47, Re=1.35, Rh=1.45, Ru=1.40, S=1.02,
                    Sb=1.46, Sc=1.44, Se=1.22, Si=1.20, Sm=1.80, Sn=1.46, Sr=1.12, Ta=1.43, Tb=1.76, Tc=1.35, Te=1.47,
                    Th=1.79, Ti=1.47, Tl=1.55, Tm=1.72, U=1.58, V=1.33, W=1.37, Y=1.78, Yb=1.94, Zn=1.45, Zr=1.56)


class MolGraph_mod:
    """Represents a molecular graph."""
    __slots__ = ['elements', 'x', 'y', 'z', 'adj_list',
                 'atomic_radii', 'bond_lengths', 'bond_orders']

    def __init__(self):
        self.elements = []
        self.x = []
        self.y = []
        self.z = []
        self.adj_list = {}
        self.atomic_radii = []
        self.bond_lengths = {}
        self.bond_orders = None

    def read_xyz(self,molxyz,bo=None):
        """Reads an XYZ file, searches for elements and their cartesian coordinates
        and adds them to corresponding arrays."""
        pattern = re.compile(r'([A-Za-z]{1,3})\s*(-?\d+(?:\.\d+)?)\s*(-?\d+(?:\.\d+)?)\s*(-?\d+(?:\.\d+)?)')
        for element, x, y, z in pattern.findall(str(molxyz)):
            self.elements.append(element)
            self.x.append(float(x))
            self.y.append(float(y))
            self.z.append(float(z))
        self.atomic_radii = [atomic_radii[element] for element in self.elements]
        if bo is not None:  
            self.bond_orders = bo
        self._generate_adjacency_list()

    def _generate_adjacency_list(self):
        """Generates an adjacency list from atomic cartesian coordinates."""
        node_ids = range(len(self.elements))
        for i, j in combinations(node_ids, 2):
            x_i, y_i, z_i = self.__getitem__(i)[1]
            x_j, y_j, z_j = self.__getitem__(j)[1]
            distance = sqrt((x_i - x_j) ** 2 + (y_i - y_j) ** 2 + (z_i - z_j) ** 2)
            if self.bond_orders is None:
                if 0.1 < distance < (self.atomic_radii[i] + self.atomic_radii[j]) * 1.3:
                    dist_limit = (self.atomic_radii[i] + self.atomic_radii[j]) * 1.3
                    self.adj_list.setdefault(i, set()).add(j)
                    self.adj_list.setdefault(j, set()).add(i)
                    self.bond_lengths[frozenset([i, j])] = round(((distance-0.1)/(dist_limit-0.1)), 5)
            else:
                if frozenset([i, j]) in self.bond_orders:
                    dist_limit = (self.atomic_radii[i] + self.atomic_radii[j]) * 1.3
                    self.bond_lengths[frozenset([i, j])] = round(((distance-0.1)/(dist_limit-0.1)), 5)
                    self.adj_list.setdefault(i, set()).add(j)
                    self.adj_list.setdefault(j, set()).add(i)
                assert len(self.bond_orders) > 0, f'{len(self.bond_orders)}'

    def edges(self):
        """Creates an iterator with all graph edges."""
        edges = set()
        for node, neighbours in self.adj_list.items():
            for neighbour in neighbours:
                edge = frozenset([node, neighbour])
                if edge in edges:
                    continue
                edges.add(edge)
                yield node, neighbour
    
    def __len__(self):
        return len(self.elements)

    def __getitem__(self, position):
        return self.elements[position], (
            self.x[position], self.y[position], self.z[position])

In [None]:

data = open('/content/drive/MyDrive/generating_chelating_agents/data/tmQM_X.xyz',"r").read().splitlines()
charges = open('/content/drive/MyDrive/generating_chelating_agents/data/tmQM_X.q',"r").read().splitlines()
BO = open('/content/drive/MyDrive/generating_chelating_agents/data/tmQM_X.BO',"r").read().split('CSD_code = ')
# print
BO = [i.splitlines()[:-1] for i in BO[1:]]

In [None]:
bond_orders = {}
csd_codes = []
for mol in BO:
    res = {}
    # if 'Fe' in mol[1]:
    if True:
        csd_codes.append(mol[0])
        for k in mol[1:]:
            k = k.split()
            p_idx = int(k[0])-1
            p_atom = k[1]
            i = 3
            while i < len(k)-1:
                c_atom, c_idx, bo = k[i], int(k[i+1])-1, float(k[i+2])
                # print(f'{c_atom}, {c_idx}, {bo}')
                res[frozenset([c_idx, p_idx])] = bo
                i += 3
        bond_orders[csd_codes[-1]] = res

In [None]:
len(bond_orders)

86665

In [None]:

PT = Chem.GetPeriodicTable()
init_notebook_mode(connected=True)

def to_networkx_graph(graph: MolGraph_mod) -> nx.Graph:
    """Creates a NetworkX graph.
    Atomic elements and coordinates are added to the graph as node attributes 'element' and 'xyz" respectively.
    Bond lengths are added to the graph as edge attribute 'length''"""
    G = nx.Graph(graph.adj_list)
    node_attrs = {num: {'x': [PT.GetAtomicNumber(element), xyz[0], xyz[1], xyz[2]], 'xyz': xyz} for num, (element, xyz) in enumerate(graph)}
    nx.set_node_attributes(G, node_attrs)
    edge_attrs = {edge: {'x': [graph.bond_orders[edge], length]} for edge, length in graph.bond_lengths.items()}
    nx.set_edge_attributes(G, edge_attrs)
    return G


#this block of code extracts structures containing Fe and convert them into networkx readable graph. 
#"graphs" contains all the individual molecular graphs containing Fe. Total 4446 complexes are  found. 
#nodes contains information about the type of atoms/nodes present in each molecular graphs 
graphs = []
nodes = []
data_list = []
csd_codes_mol = []
for ndx, line in enumerate(data):
    #print(line)
    # if ndx < 10:
    if ndx < len(data)-1:
        if line == '':
            total_atoms_in_mol = int(data[ndx+1])
            #print(total_atoms_in_mol,ndx+1+total_atoms_in_mol)
            csd_code = data[ndx+2].split()[2]
            # print(csd_code)
            mol_xyz = data[ndx+1:ndx+1+total_atoms_in_mol]
            #finds complexes containing Fe (Iron)
            if csd_code in csd_codes and total_atoms_in_mol < 30:
            # if 'Fe' in np.array(mol_xyz)[1]:
                mol = MolGraph_mod()
                # Read the data from the xyz coordinate block
                mol.read_xyz(mol_xyz, bond_orders[csd_code])
                elements = set(mol.elements)
                nodes.append(mol.elements)
                G = to_networkx_graph(mol)
                # if 0 not in G: continue
                # bfs = nx.bfs_tree(G, source=0)
                # p = from_networkx(bfs)
                p = from_networkx(G)
                # recreating node and edge attr lists in bfs node ordering
                G = G.to_directed()
                graphs.append(G)
                p.x = Tensor([G.nodes[i]['x'] for i in G.nodes])
                p.x = p.x.to(device)
                p.edge_attr = Tensor([G.edges[i]['x'] for i in G.edges])
                p.edge_attr = p.edge_attr.to(device)
                p.edge_index = p.edge_index.to(device)
                data_list.append(p)

In [None]:
print(len(data_list))
print(len(bond_orders))

4302
86665


In [None]:
from torch_geometric.nn import VGAE

class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, edge_dim=1, heads=1,num_layers=1):
        super(VariationalGCNEncoder, self).__init__()
        self.conv1 = []
        self.conv1.append(GATConv(in_channels, 2 * out_channels, edge_dim=edge_dim, heads=heads))
        if num_layers > 1:
            self.conv1.append(GATConv(2 * out_channels, 2 * out_channels, edge_dim=edge_dim, heads=heads))
        self.conv1 = torch.nn.ModuleList(self.conv1)
        
        self.conv_mu = GATConv(2 * out_channels, out_channels, edge_dim=edge_dim, heads=heads)
        self.conv_logstd = GATConv(2 * out_channels, out_channels, edge_dim=edge_dim, heads=heads)

    def forward(self, x, edge_index, edge_weights):
        for conv in self.conv1:
            x = conv(x, edge_index, edge_attr=edge_weights).relu()
        return self.conv_mu(x, edge_index, edge_attr=edge_weights), self.conv_logstd(x, edge_index)

# class InnerProductDecoder(torch.nn.Module):

#     def __init__(self, )

#     def forward(self, z, edge_index, sigmoid=True):
#         r"""Decodes the latent variables :obj:`z` into edge probabilities for
#         the given node-pairs :obj:`edge_index`.

#         Args:
#             z (Tensor): The latent space :math:`\mathbf{Z}`.
#             sigmoid (bool, optional): If set to :obj:`False`, does not apply
#                 the logistic sigmoid function to the output.
#                 (default: :obj:`True`)
#         """
#         value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
#         return torch.sigmoid(value) if sigmoid else value

In [None]:
def train(epoch, loader, beta=0.2, train=True):
    model.train()
    running_loss_kl = 0
    running_loss = 0
    n = 0
    if train:
        for data in loader:
            n += 1
            optimizer.zero_grad()
            z = model.encode(data.x, data.edge_index, data.edge_attr)
            loss = model.recon_loss(z, data.edge_index)
            #if args.variational:
            kl = model.kl_loss()
            running_loss += loss.item()
            running_loss_kl += kl.item()
            loss = loss + beta * kl
            loss.backward()
            optimizer.step()
        wandb.log({"epoch": epoch, 'loss_kl/train': running_loss_kl/n, 'loss_recon/train': running_loss/n})

    return float((running_loss+running_loss_kl)/n)

def test(epoch, loader):
    model.eval()
    running_loss = 0
    running_loss_kl = 0
    running_auc = 0
    running_ap = 0
    n = 0
    for data in loader: 
        n += 1
        with torch.no_grad():
            z = model.encode(data.x, data.edge_index, data.edge_attr)
            loss = model.recon_loss(z, data.edge_index)
            kl = model.kl_loss()
            running_loss += loss.item()
            running_loss_kl += kl.item()
            neg_edges = negative_sampling(data.edge_index)
            auc, ap = model.test(z, data.edge_index, neg_edges)
            running_auc += auc.item()
            running_ap += ap.item()
    wandb.log({"epoch": epoch, 'loss_kl/val': running_loss_kl/n, 'loss_recon/val': running_loss/n, 'auc/val': running_auc/n, 'ap/val': running_ap/n})
    return float((running_loss+running_loss_kl)/n)

In [None]:
# parameters
N = len(data_list)
split = [0.8, 0.2]
N_train = int(N * split[0])
random.seed(42)
random.shuffle(data_list)
batch_size = 32
lr = 0.01
num_layers = 3
out_channels = 2
num_features = 4
epochs = 300
edge_dim = 2
heads = 1
train_data = data_list[:N_train]
test_data = data_list[N_train:]
train_loader = DataLoader(train_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)
beta=0.3

# model
model = VGAE(VariationalGCNEncoder(num_features, out_channels, num_layers=num_layers, edge_dim=edge_dim, heads=heads))

# move to GPU (if available)
# device = 'cpu'
model = model.to(device)

# inizialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
wandb.init(project="graphVAE", entity="shivanshseth", config={
    "beta": beta,
    "num_layers": num_layers,
    "latent_channels": out_channels,
    "learning_rate": lr,
    "epochs": epochs,
    "batch_size": batch_size
})
metrics = [
            "loss_kl/train",
            "loss_kl/val",
            "loss/test",
            "loss_kl/test",
            "loss/val",
            "loss_kl/val",
            "auc/val",
            "ap/val",
            ]
for i in metrics:
    wandb.define_metric(name=i, step_metric='epoch')


for epoch in range(1, epochs + 1):
    loss = train(epoch, train_loader, beta)
    test_loss = test(epoch, test_loader)

VBox(children=(Label(value='0.000 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.044827…

0,1
ap/val,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
auc/val,█▁▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss_kl/train,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_kl/val,▇▆▇▃▂▆▅▅▃▃▄▆█▂▂▃█▂▇▇▁▆▄▃▄▇▃▅▃▅▆▇▇▄▄▄▇▁▅▅
loss_recon/train,█▁▁▁▁▁▁▁▁▂▁▂▁▁▂▁▁▁▁▂▁▁▁▂▂▁▁▁▂▁▁▂▁▁▁▂▁▁▁▂
loss_recon/val,▁▇██████████████████████████████████████

0,1
ap/val,0.50049
auc/val,0.49979
epoch,111.0
loss_kl/train,0.20341
loss_kl/val,0.20113
loss_recon/train,1.49679
loss_recon/val,1.38629
