In [80]:
import numpy as np
import os
import pandas as pd
import datetime
import torch
import gc
import io
import json
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import networkx as nx
from rdkit.Chem import rdmolops
from rdkit.Chem.rdmolops import AddHs
from rdkit.Chem.rdmolops import GetMolFrags
from rdkit import Chem, RDLogger
from rdkit import DataStructs
from rdkit.Chem import AllChem, Draw, Descriptors
from rdkit.Chem import AtomValenceException
from rdkit.Chem import Descriptors
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import rdmolfiles
from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
import rdkit.RDLogger as rdl
from tensorflow import keras
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.utils import plot_model
import tensorflow as tf
from scipy.spatial.distance import pdist, squareform
import concurrent.futures
from tqdm import tqdm
import base64
from IPython.display import display
from IPython.display import Image
from multiprocessing import Pool, cpu_count
from PIL import Image
import gzip
import pickle
import psutil
# import pygraphviz as pgv
import time
from torch_geometric.data import DataLoader, Data



In [88]:
from contextlib import contextmanager
import sys
import os
import torch
from ase import Atoms
from ase.data import chemical_symbols
from ase.calculators.morse import MorsePotential
from ase.optimize import QuasiNewton
import numpy as np
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from ase import Atoms
from ase.io import read
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx
# from torch_geometric.data import DataLoader, TensorDataset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from ase.build import molecule
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import random
from torch_geometric.nn import global_add_pool, GATConv, CGConv, GCNConv, RGCNConv
from torch_geometric.nn.models.schnet import GaussianSmearing
from sklearn.metrics import roc_auc_score, precision_score, confusion_matrix
import matplotlib.pyplot as plt
import pickle
from collections import Counter
import seaborn as sns
import time
from torch.optim.lr_scheduler import StepLR
import torch
from torch_geometric.data import Data

In [None]:
# Code taken from https://github.com/bmacedo111/MedGAN/

In [44]:
with tf.device('/GPU:0'):
    atom_mapping = {
        "C": 0,
        "N": 1,
        "O": 2,
        "H": 3,
        "F": 4,
        "S": 5,
        "Cl": 6,
    }

    bond_mapping = {
        "SINGLE": 0,
        0: Chem.BondType.SINGLE,
        "DOUBLE": 1,
        1: Chem.BondType.DOUBLE,
        "TRIPLE": 2,
        2: Chem.BondType.TRIPLE,
        "AROMATIC": 3,
        3: Chem.BondType.AROMATIC,
    }

    charge_mapping = {
        -1: 7,
        1: 8,
    }

    NUM_ATOMS = 63
    ATOM_DIM = 11
    BOND_DIM = 4 + 1
    LATENT_DIM = 256

In [45]:
def GraphGenerator(
    dense_units, dropout_rate, latent_dim, adjacency_shape, feature_shape,
):
    z = keras.layers.Input(shape=(LATENT_DIM,))
    # Propagate through one or more densely connected layers
    x = z
    for units in dense_units:
        x = keras.layers.Dense(units, activation="tanh")(x)
        x = keras.layers.Dropout(dropout_rate)(x)

    # Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)
    x_adjacency = keras.layers.Dense(tf.math.reduce_prod(adjacency_shape))(x)
    x_adjacency = keras.layers.Reshape(adjacency_shape)(x_adjacency)
    # Symmetrify tensors in the last two dimensions
    x_adjacency = (x_adjacency + tf.transpose(x_adjacency, (0, 1, 3, 2))) / 2
    x_adjacency = keras.layers.Softmax(axis=1)(x_adjacency)

    # Map outputs of previous layer (x) to [continuous] feature tensors (x_features)
    x_features = keras.layers.Dense(tf.math.reduce_prod(feature_shape))(x)
    x_features = keras.layers.Reshape(feature_shape)(x_features)
    x_features = keras.layers.Softmax(axis=2)(x_features)

    return keras.Model(inputs=z, outputs=[x_adjacency, x_features], name="Generator")

