In [1]:
# Necessary imports
%load_ext autoreload
%autoreload 2

import networkx as nx
import numpy as np
import pandas as pd

from moge.network.heterogeneous_network import HeterogeneousNetwork

from sklearn.model_selection import ParameterGrid, ParameterSampler
from moge.visualization.plot_data import matrix_heatmap
pd.set_option('display.max_columns', 30)

In [2]:
import pickle

# READ
with open('moge/data/LMN_future_recall/TRAIN/Interactions_Affinity/LMN_mirtarbase_biogrid_starbase_lncrna2target_lncrinter.train.pickle', 'rb') as file:
    network = pickle.load(file)
    
with open('moge/data/LMN_future_recall/TEST/Interactions_Affinity/LMN_mirtarbase_biogrid_starbase_lncrna2target_lncrinter.test.pickle', 'rb') as file:
    network_val = pickle.load(file)
    
with open('moge/data/luad_data_multi_U-T.pickle', 'rb') as file:
    luad_data = pickle.load(file)
#     network.multi_omics_data = luad_data
#     network.process_genes_info()
    network_val.multi_omics_data = luad_data
    network_val.process_genes_info()

Genes info columns: ['Disease association', 'locus_type', 'Transcript sequence', 'Chromosome', 'GO Terms', 'Family']
Number of nodes without seq removed: -10093
Total nodes (filtered): 22648


# Set parameter space

In [4]:
from scipy.stats.distributions import uniform
parameters = {
    "d": [64, 128, 256],
    "lr": [0.001, 0.0005],
    "margin": [0.2,],
    "weighted": [True, False],
    "compression_func": ["sqrt", "log", "sqrt3", "linear"],
    "negative_sampling_ratio": [1.0, 2.0, 5.0, 20.0, 40.0],
    "directed_proba": [0.4, 0.5, 0.75, 1.0],
    "directed_distance": ["euclidean", "cosine", "dot_sigmoid"],
    "undirected_distance": ["euclidean", "cosine", "dot_sigmoid"], #, "euclidean", "dot_sigmoid"
    "max_length": [5050],
    "truncating": ["random", "post"],
    
    "conv1_kernel_size": [6, 12, 18, 26],
    "conv1_batch_norm": [True, False],
    "max1_pool_size": [3, 4, 6, 9],
    "conv2_kernel_size": [None, 2, 6, 12],
    "conv2_batch_norm": [True, False],
    "max2_pool_size": [2, 4, 6],
    "lstm_unit_size": [100, 160, 320],
    "dense1_unit_size": [None, 256, 512, 1024],
    "dense2_unit_size": [None, 256, 512],
    "source_target_dense_layers": [True, False],
    "embedding_normalization": [True, False]
}

# Train Model

In [5]:
from moge.embedding.siamese_graph_embedding import SiameseGraphEmbedding
from moge.embedding.siamese_triplet_online_embedding import SiameseOnlineTripletGraphEmbedding

#siamese = SiameseGraphEmbedding(subsample=False, batch_size=460, epochs=1, verbose=False)

siamese = SiameseOnlineTripletGraphEmbedding(batch_size=300, epochs=5, verbose=False)

Using TensorFlow backend.


directed_margin 0.2 , undirected_margin 0.1


In [6]:
best_score = float("inf")
X_params = []
y = []

for g in ParameterSampler(parameters, n_iter=300):
    current_score = float("inf")
    print(len(X_params), ":", g)
    siamese.set_params(**g)
    
    try:
        siamese.learn_embedding(network, network_val=network_val, multi_gpu=False,
                                n_steps=250, validation_steps=None, histogram_freq=0,
                                early_stopping=True,
                                tensorboard=False, rebuild_model=True,
                                seed=999)
    except Exception as e:
        print("Failed with exception:", type(e).__name__, e)
        continue
    except KeyboardInterrupt as e:
        break
    
    current_score = siamese.hist.history['val_loss'][-1]
    X_params.append(g)
    y.append(current_score)
    
    if current_score < best_score:
        best_score = current_score
        best_grid = g
        best_history = siamese.hist.history
        print("Score:", best_score, "\n")
    


0 : {'weighted': False, 'undirected_distance': 'euclidean', 'truncating': 'post', 'source_target_dense_layers': True, 'negative_sampling_ratio': 2.0, 'max_length': 5050, 'max2_pool_size': 6, 'max1_pool_size': 9, 'margin': 0.2, 'lstm_unit_size': 160, 'lr': 0.001, 'embedding_normalization': False, 'directed_proba': 1.0, 'directed_distance': 'cosine', 'dense2_unit_size': None, 'dense1_unit_size': None, 'd': 64, 'conv2_kernel_size': 2, 'conv2_batch_norm': False, 'conv1_kernel_size': 18, 'conv1_batch_norm': False, 'compression_func': 'sqrt3'}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Score: 11.274370880126954 

