In [5]:
from __future__ import division
from __future__ import print_function
from operator import itemgetter

import tensorflow as tf
import numpy as np
import networkx as nx
import scipy.sparse as sp
from sklearn import metrics
import matplotlib.pyplot as plt
import h5py
import pickle
import os

from decagon.deep.optimizer import DecagonOptimizer
from decagon.deep.model import DecagonModel
from decagon.deep.minibatch import EdgeMinibatchIterator
from decagon.utility import rank_metrics, preprocessing

np.random.seed(0)

In [6]:
def tsne_visualization(matrix):
    from sklearn.manifold import TSNE
    import matplotlib.pyplot as plt
    plt.figure(dpi=300)
    tsne = TSNE(n_components=2, verbose=1, perplexity=40, random_state=0,
            n_iter=1000)
    tsne_results = tsne.fit_transform(matrix)
    plt.scatter(tsne_results[:, 0], tsne_results[:, 1])
    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()

def draw_graph(adj_matrix):
    G = nx.from_scipy_sparse_matrix(adj_matrix)
    pos = nx.spring_layout(G, iterations=100)
    d = dict(nx.degree(G))
    nx.draw(G, pos, node_color=range(3215), nodelist=d.keys(), 
        node_size=[v*20+20 for v in d.values()], cmap=plt.cm.Dark2)
    plt.show()

def get_accuracy_scores(edges_pos, edges_neg, edge_type, name=None):
    feed_dict.update({placeholders['dropout']: 0})
    feed_dict.update({placeholders['batch_edge_type_idx']: minibatch.edge_type2idx[edge_type]})
    feed_dict.update({placeholders['batch_row_edge_type']: edge_type[0]})
    feed_dict.update({placeholders['batch_col_edge_type']: edge_type[1]})
    rec = sess.run(opt.predictions, feed_dict=feed_dict)

    def sigmoid(x):
        return 1. / (1 + np.exp(-x))

    preds = []
    actual = []
    predicted = []
    edge_ind = 0
    for u, v in edges_pos[edge_type[:2]][edge_type[2]]:
        score = sigmoid(rec[u, v])
        preds.append(score)

        assert adj_mats_orig[edge_type[:2]][edge_type[2]][u,v] == 1, 'Problem 1'

        actual.append(edge_ind)
        predicted.append((score, edge_ind))
        edge_ind += 1

    preds_neg = []
    for u, v in edges_neg[edge_type[:2]][edge_type[2]]:
        score = sigmoid(rec[u, v])
        preds_neg.append(score)
        assert adj_mats_orig[edge_type[:2]][edge_type[2]][u,v] == 0, 'Problem 0'

        predicted.append((score, edge_ind))
        edge_ind += 1

    preds_all = np.hstack([preds, preds_neg])
    preds_all = np.nan_to_num(preds_all)
    labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])
    predicted = list(zip(*sorted(predicted, reverse=True, key=itemgetter(0))))[1]

    roc_sc = metrics.roc_auc_score(labels_all, preds_all)
    aupr_sc = metrics.average_precision_score(labels_all, preds_all)
    apk_sc = rank_metrics.apk(actual, predicted, k=200)
    #bedroc_sc = bedroc_score(labels_all, preds_all)
    if name!=None:
        with open(name, 'wb') as f:
            pickle.dump([labels_all, preds_all], f)
    return roc_sc, aupr_sc, apk_sc#, bedroc_sc


def construct_placeholders(edge_types):
    tf.compat.v1.disable_eager_execution()
    placeholders = {
        'batch': tf.compat.v1.placeholder(tf.int64, name='batch'),
        'batch_edge_type_idx': tf.compat.v1.placeholder(tf.int64, shape=(), name='batch_edge_type_idx'),
        'batch_row_edge_type': tf.compat.v1.placeholder(tf.int64, shape=(), name='batch_row_edge_type'),
        'batch_col_edge_type': tf.compat.v1.placeholder(tf.int64, shape=(), name='batch_col_edge_type'),
        'degrees': tf.compat.v1.placeholder(tf.int64),
        'dropout': tf.compat.v1.placeholder_with_default(0., shape=()),
    }
    placeholders.update({
        'adj_mats_%d,%d,%d' % (i, j, k): tf.compat.v1.sparse_placeholder(tf.floa64)
        for i, j in edge_types for k in range(edge_types[i,j])})
    placeholders.update({
        'feat_%d' % i: tf.compat.v1.sparse_placeholder(tf.float32)
        for i, _ in edge_types})
    return placeholders