with tf.device('/GPU:0'):
    generator = GraphGenerator(
        dense_units=[128, 256, 512,1024,2048, 4096],
        dropout_rate=0.50,
        latent_dim=LATENT_DIM,
        adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
        feature_shape=(NUM_ATOMS, ATOM_DIM),
    )
generator.summary()


Model: "Generator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 256)]        0           []                               
                                                                                                  
 dense_11 (Dense)               (None, 128)          32896       ['input_4[0][0]']                
                                                                                                  
 dropout_8 (Dropout)            (None, 128)          0           ['dense_11[0][0]']               
                                                                                                  
 dense_12 (Dense)               (None, 256)          33024       ['dropout_8[0][0]']              
                                                                                          

In [46]:
class RelationalGraphConvLayer(keras.layers.Layer):
    def __init__(
        self,
        units=128,  # 128
        activation="relu",
        use_bias=False,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.units = units
        self.activation = keras.activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.bias_initializer = keras.initializers.get(bias_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
        self.bias_regularizer = keras.regularizers.get(bias_regularizer)

    def build(self, input_shape):
        bond_dim = input_shape[0][1]
        atom_dim = input_shape[1][2]

        self.kernel = self.add_weight(
            shape=(bond_dim, atom_dim, self.units),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            trainable=True,
            name="W",
            dtype=tf.float32,
        )

        if self.use_bias:
            self.bias = self.add_weight(
                shape=(bond_dim, 1, self.units),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                trainable=True,
                name="b",
                dtype=tf.float32,
            )

        self.built = True

    def call(self, inputs, training=False):
        adjacency, features = inputs
        # Aggregate information from neighbors
        x = tf.matmul(adjacency, features[:, None, :, :])
        # Apply linear transformation
        x = tf.matmul(x, self.kernel)
        if self.use_bias:
            x += self.bias
        # Reduce bond types dim
        x_reduced = tf.reduce_sum(x, axis=1)
        # Apply non-linear transformation
        return self.activation(x_reduced)


def GraphDiscriminator(
    gconv_units, dense_units, dropout_rate, adjacency_shape, feature_shape
):

    adjacency = keras.layers.Input(shape=adjacency_shape)
    features = keras.layers.Input(shape=feature_shape)

    # Propagate through one or more graph convolutional layers
    features_transformed = features
    for units in gconv_units:
        features_transformed = RelationalGraphConvLayer(units)(
            [adjacency, features_transformed]
        )

    # Reduce 2-D representation of molecule to 1-D
    x = keras.layers.GlobalAveragePooling1D()(features_transformed)

    # Propagate through one or more densely connected layers
    for units in dense_units:
        x = keras.layers.Dense(units, activation="relu")(x)
        x = keras.layers.Dropout(dropout_rate)(x)

    # For each molecule, output a single scalar value expressing the "realness" of the inputted molecule
    x_out = keras.layers.Dense(1, dtype="float32")(x)

    return keras.Model(inputs=[adjacency, features], outputs=x_out)

with tf.device('/GPU:0'):
    discriminator = GraphDiscriminator(
        gconv_units= [512, 512, 512, 512], 
        dense_units= [4096, 4096],
        dropout_rate=0.50,
        adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
        feature_shape=(NUM_ATOMS, ATOM_DIM),
    )
discriminator.summary()


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, 5, 63, 63)]  0           []                               
                                                                                                  
 input_6 (InputLayer)           [(None, 63, 11)]     0           []                               
                                                                                                  
 relational_graph_conv_layer_4   (None, 63, 512)     28160       ['input_5[0][0]',                
 (RelationalGraphConvLayer)                                       'input_6[0][0]']                
                                                                                                  
 relational_graph_conv_layer_5   (None, 63, 512)     1310720     ['input_5[0][0]',          

In [47]:
class GraphWGAN(keras.Model):
    def __init__(
        self,
        generator,
        discriminator,
        discriminator_steps=4,
        generator_steps=1,
        gp_weight=10,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        self.discriminator_steps = discriminator_steps
        self.generator_steps = generator_steps
        self.gp_weight = gp_weight
        self.latent_dim = self.generator.input_shape[-1]
        self.epoch = 0
        self.num_samples = 1
        self.metric_wgan_gen_loss = keras.metrics.Mean(name="wgan_gen_loss")

    def compile(self, optimizer_generator, optimizer_discriminator, **kwargs):
        super().compile(**kwargs)
        self.optimizer_generator = optimizer_generator
        self.optimizer_discriminator = optimizer_discriminator
        self.metric_generator = keras.metrics.Mean(name="loss_gen")
        self.metric_discriminator = keras.metrics.Mean(name="loss_dis")

    # code to train

    def train_step(self, inputs):
        start_time = time.time()

        if isinstance(inputs[0], tuple):
            inputs = inputs[0]

        graph_real = inputs
        self.batch_size = tf.shape(inputs[0])[0]

        # Train the discriminator for one or more steps
        for _ in range(self.discriminator_steps):
            z = tf.random.normal((self.batch_size, self.latent_dim))
            with tf.GradientTape() as tape:
                graph_generated = self.generator(z, training=True)
                loss = self._loss_discriminator(graph_real, graph_generated)
            grads = tape.gradient(loss, self.discriminator.trainable_weights)
            self.optimizer_discriminator.apply_gradients(zip(grads, self.discriminator.trainable_weights))
            self.metric_discriminator.update_state(loss)

        # Train the generator for one or more steps
        for _ in range(self.generator_steps):
            z = tf.random.normal((self.batch_size, self.latent_dim))

            with tf.GradientTape() as tape:
                graph_generated = self.generator(z, training=True)
                loss_wgan_generator = self._loss_generator(graph_generated)
                self.metric_wgan_gen_loss.update_state(loss_wgan_generator)
            grads = tape.gradient(loss_wgan_generator, self.generator.trainable_weights)
            self.optimizer_generator.apply_gradients(zip(grads, self.generator.trainable_weights))
            self.metric_generator.update_state(loss_wgan_generator)

        # end_time = time.time()
        # time_per_epoch = end_time - start_time

        # cpu_usage = get_cpu_usage()

        logs = {m.name: m.result() for m in self.metrics}
        # logs['time'] = time_per_epoch
        # logs['cpu_usage'] = cpu_usage

        return logs

    def _loss_discriminator(self, graph_real, graph_generated):
        logits_real = self.discriminator(graph_real, training=True)
        logits_generated = self.discriminator(graph_generated, training=True)
        loss = tf.reduce_mean(logits_generated) - tf.reduce_mean(logits_real)
        loss_gp = self._gradient_penalty(graph_real, graph_generated)
        return loss + loss_gp * self.gp_weight

    def _loss_generator(self, graph_generated):
        logits_generated = self.discriminator(graph_generated, training=True)
        return -tf.reduce_mean(logits_generated)

    def _gradient_penalty(self, graph_real, graph_generated):
        # Unpack graphs
        adjacency_real, features_real = graph_real
        adjacency_generated, features_generated = graph_generated

        # Generate interpolated graphs (adjacency_interp and features_interp)
        alpha = tf.random.uniform([self.batch_size])
        alpha = tf.reshape(alpha, (self.batch_size, 1, 1, 1))
        adjacency_interp = (adjacency_real * alpha) + (1 - alpha) * adjacency_generated
        alpha = tf.reshape(alpha, (self.batch_size, 1, 1))
        features_interp = (features_real * alpha) + (1 - alpha) * features_generated

        # Compute the logits of interpolated graphs
        with tf.GradientTape() as tape:
            tape.watch(adjacency_interp)
            tape.watch(features_interp)
            logits = self.discriminator(
                [adjacency_interp, features_interp], training=True
            )

        # Compute the gradients with respect to the interpolated graphs
        grads = tape.gradient(logits, [adjacency_interp, features_interp])
        # Compute the gradient penalty
        grads_adjacency_penalty = (1 - tf.norm(grads[0], axis=1)) ** 2
        grads_features_penalty = (1 - tf.norm(grads[1], axis=2)) ** 2
        return tf.reduce_mean(
            tf.reduce_mean(grads_adjacency_penalty, axis=(-2, -1))
            + tf.reduce_mean(grads_features_penalty, axis=(-1))
        )

    def save_model(self, folder_path="models/MedGAN"):
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        self.generator.save(os.path.join(folder_path, "generator"))
        self.discriminator.save(os.path.join(folder_path, "discriminator"))

    def load_model(self, folder_path="models/MedGAN"):
        self.generator = keras.models.load_model(os.path.join(folder_path, "generator"))
        self.discriminator = keras.models.load_model(os.path.join(folder_path, "discriminator"))


In [48]:
with tf.device('/GPU:0'):
    wgan = GraphWGAN(generator, discriminator, discriminator_steps=1)

with tf.device('/GPU:0'):
    wgan.compile(
        optimizer_generator=keras.optimizers.RMSprop(1e-4),
        optimizer_discriminator=keras.optimizers.RMSprop(1e-4)
    )

In [49]:
with tf.device('/GPU:0'):
    bond_tensors = torch.load("data/bond_tensors.pt")
    bond_tensors = bond_tensors.numpy()
    bond_tensors = bond_tensors.reshape(3208, 5, 63, 63)
    print(bond_tensors.shape)
    atomic_tensors = torch.load("data/atomic_number_tensors.pt")
    atomic_tensors = atomic_tensors.numpy()

(3208, 5, 63, 63)


In [50]:
def data_generator(adjacency_tensor, feature_tensor, batch_size):
    dataset_size = len(adjacency_tensor)
    indices = np.arange(dataset_size)
    while True:
        # Shuffle indices at the start of each epoch
        #np.random.shuffle(indices)
        for i in range(0, dataset_size, batch_size):
            batch_indices = indices[i: min(i + batch_size, dataset_size)]
            batch_adjacency_tensor = adjacency_tensor[batch_indices]
            batch_feature_tensor = feature_tensor[batch_indices]
            yield [batch_adjacency_tensor, batch_feature_tensor]

In [51]:
batch_size = 32
with tf.device('/GPU:0'):
    data_gen = data_generator(bond_tensors, atomic_tensors, batch_size)

In [52]:
steps_per_epoch = len(atomic_tensors) // batch_size
if len(atomic_tensors) % batch_size != 0:
    steps_per_epoch += 1

with tf.device('/GPU:0'):
    # Train the model
    wgan.fit(
        data_gen,
        epochs=300,
        steps_per_epoch=steps_per_epoch
    )

2024-04-17 01:31:33.964919: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]


Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 63/300
Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Epoch 75/300
Epoch 76/300
Epoch 77/300
Epoch 78

In [60]:
z = tf.random.normal((batch_size, LATENT_DIM))
graph = generator.predict(z)
adjacency = tf.argmax(graph[0], axis=1)
adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)
adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
features = tf.argmax(graph[1], axis=2)
features = tf.one_hot(features, depth=ATOM_DIM, axis=2)
molecules = []
none_counter = 0
# for i in tqdm(range(batch_size), desc="Generating molecules"):
#     try:
#         mol = 
#         molecules.append(mol)
#     except AtomValenceException:
#         molecules.append(None)
#         none_counter += 1  # Increment the counter

