<a href="https://colab.research.google.com/github/124andrewM/molecule_gnn_xyz/blob/main/test_xyz_workflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Molecule XYZ

### Purpose: A basic workflow using the Keras API to test data input-output.

Note: you will need to transfer the sample data which is comma seperated into your Google Drive. Adjust the file paths and folders names as required below.

## Data Process
- Data is first parsed by taking the atoms and a NumPy array of the coordinates.
- Atoms are one-hot encoded for numerical input.
- Each atom gets a one-hot encoding and it's xyz coordinates.
- Building very basic adjacency matrices using a simple threshold value for whether atoms are connected.
- Padding the molecule graphs with 0's.
- Using a binary mask (0 = padding, 1 = real atom) for tracking the padded nodes.
- Batching using tf (TensorFlow Dataset).

## Model
- The model uses GCN layers followed by masked average pooling (pooling but ignoring the padded atoms), then a dense layer to give us a binary classification using a sigmoid activation function.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [25]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import OneHotEncoder
from tensorflow.keras.layers import Dropout

In [3]:
# An easy way to toggle print statements
VERBOSE = True

In [4]:
stable_data_path = '/content/drive/MyDrive/sample-data-xyz/stable'
unstable_data_path = '/content/drive/MyDrive/sample-data-xyz/unstable'

Converting the chars into a numerical form.

In [5]:
qm9_atoms = ['H', 'C', 'N', 'O', 'F']
encoder = OneHotEncoder(categories=[qm9_atoms], sparse_output=False, handle_unknown='ignore')
encoder.fit(np.array(qm9_atoms).reshape(-1, 1))

### Function: parse_comma_files
- Just extracting the contents of the files, comma works better as a safety net.

In [6]:
def parse_comma_files(file_path):
  atoms, coords = [], []
  with open(file_path, 'r') as file:
    for line in file:
      single_atom = line.strip().split(',')
      if len(single_atom) == 4:
        atoms.append(single_atom[0])
        coords.append([float(val) for val in single_atom[1:]])
  return atoms, np.array(coords)

### Function: dense_adj
- 1.0 is a completely random value to select which atoms are connected. Function definetly needs changing.

In [7]:
def dense_adj(coords):
    # linalg (linear algebra from NumPy)
    # Computes pairwise distances between atoms using Euclidean norm
    dists = np.linalg.norm(coords[:, None] - coords[None, :], axis=-1)
    adj = ((dists < 1.0) & (dists > 0)).astype(np.float32)
    return adj

### Function: pad_array
- Padding to the largest molecule size

In [8]:
# Padding because Spektrals disjoint data mode was a nightmare.
def pad_array(arr, new_shape):
    p_val = 0.0
    padded = np.full(new_shape, p_val, dtype=arr.dtype)
    padded[:arr.shape[0], :arr.shape[1]] = arr
    return padded

### Function: prepare_data

In [9]:
def prepare_data(stable_dir, unstable_dir):
    node_list, adj_list, mask_list, label_list = [], [], [], []
    all_files = [] # Temp list to collect all files and labels

    # Iterate over both stable and unstable folders
    for label, folder in [(1, stable_dir), (0, unstable_dir)]:
        for file_name in os.listdir(folder):
            file_path = os.path.join(folder, file_name)
            all_files.append((file_path, label))

    max_nodes = 0 # This will tell us how much padding to add
    temp_graphs = [] # Purely temp storage for the second loop

    # Go through each file-label pair and build a temp dataset
    for file_path, label in all_files:
        # Parse atom labels and 3D coords
        atoms, coords = parse_comma_files(file_path)
        # Save as a single data set
        temp_graphs.append((atoms, coords, label))
        # Adjust max_nodes if needed
        max_nodes = max(max_nodes, len(atoms))

    # Process each molecule
    for atoms, coords, label in temp_graphs:
        shaped_atoms = np.array(atoms).reshape(-1, 1) # Reshape the atoms
        encoded_atoms = encoder.transform(shaped_atoms).astype(np.float32) # Convert chars into one-hot encoded vector
        e_atoms = np.concatenate([encoded_atoms, coords], axis=-1) # Shape the en as a single vector
        num_real = e_atoms.shape[0] # Store the number of real nodes before padding

        node_list.append(pad_array(e_atoms, (max_nodes, e_atoms.shape[1])))
        adj_list.append(pad_array(dense_adj(coords), (max_nodes, max_nodes)))
        label_list.append(np.array([label], dtype=np.float32))

        mask = np.zeros((max_nodes,), dtype=np.float32) # Create a binary mask vector
        mask[:num_real] = 1.0 # 1 is a real atom, otherwise it's padding
        mask_list.append(mask)

    return (
        np.array(node_list),
        np.array(adj_list),
        np.array(mask_list),
        np.array(label_list),
        max_nodes
    )

### Function: create_tf_ds
- Creates a tf dataset and batches them

