In [None]:
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np

In [None]:
import tensorflow as tf
import keras
import pysmiles
import numpy as np

In [None]:
de_data_train = pq.read_table("../data/de_train.parquet").to_pandas()
de_data_train

## Divide into train and test

In [None]:
# Cell types where all (cell_type, sm) pairs will be used for training
train_only_cell_types     = ["T cells CD4+", "T cells CD8+", "T regulatory cells"]
# Cell types where only some (cell_type, sm) pairs will be used for training
train_and_test_cell_types = ["B cells", "Myeloid cells", "NK cells"]

In [None]:
# Create a dict mapping cell_name -> list of sm given for cell_name
sm_names_by_cell_type = de_data_train.groupby("cell_type")["sm_name"].unique().to_dict()
# Get list of small molecules given for cell types with a reduced set of (cell_type, sm) pairs
train_and_test_sm = sm_names_by_cell_type["B cells"]

In [None]:
# For cell types where only some (cell_type, sm) pairs will be used for training
# Choose which small molecules will be used for training and which for test
num_b_sm       = len(sm_names_by_cell_type["B cells"])
num_myeloid_sm = len(sm_names_by_cell_type["Myeloid cells"])
num_nk_sm      = len(sm_names_by_cell_type["NK cells"])

b_cell_train       = sm_names_by_cell_type["B cells"][:num_b_sm//2]
myeloid_cell_train = sm_names_by_cell_type["Myeloid cells"][:num_myeloid_sm//2]
nk_cell_train      = sm_names_by_cell_type["NK cells"][:num_nk_sm//2]

b_cell_test       = sm_names_by_cell_type["B cells"][num_b_sm//2:]
myeloid_cell_test = sm_names_by_cell_type["Myeloid cells"][num_myeloid_sm//2:]
nk_cell_test      = sm_names_by_cell_type["NK cells"][num_nk_sm//2:]

In [None]:
# Create training combinations with all (cell_type, sm) pairs for train only cell types
training_combinations = dict((cell_type, sm_names_by_cell_type[cell_type]) for cell_type in train_only_cell_types)

In [None]:
# Include training (cell_type, sm) pairs from train_test cell types
training_combinations["B cells"] = b_cell_train
training_combinations["Myeloid cells"] = myeloid_cell_train
training_combinations["NK cells"] = nk_cell_train

In [None]:
# Create testing combinations
testing_combinations = {}
testing_combinations["B cells"] = b_cell_test
testing_combinations["Myeloid cells"] = myeloid_cell_test
testing_combinations["NK cells"] = nk_cell_test

In [None]:
training_combinations

In [None]:
testing_combinations

In [None]:
# Convert into (cell_type, sm) pairs
training_pairs = set({})
for cell_type in training_combinations.keys():
    for sm in training_combinations[cell_type]:
        training_pairs.add(cell_type+", "+sm)

testing_pairs = set({})
for cell_type in testing_combinations.keys():
    for sm in testing_combinations[cell_type]:
        testing_pairs.add(cell_type+", "+sm)

list(training_pairs)[:10]

In [None]:
de_data_train["cell_type_sm_pair"] = de_data_train["cell_type"]+", "+de_data_train["sm_name"]

In [None]:
de_data_train[de_data_train["cell_type_sm_pair"] == "T regulatory cells, FK 866"].iloc[0][5:-1]

# Convert to np arrays

In [None]:
symbol_to_int = {'H': 0, 'He': 1, 'Li': 2, 'Be': 3, 'B': 4, 'C': 5, 'N': 6, 'O': 7, 'F': 8, 'Ne': 9, 'Na': 10, 'Mg': 11, 'Al': 12, 'Si': 13, 'P': 14, 'S': 15, 'Cl': 16, 'Ar': 17, 'K': 18, 'Ca': 19, 'Sc': 20, 'Ti': 21, 'V': 22, 'Cr': 23, 'Mn': 24, 'Fe': 25, 'Co': 26, 'Ni': 27, 'Cu': 28, 'Zn': 29, 'Ga': 30, 'Ge': 31, 'As': 32, 'Se': 33, 'Br': 34, 'Kr': 35, 'Rb': 36, 'Sr': 37, 'Y': 38, 'Zr': 39, 'Nb': 40, 'Mo': 41, 'Tc': 42, 'Ru': 43, 'Rh': 44, 'Pd': 45, 'Ag': 46, 'Cd': 47, 'In': 48, 'Sn': 49, 'Sb': 50, 'Te': 51, 'I': 52, 'Xe': 53, 'Cs': 54, 'Ba': 55, 'La': 56, 'Ce': 57, 'Pr': 58, 'Nd': 59, 'Pm': 60, 'Sm': 61, 'Eu': 62, 'Gd': 63, 'Tb': 64, 'Dy': 65, 'Ho': 66, 'Er': 67, 'Tm': 68, 'Yb': 69, 'Lu': 70, 'Hf': 71, 'Ta': 72, 'W': 73, 'Re': 74, 'Os': 75, 'Ir': 76, 'Pt': 77, 'Au': 78, 'Hg': 79, 'Tl': 80, 'Pb': 81, 'Bi': 82, 'Po': 83, 'At': 84, 'Rn': 85, 'Fr': 86, 'Ra': 87, 'Ac': 88, 'Th': 89, 'Pa': 90, 'U': 91, 'Np': 92, 'Pu': 93, 'Am': 94, 'Cm': 95, 'Bk': 96, 'Cf': 97, 'Es': 98, 'Fm': 99, 'Md': 100, 'No': 101, 'Lr': 102, 'Rf': 103, 'Db': 104, 'Sg': 105, 'Bh': 106, 'Hs': 107, 'Mt': 108, 'Ds ': 109, 'Rg ': 110, 'Cn ': 111, 'Nh': 112, 'Fl': 113, 'Mc': 114, 'Lv': 115, 'Ts': 116, 'Og': 117}
MAX_NODES = 150
MAX_EDGES = 200
EMBEDDING_DIM = 120

def convertFromNetworkX(graph, maxNodes, maxEdges, embeddingDim):
    nodeEmbeddings = np.zeros((maxNodes, embeddingDim))
    edgeEmbeddings = np.zeros((maxEdges, embeddingDim))
    universalEmbedding = np.zeros((embeddingDim))
    adjacencyMatrix = np.zeros((maxNodes, maxNodes))
    connectedEdges = np.zeros((maxNodes, maxEdges))
    
    # Populate node embeddings.
    for nodeNum, symbol in graph.nodes(data="element"):
        symbolInt = symbol_to_int[symbol]
        nodeEmbeddings[nodeNum][symbolInt] = 1.0
        
    # Populate edge embeddings and adjacency matrix.
    i = 0
    for start, end in graph.edges:
        edgeOrder = graph.get_edge_data(start, end)["order"]

        # Kinda hacky, Edgeorder can be 1.5 prob should map not multiply
        edgeEmbeddings[i][int(edgeOrder*2)] = 1.0
        
        adjacencyMatrix[start][end] = 1.0
        adjacencyMatrix[end][start] = 1.0
        
        connectedEdges[start][i] = 1.0
        connectedEdges[end][i] = 1.0
        
        i += 1
    
    return nodeEmbeddings, edgeEmbeddings, universalEmbedding, adjacencyMatrix, connectedEdges

In [None]:
import time

In [None]:
vertTrain = []
edgeTrain = []
uniTrain = []
adjMatTrain = []
connMatTrain = []

targetsTrain = []

i = 0
for pairName in list(training_pairs):
    pairData = de_data_train[de_data_train["cell_type_sm_pair"] == pairName].iloc[0]
    moleculeSMILES = pairData["SMILES"]

    graph = pysmiles.read_smiles(moleculeSMILES, explicit_hydrogen=True)
    ver, edj, uni, am, conn = convertFromNetworkX(graph, MAX_NODES, MAX_EDGES, EMBEDDING_DIM)

    vertTrain.append(ver)
    edgeTrain.append(edj)
    uniTrain.append(uni)
    adjMatTrain.append(am)
    connMatTrain.append(conn)

    targetsTrain.append(np.array(pairData[5:-1]))
    
    i += 1

In [None]:
vertTest = []
edgeTest = []
uniTest = []
adjMatTest = []
connMatTest = []

targetsTest = []

i = 0
for pairName in list(testing_pairs):
    pairData = de_data_train[de_data_train["cell_type_sm_pair"] == pairName].iloc[0]
    moleculeSMILES = pairData["SMILES"]

    graph = pysmiles.read_smiles(moleculeSMILES, explicit_hydrogen=True)
    ver, edj, uni, am, conn = convertFromNetworkX(graph, MAX_NODES, MAX_EDGES, EMBEDDING_DIM)

    vertTest.append(ver)
    edgeTest.append(edj)
    uniTest.append(uni)
    adjMatTest.append(am)
    connMatTest.append(conn)

    targetsTest.append(np.array(pairData[5:-1]))
    
    i += 1

In [None]:
trainData = {"vertices": np.array(vertTrain),
             "edges": np.array(edgeTrain),
             "universal": np.array(uniTrain),
             "adj": np.array(adjMatTrain),
             "connectedEdges": np.array(connMatTrain)}

testData = {"vertices": np.array(vertTest),
            "edges": np.array(edgeTest),
            "universal": np.array(uniTest),
            "adj": np.array(adjMatTest),
            "connectedEdges": np.array(connMatTest)}

In [None]:
targetsTrain = np.array(targetsTrain).astype(np.float32)
targetsTest = np.array(targetsTest).astype(np.float32)

In [None]:
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input, Model
import tensorflow as tf

In [None]:
class GraphUpdate(keras.layers.Layer):
    def __init__(self, 
                 v_out_dim,
                 e_out_dim,
                 u_out_dim,
                 activation="relu"):
        super().__init__()
        self.v_update = Dense(v_out_dim, activation=activation, name="V_Update")
        self.e_update = Dense(e_out_dim, activation=activation, name="E_Update")
        self.u_update = Dense(u_out_dim, activation=activation, name="U_Update")

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs
        v_out = self.v_update(v_in)
        e_out = self.e_update(e_in)
        u_out = self.u_update(u_in)
        return [v_out, e_out, u_out, adj, conEd]

# Add the embedding of each connected edge to each vertex.
class PoolEdgesToVertices(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs

        pooledEdges = tf.matmul(conEd, e_in)
        v_out = v_in+pooledEdges

        return [v_out, e_in, u_in, adj, conEd]

# Add the embedding of each connected vertex to each edge.
class PoolVerticesToEdges(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs

        pooledNodes = tf.matmul(conEd, v_in, transpose_a=True)
        e_out = e_in+pooledNodes

        return [v_in, e_out, u_in, adj, conEd]

# Pool all vertices to universal.
class PoolVerticesToUniversal(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs

        u_out = tf.reduce_sum(v_in, axis=-1)

        return [v_in, e_in, u_out, adj, conEd]

In [None]:
vertices = Input(shape=((MAX_NODES, EMBEDDING_DIM,)), name="vertices")
edges = Input(shape=((MAX_EDGES, EMBEDDING_DIM,)), name="edges")
universal = Input(shape=((EMBEDDING_DIM,)), name="universal")
adj = Input(shape=((MAX_NODES, MAX_NODES,)), name="adj")
conEd = Input(shape=((MAX_NODES, MAX_EDGES,)), name="connectedEdges")

x = PoolEdgesToVertices()([vertices, edges, universal, adj, conEd])
x = PoolVerticesToEdges()(x)
x = GraphUpdate(32, 32, 32)([vertices, edges, universal, adj, conEd])

x = PoolEdgesToVertices()(x)
x = PoolVerticesToEdges()(x)
x = GraphUpdate(16, 16, 16)(x)

x = PoolVerticesToUniversal()(x)
x = GraphUpdate(8, 8, 8)(x)

u = x[2]
u = Dense(4, activation="relu")(u)
u = Dense(1, activation="sigmoid")(u)

model = Model(inputs=[vertices, edges, universal, adj, conEd], outputs=u)

In [None]:
tf.keras.utils.plot_model(model)

# Define And Train Model

In [None]:
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input, Model
import tensorflow as tf

In [None]:
class GraphUpdate(keras.layers.Layer):
    def __init__(self, 
                 v_out_dim,
                 e_out_dim,
                 u_out_dim,
                 activation="relu"):
        super().__init__()
        self.v_update = Dense(v_out_dim, activation=activation, name="V_Update")
        self.e_update = Dense(e_out_dim, activation=activation, name="E_Update")
        self.u_update = Dense(u_out_dim, activation=activation, name="U_Update")

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs
        v_out = self.v_update(v_in)
        e_out = self.e_update(e_in)
        u_out = self.u_update(u_in)
        return [v_out, e_out, u_out, adj, conEd]

# Add the embedding of each connected edge to each vertex.
class PoolEdgesToVertices(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs

        pooledEdges = tf.matmul(conEd, e_in)
        v_out = v_in+pooledEdges

        return [v_out, e_in, u_in, adj, conEd]

# Add the embedding of each connected vertex to each edge.
class PoolVerticesToEdges(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs

        pooledNodes = tf.matmul(conEd, v_in, transpose_a=True)
        e_out = e_in+pooledNodes

        return [v_in, e_out, u_in, adj, conEd]

# Pool all vertices to universal.
class PoolVerticesToUniversal(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        v_in, e_in, u_in, adj, conEd = inputs

        u_out = tf.reduce_sum(v_in, axis=-1)

        return [v_in, e_in, u_out, adj, conEd]

In [None]:
vertices = Input(shape=((MAX_NODES, EMBEDDING_DIM,)), name="vertices")
edges = Input(shape=((MAX_EDGES, EMBEDDING_DIM,)), name="edges")
universal = Input(shape=((EMBEDDING_DIM,)), name="universal")
adj = Input(shape=((MAX_NODES, MAX_NODES,)), name="adj")
conEd = Input(shape=((MAX_NODES, MAX_EDGES,)), name="connectedEdges")

x = PoolEdgesToVertices()([vertices, edges, universal, adj, conEd])
x = PoolVerticesToEdges()(x)
x = GraphUpdate(32, 32, 32)([vertices, edges, universal, adj, conEd])

x = PoolEdgesToVertices()(x)
x = PoolVerticesToEdges()(x)
x = GraphUpdate(32, 32, 32)(x)

x = PoolVerticesToUniversal()(x)
x = GraphUpdate(32, 32, 32)(x)

u = x[2]
u = Dense(32, activation="relu")(u)
u = Dense(18211)(u)

model = Model(inputs=[vertices, edges, universal, adj, conEd], outputs=u)

In [None]:
tf.keras.utils.plot_model(model)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(1e-3),
    loss=tf.keras.losses.MeanSquaredError()
)

In [None]:
model.fit(trainData, targetsTrain, validation_data=(testData, targetsTest), epochs=50, batch_size=32)