def save_model(self, folder_path="models/WGAN/"):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    self.generator.save(os.path.join(folder_path, "generator"))
    self.discriminator.save(os.path.join(folder_path, "discriminator"))




In [61]:
wgan.save_model()



2024-04-17 01:58:12.435403: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,128]
	 [[{{node inputs}}]]
2024-04-17 01:58:12.458157: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,256]
	 [[{{node inputs}}]]
2024-04-17 01:58:12.480055: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,512]
	 [[{{node inputs}}]]
2024-04-17 01:58

INFO:tensorflow:Assets written to: models/MedGAN/generator/assets


2024-04-17 01:58:15.867414: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,4096]
	 [[{{node inputs}}]]
2024-04-17 01:58:15.892828: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,4096]
	 [[{{node inputs}}]]
2024-04-17 01:58:16.457307: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'inputs' with dtype float and shape [?,4096]
	 [[{{node inputs}}]]
2024-04-17 01

INFO:tensorflow:Assets written to: models/MedGAN/discriminator/assets


In [62]:
np.array(features)[:,:,-1].sum()/32

31.375

In [63]:
np.array(adjacency).reshape((32,63,63,5))[:, :, :, -1].sum()

25431.0

In [64]:
32*63*63

127008

In [65]:
25042/32


782.5625