1 : {'weighted': False, 'undirected_distance': 'cosine', 'truncating': 'random', 'source_target_dense_layers': True, 'negative_sampling_ratio': 40.0, 'max_length': 5050, 'max2_pool_size': 6, 'max1_pool_size': 9, 'margin': 0.2, 'lstm_unit_size': 160, 'lr': 0.0005, 'embedding_normalization': True, 'directed_proba': 0.5, 'directed_distance': 'euclidean', 'dense2_unit_size': 512,

In [7]:
print(best_score)

4.906635723114014


In [8]:
print(best_grid)

{'weighted': True, 'undirected_distance': 'euclidean', 'truncating': 'post', 'source_target_dense_layers': False, 'negative_sampling_ratio': 20.0, 'max_length': 5050, 'max2_pool_size': 4, 'max1_pool_size': 3, 'margin': 0.2, 'lstm_unit_size': 320, 'lr': 0.0005, 'embedding_normalization': False, 'directed_proba': 0.4, 'directed_distance': 'cosine', 'dense2_unit_size': None, 'dense1_unit_size': 512, 'd': 128, 'conv2_kernel_size': 12, 'conv2_batch_norm': False, 'conv1_kernel_size': 6, 'conv1_batch_norm': False, 'compression_func': 'sqrt'}


In [9]:
len(X_params), len(y)

(8, 8)

In [10]:
X = pd.DataFrame(data=X_params)
X["loss"] = y
X.sort_values("loss")

Unnamed: 0,compression_func,conv1_batch_norm,conv1_kernel_size,conv2_batch_norm,conv2_kernel_size,d,dense1_unit_size,dense2_unit_size,directed_distance,directed_proba,embedding_normalization,lr,lstm_unit_size,margin,max1_pool_size,max2_pool_size,max_length,negative_sampling_ratio,source_target_dense_layers,truncating,undirected_distance,weighted,loss
3,sqrt,False,6,False,12.0,128,512.0,,cosine,0.4,False,0.0005,320,0.2,3,4,5050,20.0,False,post,euclidean,True,4.906636
0,sqrt3,False,18,False,2.0,64,,,cosine,1.0,False,0.001,160,0.2,9,6,5050,2.0,True,post,euclidean,False,11.274371
2,sqrt3,False,26,False,2.0,256,,512.0,euclidean,0.5,True,0.0005,100,0.2,3,4,5050,20.0,False,post,euclidean,True,11.996773
5,sqrt,True,18,False,6.0,256,512.0,,cosine,1.0,False,0.0005,100,0.2,3,4,5050,1.0,True,post,cosine,True,12.033129
6,log,True,6,False,12.0,256,,512.0,euclidean,0.5,False,0.0005,320,0.2,9,2,5050,2.0,False,post,euclidean,True,12.033129
7,log,False,6,True,6.0,128,,256.0,cosine,1.0,True,0.0005,160,0.2,6,2,5050,20.0,True,random,euclidean,True,12.033129
1,sqrt,False,12,True,,64,512.0,512.0,euclidean,0.5,True,0.0005,160,0.2,9,6,5050,40.0,True,random,cosine,False,14.811341
4,sqrt3,True,6,True,2.0,64,256.0,256.0,euclidean,1.0,True,0.0005,160,0.2,6,4,5050,2.0,True,post,dot_sigmoid,False,22.457425


# Correlation to loss

In [38]:
# Sampled triplet loss
X.corr()

Unnamed: 0,conv1_kernel_size,conv2_kernel_size,d,dense1_unit_size,dense2_unit_size,directed_proba,lr,lstm_unit_size,max1_pool_size,max2_pool_size,max_length,loss
conv1_kernel_size,1.0,0.083094,-0.067887,0.101041,0.029739,0.089139,0.001046,0.171966,0.010055,-0.05679,0.013527,-0.034231
conv2_kernel_size,0.083094,1.0,-0.016334,0.033706,0.125814,0.185991,-0.058078,0.019996,-0.013578,0.097513,0.056835,-0.154107
d,-0.067887,-0.016334,1.0,0.028041,-0.010332,0.061115,-0.038663,-0.186344,-0.008234,-0.165212,0.010824,0.000513
dense1_unit_size,0.101041,0.033706,0.028041,1.0,-0.038954,0.03471,-0.029428,-0.027874,0.032143,-0.114902,-0.110186,-0.015151
dense2_unit_size,0.029739,0.125814,-0.010332,-0.038954,1.0,0.037692,-0.055607,-0.092136,-0.038954,0.062757,0.104013,-0.082427
directed_proba,0.089139,0.185991,0.061115,0.03471,0.037692,1.0,0.039121,0.071548,-0.004573,-0.035792,0.069335,-0.084432
lr,0.001046,-0.058078,-0.038663,-0.029428,-0.055607,0.039121,1.0,0.124277,0.0328,-0.14581,0.020267,-0.188738
lstm_unit_size,0.171966,0.019996,-0.186344,-0.027874,-0.092136,0.071548,0.124277,1.0,0.093581,0.012975,-0.044179,-0.173416
max1_pool_size,0.010055,-0.013578,-0.008234,0.032143,-0.038954,-0.004573,0.0328,0.093581,1.0,0.049348,0.055061,-0.003275
max2_pool_size,-0.05679,0.097513,-0.165212,-0.114902,0.062757,-0.035792,-0.14581,0.012975,0.049348,1.0,-0.073825,-0.325119


In [10]:
# Sampled online triplet loss
X.corr()

Unnamed: 0,conv1_batch_norm,conv1_kernel_size,conv2_kernel_size,d,dense1_unit_size,dense2_unit_size,embedding_normalization,lr,lstm_unit_size,max1_pool_size,max2_pool_size,max_length,negative_sampling_ratio,loss
conv1_batch_norm,1.0,-0.081679,-6.223344e-18,-0.105209,-0.051803,-0.131466,0.013616,0.049637,-0.025464,0.015408,-0.090649,-0.027332,0.042505,-0.483913
conv1_kernel_size,-0.08167858,1.0,-0.02690937,0.096422,-0.019064,-0.13366,-0.011525,0.081356,-0.057643,-0.001097,0.087768,0.036916,-0.053697,0.03559
conv2_kernel_size,-6.223344e-18,-0.026909,1.0,-0.111828,-0.013795,0.008061,-0.044859,0.056112,-0.034023,0.002629,-0.035868,-0.030346,0.066992,0.037505
d,-0.1052085,0.096422,-0.1118282,1.0,0.039701,0.079836,0.10707,-0.036641,-0.097438,0.074063,-0.091483,0.064913,-0.010593,0.094543
dense1_unit_size,-0.05180293,-0.019064,-0.0137949,0.039701,1.0,-0.048898,0.041201,-0.086884,0.025244,0.013746,-0.087694,0.056602,-0.032716,-0.012541
dense2_unit_size,-0.1314657,-0.13366,0.008061051,0.079836,-0.048898,1.0,0.061331,-0.004586,0.053806,-0.100669,0.034264,-0.070893,0.046463,-0.061047
embedding_normalization,0.01361645,-0.011525,-0.04485903,0.10707,0.041201,0.061331,1.0,0.08862,0.049824,-0.055081,-0.032091,-0.08315,-0.016594,0.332455
lr,0.04963726,0.081356,0.05611161,-0.036641,-0.086884,-0.004586,0.08862,1.0,-0.041065,-0.086235,-0.08334,-0.053815,-0.018581,0.018668
lstm_unit_size,-0.02546436,-0.057643,-0.03402283,-0.097438,0.025244,0.053806,0.049824,-0.041065,1.0,-0.053959,0.047418,-0.025911,-0.057729,0.129468
max1_pool_size,0.01540841,-0.001097,0.002628934,0.074063,0.013746,-0.100669,-0.055081,-0.086235,-0.053959,1.0,-0.045595,0.112067,0.025152,-0.041907


In [9]:
# Sampled online triplet loss with context sampling
X.corr()

Unnamed: 0,conv1_batch_norm,conv1_kernel_size,conv2_batch_norm,conv2_kernel_size,d,dense1_unit_size,dense2_unit_size,embedding_normalization,lr,lstm_unit_size,max1_pool_size,max2_pool_size,max_length,negative_sampling_ratio,loss
conv1_batch_norm,1.0,-0.142089,-0.079786,0.085749,0.410495,-0.073969,0.096077,,-0.187283,0.19351,-0.112782,-0.202152,-0.047741,-0.003918,0.025792
conv1_kernel_size,-0.142089,1.0,0.276534,-0.350297,-0.025695,-0.009039,-0.30697,,-0.094955,-0.112095,0.268743,-0.385359,0.076977,-0.105051,0.059286
conv2_batch_norm,-0.079786,0.276534,1.0,-0.025364,0.294364,0.068344,0.096077,,0.28378,-0.076146,0.147246,0.07038,0.227504,0.073653,-0.018621
conv2_kernel_size,0.085749,-0.350297,-0.025364,1.0,0.141377,0.076945,0.239339,,0.310668,0.112913,0.044466,0.177003,0.048233,0.052893,-0.111492
d,0.410495,-0.025695,0.294364,0.141377,1.0,0.317513,0.268484,,0.246636,0.246957,0.136816,-0.089175,0.09477,-0.369831,0.15115
dense1_unit_size,-0.073969,-0.009039,0.068344,0.076945,0.317513,1.0,-0.240214,,0.090075,0.022861,-0.046204,0.0,0.146687,-0.269631,0.201122
dense2_unit_size,0.096077,-0.30697,0.096077,0.239339,0.268484,-0.240214,1.0,,0.163934,0.124599,0.123897,-0.044412,-0.169736,-0.128688,-0.226575
embedding_normalization,,,,,,,,,,,,,,,
lr,-0.187283,-0.094955,0.28378,0.310668,0.246636,0.090075,0.163934,,1.0,0.000254,0.17376,0.072386,-0.068492,0.184728,-0.456526
lstm_unit_size,0.19351,-0.112095,-0.076146,0.112913,0.246957,0.022861,0.124599,,0.000254,1.0,0.279556,0.069867,-0.067676,0.102058,-0.254394
