In [1]:
# Necessary imports
%load_ext autoreload
%autoreload 2\

import networkx as nx
import numpy as np
import tensorflow as tf

from keras import backend as K
from keras.layers import Dense
from keras.models import Model, Sequential


from TCGAMultiOmics.multiomics import MultiOmicsData
from moge.network.heterogeneous_network import HeterogeneousNetwork


Using TensorFlow backend.


#  Import network from file

In [2]:
import pickle

# WRITE
# with open('moge/data/lncRNA_miRNA_mRNA/miRNA-mRNA_network_test_05_val_01_seed_0.pickle', 'wb') as file:
#     pickle.dump(network, file)

# READ
with open('moge/data/lncRNA_miRNA_mRNA/lncRNA-miRNA-mRNA_network_new.pickle', 'rb') as file:
# with open('moge/data/lncRNA_miRNA_mRNA/miRNA-mRNA_network_biogrid.pickle', 'rb') as file:
    network = pickle.load(file)
#     network.remove_extra_nodes()
#     network.node_list = network.all_nodes
#     node_list = network.node_list

In [3]:
for u,v,d in network.G.edges(data=True):
    if d["type"] == 'u_n':
        d['weight']+=1e-8

In [4]:
# READ edgelists
# with open('moge/data/lncRNA_miRNA_mRNA/miRNA-mRNA_network_test_05_val_01_seed_0_test_edges.pickle', 'rb') as file:
#     test_edges_dict = pickle.load(file)
    
# with open('moge/data/lncRNA_miRNA_mRNA/miRNA-mRNA_network_test_05_val_01_seed_0_val_edges.pickle', 'rb') as file:
#     val_edges_dict = pickle.load(file)

# Load training data

In [5]:
# X, y = network.multi_omics_data.load_data(modalities=["MIR", "GE"])

In [6]:
# network.multi_omics_data.external_data_path = "/home/jonny/PycharmProjects/Bioinformatics_ExternalData/"

In [7]:
# X["MIR"].shape

# Training Source Target Graph Embedding

In [4]:
from keras.layers import Input, Conv1D, Lambda, Dot, Dense, Flatten, MaxPooling1D, Lambda, Convolution1D, Layer
from keras.layers import LSTM, Dense, TimeDistributed, Dropout, Bidirectional, AveragePooling1D
from keras.models import Model, Sequential
from keras.regularizers import l2
from keras import backend as K
import keras

from keras.optimizers import SGD, Adam, RMSprop
from keras.losses import binary_crossentropy
from keras.metrics import kullback_leibler_divergence, binary_crossentropy, binary_accuracy

from keras.utils import to_categorical

from keras.callbacks import TensorBoard

def W_init(shape,name=None):
    """Initialize weights as in paper"""
    values = np.random.normal(loc=0,scale=1e-2,size=shape)
    return K.variable(values,name=name)
#//TODO: figure out how to initialize layer biases in keras.
def b_init(shape,name=None):
    """Initialize bias as in paper"""
    values=np.random.normal(loc=0.5,scale=1e-2,size=shape)
    return K.variable(values,name=name)


In [5]:
K.clear_session()
tf.reset_default_graph()
# sess.close()
sess = tf.InteractiveSession()

In [6]:
# node_features_size = X["MIR"].shape[0]
input_shape = (None, 4)
_d = 128
batch_size = 1

In [7]:
# Inputs
E_ij = Input(batch_shape=(batch_size, 1), name="E_ij")
input_seq_i = Input(batch_shape=(batch_size, None, 4), name="input_seq_i")
input_seq_j = Input(batch_shape=(batch_size, None, 4), name="input_seq_j")
is_directed = Input(batch_shape=(batch_size, 1), dtype=tf.bool, name="is_directed")


In [8]:
def create_base_network(input_shape=(None, 4)):
    """ Base network to be shared (eq. to feature extraction).
    """
    input = Input(shape=input_shape)
#     x = Flatten()(input)
    x = Convolution1D(filters=512, kernel_size=5, input_shape=input_shape, activation='relu')(input)
    print("conv1d", x)
    x = MaxPooling1D(pool_size=13, strides=13)(x) # Similar to DanQ Model
    print("avg pool", x)
    x = Bidirectional(LSTM(256, input_shape=input_shape, return_sequences=False, return_state=False))(x)
    print("brnn", x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x) # Embedding space
    return Model(input, x)

def euclidean_distance(vects):
    x, y = vects
    return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))

def st_euclidean_distance(vects):
    emb_i, emb_j, is_directed = vects
    sum_directed = 2 * K.sum(K.square(emb_i[:, 0:int(_d/2)] - emb_j[:, int(_d/2):_d]), axis=1, keepdims=True)
    sum_undirected = K.sum(K.square(emb_i - emb_j), axis=1, keepdims=True)
    sum_switch = K.switch(is_directed, sum_directed, sum_undirected)
    return K.sqrt(K.maximum(sum_switch, K.epsilon()))