In [71]:
def mol_sample(generator, batch_size):
    z = tf.random.normal((batch_size, LATENT_DIM))
    graph = generator.predict(z)
    adjacency = tf.argmax(graph[0], axis=1)
    adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)
    adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
    features = tf.argmax(graph[1], axis=2)
    features = tf.one_hot(features, depth=ATOM_DIM, axis=2)
    molecules = []
    none_counter = 0
    for i in tqdm(range(batch_size), desc="Generating molecules"):
        try:
            mol = process_graph_data_to_data_objects(adjacency[i].numpy(), features[i].numpy())
            molecules.append(mol)
        except AtomValenceException:
            molecules.append(None)
            none_counter += 1  # Increment the counter
    return molecules, none_counter  # Return the counter

In [81]:
def process_graph_data_to_data_objects(adjacency, atoms, grad = False):
    num_nodes = NUM_ATOMS
    num_bond_types = BOND_DIM - 1  # Exclude 'no bond' type
    num_atom_types = ATOM_DIM - 1  # Exclude 'no atom' type

    adjacency = adjacency.reshape(NUM_ATOMS, NUM_ATOMS, BOND_DIM)
    adjacency = torch.tensor(adjacency)
    atoms = torch.tensor(atoms)
    
    data_objects = []

    edge_index = []
    edge_attr = []
    x = []
    node_index_map = {}
    
    new_idx = 0
    for i in range(num_nodes):
        atom_type_i = torch.argmax(atoms[i])
        if atom_type_i < num_atom_types:  # Ensure atom exists
            node_index_map[i] = new_idx
            # x.append(atoms[i, :num_atom_types])  # Add existing atom types only
            x.append(torch.nn.functional.one_hot(atom_type_i.clone().detach(), num_classes=num_atom_types).float())
            new_idx += 1

    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):  # Only look at one direction to avoid duplicates
            if i in node_index_map and j in node_index_map:
                bond_type = torch.argmax(adjacency[i, j])
                # if adjacency[i,j][num_bond_types] > 0.2 or adjacency[i,j][bond_type] < 0.3:
                #     # print('overwrite')
                #     continue
                if bond_type < num_bond_types:  # Bond exists
                    edge_index.append([node_index_map[i], node_index_map[j]])
                    bond_one_hot = torch.zeros(num_bond_types)
                    bond_one_hot[bond_type] = 1
                    edge_attr.append(bond_type)
                    # print(bond_type)
                    # edge_attr.append(bond_type/num_bond_types)

    if edge_index:
        edge_index_tensor = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr_tensor = torch.stack(edge_attr, dim=0)

    else:
        # No edges, create dummy tensors to handle this case
        edge_index_tensor = torch.empty((2, 0), dtype=torch.long)
        edge_attr_tensor = torch.empty((0, num_bond_types), dtype=torch.float)
        # return None

    if x:
        node_features_tensor = torch.stack(x, dim=0)
    else:
        # return None
        node_features_tensor = torch.empty((0, num_atom_types), dtype=torch.float)

    # Create a PyTorch Geometric Data object
    # if grad:
    #     node_features_tensor.requires_grad_(True)
    #     edge_attr.requires_grad_(True)
    graph_data = Data(x=node_features_tensor, edge_index=edge_index_tensor, edge_attr=edge_attr_tensor)

    return graph_data