def network_edge_threshold(network_adj, threshold):
    edge_tmp, edge_value, shape_tmp = preprocessing.sparse_to_tuple(network_adj)
    preserved_edge_index = np.where(edge_value>threshold)[0]
    preserved_network = sp.csr_matrix(
        (edge_value[preserved_edge_index], 
        (edge_tmp[preserved_edge_index,0], edge_tmp[preserved_edge_index, 1])),
        shape=shape_tmp)
    return preserved_network


def get_prediction(edge_type):
    feed_dict.update({placeholders['dropout']: 0})
    feed_dict.update({placeholders['batch_edge_type_idx']: minibatch.edge_type2idx[edge_type]})
    feed_dict.update({placeholders['batch_row_edge_type']: edge_type[0]})
    feed_dict.update({placeholders['batch_col_edge_type']: edge_type[1]})
    rec = sess.run(opt.predictions, feed_dict=feed_dict)

    return 1. / (1 + np.exp(-rec))


In [11]:
gene_phenes_path = './data_prioritization/genes_phenes.mat'
f = h5py.File(gene_phenes_path, 'r')
gene_network_adj = sp.csc_matrix((np.array(f['GeneGene_Hs']['data']),
    np.array(f['GeneGene_Hs']['ir']), np.array(f['GeneGene_Hs']['jc'])),
    shape=(12331,12331))

gene_network_adj = gene_network_adj.tocsr()
print(gene_network_adj)

  (0, 41)	0.3429869656355365
  (0, 42)	0.229051470285678
  (0, 158)	0.2571748481708065
  (0, 169)	1.42149513328269
  (0, 202)	0.79362367099361
  (0, 216)	0.2744082575786565
  (0, 232)	0.221678022306955
  (0, 235)	0.786535443737085
  (0, 242)	0.359475328536375
  (0, 365)	0.80406353210639
  (0, 418)	0.308061403101067
  (0, 447)	0.3258654664344195
  (0, 460)	0.917947990535165
  (0, 461)	0.291648092061384
  (0, 465)	0.308741672856747
  (0, 468)	0.3439394442851005
  (0, 558)	0.2813892224095805
  (0, 561)	0.2047004183398565
  (0, 570)	0.2290764456406575
  (0, 597)	0.2837714909401
  (0, 616)	0.2906006554127645
  (0, 626)	0.479701048909572
  (0, 631)	0.2090372752959145
  (0, 659)	0.4912143818685225
  (0, 691)	0.52548244154808
  :	:
  (12052, 3133)	0.989738747149405
  (12052, 4840)	0.989738747149405
  (12052, 7245)	0.989738747149405
  (12052, 7572)	0.989738747149405
  (12052, 8141)	0.396576240113495
  (12052, 9196)	0.989738747149405
  (12066, 1677)	0.93273741741171
  (12066, 1678)	0.89540169644

In [19]:
disease_network_adj = sp.csc_matrix((np.array(f['PhenotypeSimilarities']['data']),
    np.array(f['PhenotypeSimilarities']['ir']), np.array(f['PhenotypeSimilarities']['jc'])),
    shape=(3215, 3215))