def st_embedding_probability(vects):
    emb_i, emb_j, is_directed = vects
    dot_directed = 2 * Dot(axes=1)([emb_i[:, 0:int(_d/2)], emb_j[:, int(_d/2):_d]])
    dot_undirected = Dot(axes=1)([emb_i, emb_j])
    return K.switch(is_directed, K.sigmoid(dot_directed), K.sigmoid(dot_undirected))

def st_embedding_probability_w_dense(vects):
    emb_i, emb_j, is_directed = vects
    directed = Dense(1, activation='sigmoid')(Dot(axes=1)([emb_i[:, 0:int(_d/2)], emb_j[:, int(_d/2):_d]]))
    undirected = Dense(1, activation='sigmoid')(Dot(axes=1)([emb_i, emb_j]))
    return K.switch(is_directed, directed, undirected)

def st_l1_distance(vects):
    emb_i, emb_j, is_directed = vects
    L1_layer = Lambda(lambda tensors: K.abs(tensors[0] - tensors[1]))
    directed_distance = L1_layer([emb_i[:, 0:int(_d/2)], emb_j[:, int(_d/2):_d]])
    undirected_distance = L1_layer([emb_i, emb_j])
    
    return Dense(1, activation='sigmoid')(
        K.switch(is_directed, directed_distance, undirected_distance))


# Loss function
def contrastive_loss(y_true, y_pred):
    '''Contrastive loss from Hadsell-et-al.'06
    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    '''
    margin = 0.8
    return K.mean(y_true * K.square(y_pred) +
                  (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))

def regularized_cross_entropy(y_true, y_pred):
    return K.mean(y_true * K.log(y_pred) + (1 - y_true) * K.log(1 - y_pred))


def kl_loss(y_true, y_pred):
    return -K.mean(y_true * K.log(y_pred))


# Metrics
def accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    return K.mean(K.equal(K.cast(y_true > 0.5, y_true.dtype), K.cast(y_pred > 0.5, y_true.dtype)))

In [9]:
#build create_base_network to use in each siamese 'leg'
lstm_network = create_base_network(input_shape=input_shape)

print("lstm_network", lstm_network)

# encode each of the two inputs into a vector with the convnet
encoded_i = lstm_network(input_seq_i)
encoded_j = lstm_network(input_seq_j)
print("encoded_i", encoded_i, "\nencoded_j", encoded_j)

distance = Lambda(st_euclidean_distance)([encoded_i, encoded_j, is_directed])
print("distance", distance)

siamese_net = Model(inputs=[input_seq_i, input_seq_j, is_directed], outputs=distance)

conv1d Tensor("conv1d_1/Relu:0", shape=(?, ?, 512), dtype=float32)
avg pool Tensor("max_pooling1d_1/Squeeze:0", shape=(?, ?, 512), dtype=float32)
brnn Tensor("bidirectional_1/concat:0", shape=(?, 512), dtype=float32)
lstm_network <keras.engine.training.Model object at 0x7f09dc86dbe0>
encoded_i Tensor("model_1/dense_3/Relu:0", shape=(1, 128), dtype=float32) 
encoded_j Tensor("model_1_1/dense_3/Relu:0", shape=(1, 128), dtype=float32)
distance Tensor("lambda_1/Sqrt:0", shape=(1, 1), dtype=float32)


In [10]:
#//TODO: get layerwise learning rates and momentum annealing scheme described in paperworking
siamese_net.compile(loss=contrastive_loss, 
                    optimizer=RMSprop(lr=0.01),
                    metrics=[accuracy])

siamese_net.count_params()


2701440

In [45]:
# Tensorboard
# tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)

# Data Generator

In [26]:
from moge.network.data_generator import DataGenerator

generator = DataGenerator(network.node_list, network=network, 
                          batch_size=1, dim=(None, 4), shuffle=True)

X, y = generator.__getitem__(0)
print("X:", [(k, v.shape) for k, v in X.items()], "\ny:", y.shape)

Ed_count 448908 Eu_count 2550360 En_count 2307016
X: [('input_seq_i', (1, 66, 4)), ('input_seq_j', (1, 126, 4)), ('is_directed', (1,))] 
y: (1,)


# Train

In [None]:
siamese_net.fit_generator(generator, use_multiprocessing=True, workers=9)

Epoch 1/1
   6067/5306284 [..............................] - ETA: 290:49:45 - loss: 3.5414 - accuracy: 0.2186

In [None]:
siamese_net.save("siamese_net_1.h5")

In [None]:
del siamese_net

# Inference

In [None]:
siamese_net.predict(X)