In [82]:
mols, counds = mol_sample(wgan.generator, batch_size)



Generating molecules: 100%|██████████| 32/32 [00:00<00:00, 34.05it/s]


In [85]:
print(mols[0])

Data(x=[29, 10], edge_index=[2, 381], edge_attr=[381])


In [86]:
atomic_mapping = {0: 6, 1: 7, 2:8, 3:9, 4:14, 5:15, 6:16, 7:17, 8:35, 9:53}

In [89]:
def get_molecule(datapoint):
    #atomic_numbers = datapoint.x.argmax(dim=1).tolist()
    atomic_numbers_unmapped = datapoint.x.argmax(dim=1).tolist()
    atomic_numbers = []
    for at in atomic_numbers_unmapped:
        atomic_numbers.append(atomic_mapping[at])
    # positions = datapoint.positions.numpy()
    positions = np.random.rand(len(atomic_numbers), 3) * 10  # Random positions within a 10x10x10 Å box

    molecule = Atoms(numbers=atomic_numbers, positions=positions)
    return molecule

@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

def get_dft(datapoint, De=0.242, re=0.74, alpha=1.5):
    molecule = get_molecule(datapoint)
    print('Molecule:', molecule)
    dft_calculator = MorsePotential(De=De, r0=re, alpha=alpha)
    molecule.set_calculator(dft_calculator)
    # energy = molecule.get_potential_energy()
    with suppress_stdout():
        opt = QuasiNewton(molecule)
        opt.run(fmax=0.02)
    optimized_energy = molecule.get_potential_energy()
    return optimized_energy

