# This notebook generates molecular crystal graph from pickled data from `cif2cluster_data.ipynb` descriptors



In [2]:
import gzip
import pickle
import urllib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import pandas as pd
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw

In [3]:
f = 'acetam_GNN_DATA.pkl'
file = open(f,'rb')
#trajs = pickle.load(file)
acetam_trajs = pd.read_pickle(file)
# label_str = list(set([k.split("-")[0] for k in trajs]))

In [4]:
f = 'glutam_GNN_DATA.pkl'
file = open(f,'rb')
#trajs = pickle.load(file)
glutam_trajs = pd.read_pickle(file)
# label_str = list(set([k.split("-")[0] for k in trajs]))

In [5]:
f = 'meoh_GNN_DATA.pkl'
file = open(f,'rb')
#trajs = pickle.load(file)
meoh_trajs = pd.read_pickle(file)
# label_str = list(set([k.split("-")[0] for k in trajs]))

In [6]:
f = 'csd_GNN_DATA.pkl'
file = open(f,'rb')
#trajs = pickle.load(file)
csd_trajs = pd.read_pickle(file)
# label_str = list(set([k.split("-")[0] for k in trajs]))

In [6]:
#f = 'GNN_DATA.pkl'
#file = open(f,'rb')
#trajs = pickle.load(file)
#trajs = pd.read_pickle(file)
# label_str = list(set([k.split("-")[0] for k in trajs]))

In [7]:
# # now build dataset
def generator():
    for index,row in acetam_trajs.iterrows():
        label = row['Label']
        traj = row['Coordinates']
        nodes = row['Descriptors']
        #for i in range(traj.shape[0]):
        yield traj, label, nodes

data = tf.data.Dataset.from_generator(
    generator,
    output_signature=(
        tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32),
        tf.TensorSpec(shape=(None, 24), dtype=tf.float32),
    ),
).shuffle(
    1000,
    reshuffle_each_iteration=False,  # do not change order each time (!) otherwise will contaminate
)

# # The shuffling above is really important because this dataset is in order of labels!

# val_data = data.take(100)
# test_data = data.skip(100).take(100)
# train_data = data.skip(200)

val_data = data.take(100)
test_data = data.skip(100).take(100)
train_data = data.skip(100)

In [None]:
"""
Visualize 2D projection of 3D data
"""
fig, axs = plt.subplots(1, 5, figsize=(12, 2))
axs = axs.flatten()

# get a few example and plot them
for i, (x, y, z) in enumerate(val_data):
    # print(x[:,0])
    # print(z)
    # print(x)
    if i == 20:
        break
    axs[i].plot(x[:, 1], x[:, 2], ".")
    axs[i].set_title(str(i))
    axs[i].axis("off")


In [None]:
# this decorator speeds up the function by "compiling" it (tracing it)
# to run efficienty
@tf.function(
    reduce_retracing=True,
)
def get_edges(positions, NN, sorted=True):
    M = tf.shape(input=positions)[0]
    # adjust NN
    NN = tf.minimum(NN, M)
    qexpand = tf.expand_dims(positions, 1)  # one column
    qTexpand = tf.expand_dims(positions, 0)  # one row
    # repeat it to make matrix of all positions
    qtile = tf.tile(qexpand, [1, M, 1])
    qTtile = tf.tile(qTexpand, [M, 1, 1])
    # subtract them to get distance matrix
    dist_mat = qTtile - qtile
    # mask distance matrix to remove zros (self-interactions)
    dist = tf.norm(tensor=dist_mat, axis=2)
    mask = dist >= 5e-4
    mask_cast = tf.cast(mask, dtype=dist.dtype)
    # make masked things be really far
    dist_mat_r = dist * mask_cast + (1 - mask_cast) * 1000
    topk = tf.math.top_k(-dist_mat_r, k=NN, sorted=sorted)
    return -topk.values, topk.indices



In [None]:
from matplotlib import collections

