In [11]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import numpy as np
import warnings
import pandas as pd
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw
import networkx as nx
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch.utils.data as data
import torch.optim as optim
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import dense_to_sparse, add_self_loops, to_scipy_sparse_matrix
from torch_geometric.data import Data
import torch.nn.functional as F
from six.moves import urllib
import deepchem as dc
import random
from dgl.nn import Set2Set

torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
np.random.seed(0)
random.seed(0)

warnings.filterwarnings("ignore")
sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
        "figure.dpi": 300,
    },
)
color_cycle = ["#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)

opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

# had to rehost because dataverse isn't reliable
soldata = pd.read_csv(
    # "https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv"
    "/Users/adityabehal/Downloads/curated-solubility-dataset.csv"
)

In [12]:
def gen_smiles2graph(smiles):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
    out = featurizer.featurize(smiles)
    return out[0]

In [13]:
testCO = gen_smiles2graph("CO")
print(testCO)
print(type(testCO))
print(testCO.node_features)
print(type(testCO.node_features))
print(testCO.edge_index)
print(type(testCO.edge_index))
print(testCO.edge_features)
print(type(testCO.edge_features))

GraphData(node_features=[2, 30], edge_index=[2, 2], edge_features=[2, 11], pos=[0])
<class 'deepchem.feat.graph_data.GraphData'>
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
  0. 0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0.
  0. 1. 0. 0. 0. 0.]]
<class 'numpy.ndarray'>
[[0 1]
 [1 0]]
<class 'numpy.ndarray'>
[[1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
<class 'numpy.ndarray'>


In [14]:
print("number of atoms: ", testCO.num_nodes)
print("number of edges: ", testCO.num_edges)

number of atoms:  2
number of edges:  2


In [15]:
mw_CO = rdkit.Chem.Descriptors.ExactMolWt(rdkit.Chem.MolFromSmiles("CO"))
print(mw_CO)

32.026214748


In [16]:
# tutorial on smiles strings: https://chemicbook.com/2021/02/13/smiles-strings-explained-for-beginners-part-1.html
formal_charge_Cl_anion = rdkit.Chem.rdmolops.GetFormalCharge(rdkit.Chem.MolFromSmiles("[Cl-]"))
print(formal_charge_Cl_anion)

-1


In [38]:
graphInstance = testCO

global_features = np.zeros((1,6))
global_features[0][0] = graphInstance.num_nodes
global_features[0][1] = graphInstance.num_edges
global_features[0][2] = int(rdkit.Chem.Descriptors.ExactMolWt(rdkit.Chem.MolFromSmiles("CO")))

formal_charge = rdkit.Chem.rdmolops.GetFormalCharge(rdkit.Chem.MolFromSmiles("CO"))

if formal_charge < 0:
    global_features[0][3] = 1
    global_features[0][4] = 0
    global_features[0][5] = 0
elif formal_charge > 0:
    global_features[0][3] = 0
    global_features[0][4] = 0
    global_features[0][5] = 1
else:
    global_features[0][3] = 0
    global_features[0][4] = 1
    global_features[0][5] = 0

In [39]:
graphInstanceWithGlobalFeatures = dc.feat.graph_data.GraphData(node_features=graphInstance.node_features,
                                             edge_index=graphInstance.edge_index,
                                             edge_features=graphInstance.edge_features,
                                             z=global_features)

In [40]:
graphInstanceWithGlobalFeatures.node_features

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [41]:
graphInstanceWithGlobalFeatures.edge_features

array([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [42]:
graphInstanceWithGlobalFeatures.z

array([[ 2.,  2., 32.,  0.,  1.,  0.]])

In [43]:
graphInstanceWithGlobalFeatures.z.shape

(1, 6)

In [44]:
graphInstance.z = global_features

In [45]:
print(graphInstance)

GraphData(node_features=[2, 30], edge_index=[2, 2], edge_features=[2, 11], pos=[0])


In [47]:
print(graphInstance.z.shape)

(1, 6)


In [48]:
import torch
torch.cat?

In [49]:
torch.sum?

In [52]:
x = torch.Tensor([[1, 2, 3], 
                  [4, 5, 6], 
                  [7, 8, 9]])

print(torch.sum(x, dim=0))

tensor([12., 15., 18.])