energy = get_dft(mols[0])

print(f'DFT Energy: {energy} eV')

Molecule: Atoms(symbols='OBrOSSiSBrI2NISiSNClP2NSO2NOFCOS2N', pbc=False)
DFT Energy: -5.999999409890069 eV


In [90]:
class GraphEncoder(nn.Module):
    def __init__(self, input_dim, num_relations):
        super(GraphEncoder, self).__init__()
        self.rgcnconv1 = RGCNConv(input_dim, 512, num_relations=num_relations)
        self.rgcnconv2 = RGCNConv(512, 256, num_relations=num_relations)
        self.rgcnconv3 = RGCNConv(256, 128, num_relations=num_relations)

    def forward(self, x, edge_index, edge_attr):
        x = self.rgcnconv1(x, edge_index, edge_attr)
        x = nn.ReLU()(x)
        x = self.rgcnconv2(x, edge_index, edge_attr)
        x = nn.ReLU()(x)
        x = self.rgcnconv3(x, edge_index, edge_attr)
        x = nn.ReLU()(x)
        return x

class BioClassifier(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(BioClassifier, self).__init__()

        self.encoder = GraphEncoder(input_dim, num_heads)
        self.fc1 = nn.Linear(128, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.fc2 = nn.Linear(64, 1)
        self.dropout = nn.Dropout(0.4)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = x.float()

        x = self.encoder(x, edge_index, edge_attr)
        x = global_add_pool(x, data.batch)
        x = self.dropout(x)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return x.squeeze(dim=1)

In [91]:
model_bio_eval = BioClassifier(10, 4)

state_dict = torch.load('models/rcgn_model_7428.pt')
model_bio_eval.load_state_dict(state_dict)

model_bio_eval.eval()
def get_model_prediction(datapoint):
    pred = model_bio_eval(datapoint)
    return pred.item()

In [94]:
dft = get_dft(mols[0])
y = get_model_prediction(mols[0]) * 100

print(f'DFT (remember negative is good): {dft}')
print(f'Chance of Biodgradability: {y:.2f}%')

Molecule: Atoms(symbols='OBrOSSiSBrI2NISiSNClP2NSO2NOFCOS2N', pbc=False)
DFT (remember negative is good): -14.99999977448502
Chance of Biodgradability: 0.00%