fig, axs = plt.subplots(1, 5, figsize=(12, 2))
axs = axs.flatten()
for i, (x, y, z) in enumerate(data):
    if i == 6:
        break
    e_f, e_i = get_edges(x, 6)

    # make things easier for plotting
    e_i = e_i.numpy()
    x = x.numpy()
    y = y.numpy()

    # make lines from origin to its neigbhors
    lines = []
    colors = []
    for j in range(0, x.shape[0], 23):
        # lines are [(xstart, ystart), (xend, yend)]
        lines.extend([[(x[j, 0], x[j, 1]), (x[k, 0], x[k, 1])] for k in e_i[j]])
        colors.extend([f"C{j}"] * len(e_i[j]))
    lc = collections.LineCollection(lines, linewidths=2, colors=colors)
    axs[i].add_collection(lc)
    axs[i].plot(x[:, 0], x[:, 1], ".")
    axs[i].axis("off")
    #axs[i].set_title(label_str[y])
plt.show()


In [None]:
MAX_DEGREE = 12
EDGE_FEATURES = 8
MAX_R = 15

gamma = 1
mu = np.linspace(0, MAX_R, EDGE_FEATURES)


def rbf(r):
    return tf.exp(-gamma * (r[..., tf.newaxis] - mu) ** 2)


def make_graph(x, y, n):
    edge_r, edge_i = get_edges(x, MAX_DEGREE)
    edge_features = rbf(edge_r)    
    return (n, edge_features, edge_i), y[None]


graph_train_data = train_data.map(make_graph)
graph_val_data = val_data.map(make_graph)
graph_test_data = test_data.map(make_graph)


In [None]:
for (n, e, nn), y in graph_train_data:
    print("first node:", n[1].numpy())
    print("first node, first edge features:", e[1, 1].numpy())
    print("first node, all neighbors", nn[1].numpy())
    print("label", y.numpy())
    break


In [None]:
def ssp(x):
    # shifted softplus activation
    return tf.math.log(0.5 * tf.math.exp(x) + 0.5)


def make_h1(units):
    return tf.keras.Sequential([tf.keras.layers.Dense(units)])


def make_h2(units):
    return tf.keras.Sequential(
        [
            tf.keras.layers.Dense(units, activation=ssp),
            tf.keras.layers.Dense(units, activation=ssp),
        ]
    )


def make_h3(units):
    return tf.keras.Sequential(
        [tf.keras.layers.Dense(units, activation=ssp), tf.keras.layers.Dense(units)]
    )


In [None]:
class SchNetModel(tf.keras.Model):
    """Implementation of SchNet Model"""

    def __init__(self, gnn_blocks, channels, label_dim, **kwargs):
        super(SchNetModel, self).__init__(**kwargs)
        self.gnn_blocks = gnn_blocks

        # build our layers
        
        self.embedding = tf.keras.layers.Embedding(89, channels-24 + 1)
        self.h1s = [make_h1(channels) for _ in range(self.gnn_blocks)]
        self.h2s = [make_h2(channels) for _ in range(self.gnn_blocks)]
        self.h3s = [make_h3(channels) for _ in range(self.gnn_blocks)]
        self.readout_l1 = tf.keras.layers.Dense(channels // 2, activation=ssp)
        self.readout_l2 = tf.keras.layers.Dense(label_dim)

    def call(self, inputs):
        nodes, edge_features, edge_i = inputs
        #print(nodes.shape,edge_features.shape,edge_i.shape)
        # turn node types as index to features
        # embedded = self.embedding(nodes[:,0])
        nodes = tf.concat([self.embedding(nodes[:,0]),nodes[:,1:]],1)
        print(nodes.shape,edge_features.shape,edge_i.shape)
        for i in range(self.gnn_blocks):
            # get the node features per edge
            v_sk = tf.gather(nodes, edge_i)
            e_k = self.h1s[i](v_sk) * self.h2s[i](edge_features)
            e_i = tf.reduce_sum(e_k, axis=1)
            nodes += self.h3s[i](e_i)
        # readout now
        nodes = self.readout_l1(nodes)
        nodes = self.readout_l2(nodes)
        return tf.reduce_mean(nodes, axis=0)


In [None]:
small_schnet = SchNetModel(3, 32, 2)

In [None]:
for x, y in graph_test_data:
    #print(x[0])
    yhat = small_schnet(x)
    #break
    print(yhat.numpy())

In [None]:
x[0]

In [None]:
embed(nodes[:,0])

In [None]:
train_data