disease_network_adj = disease_network_adj.tocsr()
disease_network_adj = network_edge_threshold(disease_network_adj, 0.2)
print(disease_network_adj)

  (0, 0)	1.0
  (1, 1)	1.0
  (2, 2)	1.0
  (2, 4)	0.269991
  (2, 7)	0.2932
  (2, 12)	0.339762
  (2, 13)	0.219626
  (2, 15)	0.231472
  (2, 16)	0.268054
  (2, 26)	0.201734
  (2, 27)	0.204409
  (2, 28)	0.294928
  (2, 29)	0.271269
  (2, 30)	0.290155
  (2, 31)	0.265699
  (2, 33)	0.286343
  (2, 42)	0.23041
  (2, 44)	0.20191
  (2, 48)	0.204602
  (2, 62)	0.36473
  (2, 87)	0.206257
  (2, 97)	0.294945
  (2, 108)	0.262
  (2, 109)	0.206412
  (2, 114)	0.361163
  :	:
  (3212, 259)	0.308152
  (3212, 525)	0.230902
  (3212, 564)	0.266813
  (3212, 627)	0.256258
  (3212, 656)	0.210021
  (3212, 914)	0.270097
  (3212, 982)	0.219849
  (3212, 1040)	0.292659
  (3212, 1111)	0.336164
  (3212, 1112)	0.327738
  (3212, 1139)	0.219679
  (3212, 1234)	0.223895
  (3212, 1265)	0.213606
  (3212, 1654)	0.218028
  (3212, 1836)	0.254725
  (3212, 1869)	0.252043
  (3212, 1875)	0.205368
  (3212, 1922)	0.237681
  (3212, 1955)	0.400465
  (3212, 1992)	0.314485
  (3212, 1993)	0.212216
  (3212, 2089)	0.218028
  (3212, 3212)	1.0
  (3

In [20]:
dg_ref = f['GenePhene'][0][0]
gene_disease_adj = sp.csc_matrix((np.array(f[dg_ref]['data']),
    np.array(f[dg_ref]['ir']), np.array(f[dg_ref]['jc'])),
    shape=(12331, 3215))
gene_disease_adj = gene_disease_adj.tocsr()
print(gene_disease_adj)

  (0, 24)	1.0
  (2, 939)	1.0
  (5, 3093)	1.0
  (7, 2743)	1.0
  (8, 286)	1.0
  (8, 653)	1.0
  (8, 1767)	1.0
  (10, 2392)	1.0
  (11, 1423)	1.0
  (13, 352)	1.0
  (13, 968)	1.0
  (13, 1665)	1.0
  (13, 1768)	1.0
  (22, 617)	1.0
  (23, 618)	1.0
  (24, 2267)	1.0
  (25, 619)	1.0
  (26, 641)	1.0
  (30, 74)	1.0
  (35, 1144)	1.0
  (38, 2063)	1.0
  (40, 394)	1.0
  (40, 1050)	1.0
  (41, 2479)	1.0
  (41, 2966)	1.0
  :	:
  (11828, 1231)	1.0
  (11829, 3150)	1.0
  (11830, 2699)	1.0
  (11831, 3158)	1.0
  (11848, 2875)	1.0
  (11864, 2105)	1.0
  (11890, 758)	1.0
  (11891, 1830)	1.0
  (11898, 2300)	1.0
  (11912, 1411)	1.0
  (11924, 3002)	1.0
  (11931, 1007)	1.0
  (11959, 2840)	1.0
  (11983, 873)	1.0
  (11990, 500)	1.0
  (12012, 2141)	1.0
  (12016, 500)	1.0
  (12027, 1572)	1.0
  (12050, 500)	1.0
  (12081, 2352)	1.0
  (12122, 182)	1.0
  (12127, 2765)	1.0
  (12128, 687)	1.0
  (12186, 2810)	1.0
  (12210, 1002)	1.0


In [21]:
novel_associations_adj = sp.csc_matrix((np.array(f['NovelAssociations']['data']),
    np.array(f['NovelAssociations']['ir']), np.array(f['NovelAssociations']['jc'])),
    shape=(12331,3215))
print(novel_associations_adj)

  (2029, 21)	1.0
  (560, 155)	1.0
  (441, 172)	1.0
  (4404, 181)	1.0
  (888, 182)	1.0
  (7751, 228)	1.0
  (10361, 228)	1.0
  (4344, 444)	1.0
  (6246, 499)	1.0
  (1540, 524)	1.0
  (2398, 556)	1.0
  (4976, 699)	1.0
  (1276, 788)	1.0
  (7444, 1006)	1.0
  (388, 1166)	1.0
  (1509, 1166)	1.0
  (4409, 1166)	1.0
  (1103, 1219)	1.0
  (5836, 1349)	1.0
  (3510, 1498)	1.0
  (916, 1578)	1.0
  (4455, 1662)	1.0
  (4017, 2824)	1.0
  (9380, 2824)	1.0
  (1344, 3142)	1.0
  (10953, 3209)	1.0
  (2122, 3210)	1.0
  (1837, 3211)	1.0
  (10060, 3211)	1.0
  (1332, 3212)	1.0
  (8407, 3212)	1.0
  (7140, 3213)	1.0
  (9161, 3213)	1.0
  (1634, 3214)	1.0


In [22]:
gene_feature_path = './data_prioritization/GeneFeatures.mat'
f_gene_feature = h5py.File(gene_feature_path,'r')
gene_feature_exp = np.array(f_gene_feature['GeneFeatures'])
gene_feature_exp = np.transpose(gene_feature_exp)
gene_network_exp = sp.csc_matrix(gene_feature_exp)
print(gene_network_exp)

  (0, 0)	0.6882990632744054
  (1, 0)	-1.084648526696566
  (2, 0)	0.17276799812704952
  (3, 0)	-0.8395759232908131
  (4, 0)	1.6577039875956963
  (5, 0)	-0.635559436678219
  (6, 0)	2.2653509692896194
  (7, 0)	-0.9304714468500453
  (8, 0)	-0.37276614017136284
  (9, 0)	-0.024408475622041657
  (10, 0)	-0.20450392274146112
  (11, 0)	-0.4784230553830688
  (12, 0)	1.2502239202204724
  (13, 0)	-0.04775975269177065
  (14, 0)	1.1467630358279635
  (15, 0)	0.3647603963245443
  (16, 0)	0.23184940061151954
  (17, 0)	0.9046614811939566
  (18, 0)	0.06877978160304124
  (19, 0)	0.7444933594561656
  (20, 0)	-0.21368056170101346
  (21, 0)	-1.2036478507142434
  (22, 0)	-0.7363345778312629
  (23, 0)	0.22232490417681158
  (24, 0)	-1.5881083851282265
  :	:
  (11712, 4535)	0.9806119884439489
  (11728, 4535)	-0.6327634554065862
  (11745, 4535)	-1.8441555434074688
  (11746, 4535)	-0.2762316900654667
  (11765, 4535)	-0.5175118864508739
  (11770, 4535)	0.5998631651748905
  (11807, 4535)	-0.7875644118297614
  (11817

In [27]:
row_list = [3215, 1137, 744, 2503, 1143, 324, 1188, 4662, 1243]
gene_feature_list_other_spe = list()
for i in range(1,9):
    dg_ref = f['GenePhene'][i][0]
    disease_gene_adj_tmp = sp.csc_matrix((np.array(f[dg_ref]['data']),
        np.array(f[dg_ref]['ir']), np.array(f[dg_ref]['jc'])),
        shape=(12331, row_list[i]))
    gene_feature_list_other_spe.append(disease_gene_adj_tmp)

print(gene_feature_list_other_spe)

[<12331x1137 sparse matrix of type '<class 'numpy.float64'>'
	with 12010 stored elements in Compressed Sparse Column format>, <12331x744 sparse matrix of type '<class 'numpy.float64'>'
	with 30519 stored elements in Compressed Sparse Column format>, <12331x2503 sparse matrix of type '<class 'numpy.float64'>'
	with 68525 stored elements in Compressed Sparse Column format>, <12331x1143 sparse matrix of type '<class 'numpy.float64'>'
	with 4500 stored elements in Compressed Sparse Column format>, <12331x324 sparse matrix of type '<class 'numpy.float64'>'
	with 72846 stored elements in Compressed Sparse Column format>, <12331x1188 sparse matrix of type '<class 'numpy.float64'>'
	with 22150 stored elements in Compressed Sparse Column format>, <12331x4662 sparse matrix of type '<class 'numpy.float64'>'
	with 75199 stored elements in Compressed Sparse Column format>, <12331x1243 sparse matrix of type '<class 'numpy.float64'>'
	with 73284 stored elements in Compressed Sparse Column format>]


In [28]:
disease_tfidf_path = './data_prioritization/clinicalfeatures_tfidf.mat'
f_disease_tfidf = h5py.File(disease_tfidf_path)
disease_tfidf = np.array(f_disease_tfidf['F'])
disease_tfidf = np.transpose(disease_tfidf)
disease_tfidf = sp.csc_matrix(disease_tfidf)
print(disease_tfidf)

  (733, 0)	0.044022276470087114
  (906, 0)	0.37015016570833714
  (1052, 0)	0.13573091092516065
  (1494, 0)	0.05737655295763999
  (31, 1)	0.015546762143879837
  (241, 1)	0.04260094735742209
  (671, 1)	0.12216976627331207
  (835, 1)	0.04442092789623052
  (840, 1)	0.04917518937365812
  (842, 1)	0.03495397806390919
  (902, 1)	0.04632988498592828
  (1107, 1)	0.14257318603397254
  (1146, 1)	0.17493992638384284
  (1787, 1)	0.11837221663325781
  (216, 2)	0.1620926141297842
  (227, 2)	0.043900071605066385
  (802, 2)	0.04427521428071298
  (803, 2)	0.1632102472686118
  (1194, 2)	0.04013714817224205
  (4, 3)	0.011143340608277742
  (17, 3)	0.03211236151169534
  (19, 3)	0.11949227901679864
  (37, 3)	0.03777680514065878
  (39, 3)	0.009688607703132934
  (40, 3)	0.016260390402779545
  :	:
  (1966, 16591)	0.24287404810076083
  (1986, 16591)	0.03993463886305298
  (2017, 16591)	0.06593623340768666
  (2027, 16591)	0.10916856322296231
  (2198, 16591)	0.018588439419258987
  (2208, 16591)	0.12872300600053924


In [30]:
dis_dis_adj_list= list()
dis_dis_adj_list.append(disease_network_adj)

val_test_size = 0.1
n_genes = 12331
n_dis = 3215
n_dis_rel_types = len(dis_dis_adj_list)
gene_adj = gene_network_adj
gene_degrees = np.array(gene_adj.sum(axis=0)).squeeze()

gene_dis_adj = gene_disease_adj
dis_gene_adj = gene_dis_adj.transpose(copy=True)

In [36]:
dis_degrees_list = [np.array(dis_adj.sum(axis=0)).squeeze() for dis_adj in dis_dis_adj_list]
print(dis_degrees_list)

[array([ 1.      ,  1.      , 91.266734, ...,  7.260462,  1.      ,
        1.      ])]


In [39]:
adj_mats_orig = {
    (0, 0): [gene_adj, gene_adj.transpose(copy=True)],
    (0, 1): [gene_dis_adj],
    (1, 0): [dis_gene_adj],
    (1, 1): dis_dis_adj_list + [x.transpose(copy=True) for x in dis_dis_adj_list],
}
degrees = {
    0: [gene_degrees, gene_degrees],
    1: dis_degrees_list + dis_degrees_list,
}

gene_feat = sp.hstack(gene_feature_list_other_spe+[gene_feature_exp])
gene_nonzero_feat, gene_num_feat = gene_feat.shape
gene_feat = preprocessing.sparse_to_tuple(gene_feat.tocoo())

dis_feat = disease_tfidf
dis_nonzero_feat, dis_num_feat = dis_feat.shape
dis_feat = preprocessing.sparse_to_tuple(dis_feat.tocoo())

num_feat = {
    0: gene_num_feat,
    1: dis_num_feat,
}
nonzero_feat = {
    0: gene_nonzero_feat,
    1: dis_nonzero_feat,
}
feat = {
    0: gene_feat,
    1: dis_feat,
}

In [43]:
edge_type2dim = {k: [adj.shape for adj in adjs] for k, adjs in adj_mats_orig.items()}

edge_type2decoder = {
    (0, 0): 'innerproduct',
    (0, 1): 'innerproduct',
    (1, 0): 'innerproduct',
    (1, 1): 'innerproduct',
}

edge_types = {k: len(v) for k, v in adj_mats_orig.items()}
num_edge_types = sum(edge_types.values())
print("Edge types:", "%d" % num_edge_types)

Edge types: 6


In [46]:
flags = {
        'neg_sample_size': 1,
        'learning_rate': 0.001,
        'hidden1': 64,
        'hidden2': 32,
        'weight_decay': 0.001,
        'dropout': 0.1,
        'max_margin': 0.1,
        'batch_size': 512,
        'bias': True
    }

In [51]:
tf.compat.v1.disable_eager_execution()
placeholders = {
    'batch': tf.compat.v1.placeholder(tf.int64, name='batch'),
    'batch_edge_type_idx': tf.compat.v1.placeholder(tf.int64, shape=(), name='batch_edge_type_idx'),
    'batch_row_edge_type': tf.compat.v1.placeholder(tf.int64, shape=(), name='batch_row_edge_type'),
    'batch_col_edge_type': tf.compat.v1.placeholder(tf.int64, shape=(), name='batch_col_edge_type'),
    'degrees': tf.compat.v1.placeholder(tf.int64),
    'dropout': tf.compat.v1.placeholder_with_default(0., shape=()),
}
placeholders.update({
    'adj_mats_%d,%d,%d' % (i, j, k): tf.compat.v1.sparse_placeholder(tf.float32)
    for i, j in edge_types for k in range(edge_types[i,j])})
placeholders.update({
    'feat_%d' % i: tf.compat.v1.sparse_placeholder(tf.float32)
    for i, _ in edge_types})

print(placeholders)

{'batch': <tf.Tensor 'batch_4:0' shape=<unknown> dtype=int64>, 'batch_edge_type_idx': <tf.Tensor 'batch_edge_type_idx_4:0' shape=() dtype=int64>, 'batch_row_edge_type': <tf.Tensor 'batch_row_edge_type_4:0' shape=() dtype=int64>, 'batch_col_edge_type': <tf.Tensor 'batch_col_edge_type_4:0' shape=() dtype=int64>, 'degrees': <tf.Tensor 'Placeholder_124:0' shape=<unknown> dtype=int64>, 'dropout': <tf.Tensor 'PlaceholderWithDefault_4:0' shape=() dtype=float32>, 'adj_mats_0,0,0': SparseTensor(indices=Tensor("Placeholder_127:0", shape=(None, None), dtype=int64), values=Tensor("Placeholder_126:0", shape=(None,), dtype=float32), dense_shape=Tensor("Placeholder_125:0", shape=(None,), dtype=int64)), 'adj_mats_0,0,1': SparseTensor(indices=Tensor("Placeholder_130:0", shape=(None, None), dtype=int64), values=Tensor("Placeholder_129:0", shape=(None,), dtype=float32), dense_shape=Tensor("Placeholder_128:0", shape=(None,), dtype=int64)), 'adj_mats_0,1,0': SparseTensor(indices=Tensor("Placeholder_133:0",

In [54]:
minibatch = EdgeMinibatchIterator(
    adj_mats=adj_mats_orig,
    feat=feat,
    edge_types=edge_types,
    batch_size=flags['batch_size'],
    val_test_size=val_test_size
)

In [56]:
model = DecagonModel(
    placeholders=placeholders,
    num_feat=num_feat,
    nonzero_feat=nonzero_feat,
    edge_types=edge_types,
    decoders=edge_type2decoder,
)

In [58]:
with tf.name_scope('optimizer'):
    opt = DecagonOptimizer(
        embeddings=model.embeddings,
        latent_inters=model.latent_inters,
        latent_varies=model.latent_varies,
        degrees=degrees,
        edge_types=edge_types,
        edge_type2dim=edge_type2dim,
        placeholders=placeholders,
        batch_size=flags['batch_size'],
        margin=flags['max_margin']
    )



In [62]:
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
feed_dict = {}
saver = tf.compat.v1.train.Saver()
feed_dict = minibatch.next_minibatch_feed_dict(placeholders=placeholders)
feed_dict = minibatch.update_feed_dict(
    feed_dict=feed_dict,
    dropout=flags['dropout'],
    placeholders=placeholders
)

In [64]:
roc_score, auprc_score, apk_score = get_accuracy_scores(
    minibatch.test_edges, minibatch.test_edges_false, minibatch.idx2edge_type[3])
print("Edge type=", "[%02d, %02d, %02d]" % minibatch.idx2edge_type[3])
print("Edge type:", "%04d" % 3, "Test AUROC score", "{:.5f}".format(roc_score))
print("Edge type:", "%04d" % 3, "Test AUPRC score", "{:.5f}".format(auprc_score))
print("Edge type:", "%04d" % 3, "Test AP@k score", "{:.5f}".format(apk_score))

Edge type= [01, 00, 00]
Edge type: 0003 Test AUROC score 0.50000
Edge type: 0003 Test AUPRC score 0.50000
Edge type: 0003 Test AP@k score 1.00000


In [65]:
prediction = get_prediction(minibatch.idx2edge_type[3])

print(prediction)

[[0.5 0.5 0.5 ... 0.5 0.5 0.5]
 [0.5 0.5 0.5 ... 0.5 0.5 0.5]
 [0.5 0.5 0.5 ... 0.5 0.5 0.5]
 ...
 [0.5 0.5 0.5 ... 0.5 0.5 0.5]
 [0.5 0.5 0.5 ... 0.5 0.5 0.5]
 [0.5 0.5 0.5 ... 0.5 0.5 0.5]]