In [10]:
def create_tf_ds(molecules, adj_arr, mask_arr, labels, batch_size=32, shuffle=True):
    ds = tf.data.Dataset.from_tensor_slices(((molecules, adj_arr, mask_arr), labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(molecules))
    ds = ds.batch(batch_size)
    return ds

### Function: masked_avg_pooling
- Custom pooling function to account for padding.
- Essentially, compressing all node features into one vector for passing into the dense layer.

In [11]:
def masked_avg_pooling(node_embeddings, node_mask):
    masked_embeddings = node_embeddings * tf.expand_dims(node_mask, axis=-1) # 0 out padded nodes
    summed_features = tf.reduce_sum(masked_embeddings, axis=1) # sum over real nodes
    count = tf.reduce_sum(node_mask, axis=1) # count the real nodes
    return summed_features / tf.maximum(count, 1) # prevent division by zero

### Class GraphConvLayer
- Basic template class for graph data.
- Call takes two inputs: `node_features` and `adjacency`

In [12]:
class GraphConvLayer(layers.Layer):
    def __init__(self, output_size, activation=None):
        super().__init__()
        self.output_size = output_size
        self.activation = tf.keras.activations.get(activation)

    def build(self, input_shape):
        num_features = input_shape[0][-1]

        self.kernel = self.add_weight(shape=(num_features, self.output_size),
                                      initializer='glorot_uniform',
                                      name='kernel')
        self.bias = self.add_weight(shape=(self.output_size,),
                                    initializer='zeros',
                                    name='bias')
        super().build(input_shape)

    def call(self, inputs):
        node_features, adjacency = inputs
        messages = tf.matmul(adjacency, node_features)
        output = tf.matmul(messages, self.kernel) + self.bias
        if self.activation is not None:
            output = self.activation(output)
        return output

### Prepare Data

In [13]:
# Load and process our data
molecules, adj_arr, mask_arr, labels, max_nodes = prepare_data(stable_data_path, unstable_data_path)
if VERBOSE:
  print("Loaded", molecules.shape[0], "graphs, each padded to", max_nodes, "nodes")

Loaded 4000 graphs, each padded to 20 nodes


### Create Datasets

In [14]:
num_molecules = molecules.shape[0]
indices = np.arange(num_molecules)
np.random.shuffle(indices)
split = int(0.8 * num_molecules)
train_idx, test_idx = indices[:split], indices[split:]

m_train, adj_train, mask_train, label_train = molecules[train_idx], adj_arr[train_idx], mask_arr[train_idx], labels[train_idx]
m_test, adj_test, mask_test, label_test = molecules[test_idx], adj_arr[test_idx], mask_arr[test_idx], labels[test_idx]

train_ds = create_tf_ds(m_train, adj_train, mask_train, label_train, batch_size=32, shuffle=True)
test_ds = create_tf_ds(m_test, adj_test, mask_test, label_test, batch_size=32, shuffle=False)

### Build the model

In [33]:
input_mols = Input(shape=(max_nodes, molecules.shape[-1]), name="node_features")
input_adj = Input(shape=(max_nodes, max_nodes), name="adjacency")
input_mask = Input(shape=(max_nodes,), name="mask")

layer_1 = GraphConvLayer(32, activation='relu')([input_mols, input_adj])
layer_2 = GraphConvLayer(32, activation='relu')([layer_1, input_adj])
layer_3 = GraphConvLayer(32, activation='relu')([layer_2, input_adj])
pool = layers.Lambda(lambda args: masked_avg_pooling(*args))([layer_3, input_mask])
output = Dense(1, activation='sigmoid')(pool)
model = Model(inputs=[input_mols, input_adj, input_mask], outputs=output)

### Compile the model

In [34]:
model.compile(optimizer=Adam(1e-3),
              loss=BinaryCrossentropy(from_logits=False),
              metrics=['accuracy'])
model.summary()

### Train the model

In [35]:
history = model.fit(train_ds, epochs=12)

Epoch 1/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 6ms/step - accuracy: 0.6474 - loss: 0.6518
Epoch 2/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.7635 - loss: 0.5398
Epoch 3/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.8303 - loss: 0.4167
Epoch 4/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.9064 - loss: 0.2904
Epoch 5/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.9101 - loss: 0.2583
Epoch 6/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.8973 - loss: 0.2731
Epoch 7/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.9080 - loss: 0.2538
Epoch 8/12
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.9062 - loss: 0.2556
Epoch 9/12
[1m100/100[0m [32m━━━━━━━━

### Evaluate the model

In [36]:
predictions = model.predict(test_ds)
preds = predictions.flatten()
true = label_test.flatten()

for i in range(10):
    print(f"Sample Num: {i+1}, Real Label: {int(true[i])}, Predicted Probability: {preds[i]:.4f}")

[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step
Sample Num: 1, Real Label: 0, Predicted Probability: 0.0033
Sample Num: 2, Real Label: 0, Predicted Probability: 0.0000
Sample Num: 3, Real Label: 1, Predicted Probability: 0.8829
Sample Num: 4, Real Label: 0, Predicted Probability: 0.0000
Sample Num: 5, Real Label: 0, Predicted Probability: 0.0023
Sample Num: 6, Real Label: 0, Predicted Probability: 0.0000
Sample Num: 7, Real Label: 1, Predicted Probability: 0.7913
Sample Num: 8, Real Label: 1, Predicted Probability: 0.8062
Sample Num: 9, Real Label: 0, Predicted Probability: 0.7913
Sample Num: 10, Real Label: 0, Predicted Probability: 0.8202


In [37]:
loss, accuracy = model.evaluate(test_ds)
print("Test Loss: ", loss)
print("Test Accuracy: ", accuracy)

[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.8918 - loss: 0.2858
Test Loss:  0.2896278500556946
Test Accuracy:  0.8899999856948853
