In [1]:
import time

import numpy as np

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

# from torch_geometric.nn import VGAE
from torch_geometric.loader import DataLoader
from torch_geometric.utils import (degree, negative_sampling, 
                                   batched_negative_sampling,
                                  add_self_loops, to_undirected)

from torch.utils.tensorboard import SummaryWriter

from gene_graph_dataset import G3MedianDataset
from phylognn_model import G3Median_GCNConv, G3Median_VGAE

from sklearn.metrics import (roc_auc_score, roc_curve,
                             average_precision_score, 
                             precision_recall_curve,
                             f1_score, matthews_corrcoef)

from sklearn.model_selection import KFold

import matplotlib.pyplot as plt

In [2]:
from gene_graph_dataset import G3MedianDataset

In [3]:
from dcj_comp import dcj_dist
# from genome_file import mat2adj

In [4]:
dataset = G3MedianDataset('dataset_g3m_exps', 10, 1, 1000)

In [5]:
val_dataset = G3MedianDataset('val_seq_g3m', 10, 1, 200)

In [6]:
val_seq = torch.load('val_seq_g3m_3_10_1_200/raw/g3raw_10_1.pt')

In [7]:
train_batch, test_batch, val_batch = 256, 64, 8

device = torch.device('cuda:' + str(1) if torch.cuda.is_available() else 'cpu')

in_channels, out_channels = None, 128

In [8]:
dataset = dataset.shuffle()

In [9]:
def test_cycle(adj_dict, x, y):
    while True:
        x = x + (1 if x % 2 == 0 else -1)
        if x == y:
            return True
        if x not in adj_dict.keys():
            return False
        x = adj_dict[x]

In [10]:
# def test_cycle(adj_dict, x, y):
#     while True:
#         x = x + (1 if x % 2 == 0 else -1)
#         if x == y:
#             return True
#         if x not in adj_dict.keys():
#             return False
#         x = adj_dict[x]

def dict2adj(adj_dict, start):
    res = []
    while True:
        end = start + (1 if start % 2 == 0 else -1)
        res += [start, end]
        if end not in adj_dict.keys():
            break
        start = adj_dict[end]
        
    return [(res[i]//2 + 1) * (-1 if res[i] > res[i+1] else 1) for i in range(0, len(res), 2)]
        

def mat2adj(res_mat):
    gen_len = res_mat.shape[0]
    tmp = np.copy(res_mat)
    for i in range(0, gen_len, 2):
        tmp[i, i:i+2] = 0
        tmp[i + 1, i:i+2] = 0
    adj_dict = {}
    p_list = []
    while True:
        r,c = np.unravel_index(np.argmax(tmp, axis = None), tmp.shape)
        if r == c:
            print(adj_dict)
            break
                
        # test cycles
        if test_cycle(adj_dict, r, c):
            tmp[r, c] = 0
            tmp[c, r] = 0
            continue
        p_list.append(tmp[r,c])    
        tmp[(r,c), :] = 0
        tmp[:, (r,c)] = 0
        adj_dict[r] = c
        adj_dict[c] = r
        if len(adj_dict)//2 == (gen_len//2 - 1):
            break
    
    start = list(set(range(gen_len)) - set(adj_dict.keys()))
    
    return [dict2adj(adj_dict, a) for a in start], -np.log(p_list).sum()

In [11]:
def train(model, train_loader):
    model.train()
    
    total_loss = 0
    for data in train_loader:    
        optimizer.zero_grad()
        data = data.to(device)
        
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss_wt(z, data.pos_edge_label_index, data.neg_edge_label_index, 2, 1) * 5
        loss = loss + (1 / data.num_nodes) * model.kl_loss() * 0.5
        loss.backward()
        optimizer.step()
        
        total_loss += loss
    return total_loss/len(train_loader)

In [12]:
@torch.no_grad()
def predict(model, test_loader):
    model.eval()
    y_list, pred_list = [], []
        
    for data in test_loader:
        
        data = data.to(device)
        
        z = model.encode(data.x, data.edge_index)
        # loss += model.recon_loss(z, data.pos_edge_label_index, data.neg_edge_label_index)
        y, pred = model.pred(z, data.pos_edge_label_index, data.neg_edge_label_index)
        
        y_list.append(y)
        pred_list.append(pred)
        
    return y_list, pred_list

@torch.no_grad()
def val(model, val_loader):
    model.eval()
    loss = 0
    
    for data in val_loader:        
        data = data.to(device)        
        z = model.encode(data.x, data.edge_index)        
        loss += model.recon_loss_wt(z, data.pos_edge_label_index, data.neg_edge_label_index, 2, 1)
        # tauc, tap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
                
    return loss/len(val_loader)

def auc_ap(y_list, pred_list):
    pred_accuracy = [[roc_auc_score(y, pred), average_precision_score(y, pred)]
                     for y, pred in zip(y_list, pred_list)]
    auc, ap = np.mean(pred_accuracy, axis = 0)
    return auc, ap

In [13]:
y_pred_res = []
counter = 1

In [14]:
print(f'{time.ctime()} -- seqlen:{10:0>4} '
      f'rate:{0.1:.2f} samples:{5000:0>5} -- fold: {counter:0>2}')

model = G3Median_VGAE(G3Median_GCNConv(in_channels, out_channels)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001,verbose=True)

train_dataset = dataset[:int(len(dataset) * 1.0)]
# val_dataset = dataset[int(len(dataset) * 0.9):]

train_loader = DataLoader(train_dataset, batch_size = train_batch, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size = test_batch)
# val_loader = DataLoader(val_dataset, batch_size = val_batch)

start_time = time.time()

y_pred = None
p_auc, p_ap = 0, 0

Mon Jan 17 18:02:36 2022 -- seqlen:0010 rate:0.10 samples:05000 -- fold: 01


In [15]:
for epoch in range(1, 500 + 1):

    loss = train(model, train_loader)
    # tloss = val(model, val_loader)
    scheduler.step(loss)
    
    if epoch % 10 != 0:
        continue
    
    model.eval()
    count = 0
    num = 0
    for d, seqs in zip(val_dataset, val_seq[0]):
        d = d.to(device)
        z = model.encode(d.x, d.edge_index)
        res = model.decoder.forward_all(z).detach().cpu().numpy()
    #     pred_seqs = mat2adj(res)
    #     pred_dist = min([sum([dcj_dist(pred, s)[-1] for s in seqs]) for pred in pred_seqs])
    #     low_dist = np.ceil(sum([dcj_dist(seqs[0], seqs[1])[-1], 
    #                             dcj_dist(seqs[0], seqs[2])[-1], 
    #                             dcj_dist(seqs[1], seqs[2])[-1]])/2)

    #     print(f'{pred_dist:>3} -- {low_dist:<3}')
    #     count += pred_dist - low_dist
        # print('---------')
        tmp_mat = np.copy(res)

        for i in range(0, 20, 2):
            tmp_mat[i, i:i+2] = 0
            tmp_mat[i + 1, i:i+2] = 0

        pred, prob_ = mat2adj(tmp_mat)
        mval = min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred])
        # print(mval, prob_)
        # print(pred)
        tag = False
        val_list = [mval]
        prob_list = [prob_]
        for _ in range(9):
            r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
            if r == c:
                break
            if tmp_mat[r,c] == 1.0:
                redv = 0.95
            else:
                redv = tmp_mat[r,c] * tmp_mat[r,c]
            tmp_mat[r,c] = redv
            tmp_mat[c,r] = redv
            pred, prob = mat2adj(tmp_mat)
            # if prob < prob_:
            #     print(f'** {prob}')
            #     print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
            val_list.append(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
            prob_list.append(prob)
            # if min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]) < mval:
            #     # print(pred)
            #     # print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
            #     tag = True
            # print(pred)
        if np.argmin(val_list) == np.argmin(prob_list):
            count += 1
        if np.argmin(val_list) == 0:
            num += 1
        # print(np.argmin(val_list), val_list)
        # print(np.argmin(prob_list), prob_list)
    print(epoch, count, num)

10 82 150
20 90 164
30 96 155
40 88 146
50 87 143
60 94 137
70 92 134
80 90 126
90 84 121
100 81 101
110 86 106
120 84 100
130 82 100
140 81 102
150 82 97
160 84 99
170 85 89
180 89 91
190 87 90
200 85 92
Epoch   209: reducing learning rate of group 0 to 2.5000e-03.
210 90 91
220 85 83
230 92 86
Epoch   233: reducing learning rate of group 0 to 1.2500e-03.
240 89 88
250 88 83
260 84 84
Epoch   268: reducing learning rate of group 0 to 6.2500e-04.
270 90 86
280 91 88
290 88 86
300 87 87
310 85 87
320 91 86
Epoch   323: reducing learning rate of group 0 to 3.1250e-04.
330 85 89
340 83 89
350 85 89
360 83 86
370 85 89
Epoch   375: reducing learning rate of group 0 to 1.5625e-04.
380 82 89
Epoch   388: reducing learning rate of group 0 to 7.8125e-05.
390 85 91
400 85 90
Epoch   402: reducing learning rate of group 0 to 3.9063e-05.
410 83 89
420 83 89
Epoch   421: reducing learning rate of group 0 to 1.9531e-05.
430 84 90
Epoch   432: reducing learning rate of group 0 to 1.0000e-05.
440 83 

In [80]:
model.eval()
count = 0
num = 0
for d, seqs in zip(val_dataset, val_seq[0]):
    d = d.to(device)
    z = model.encode(d.x, d.edge_index)
    res = model.decoder.forward_all(z).detach().cpu().numpy()
#     pred_seqs = mat2adj(res)
#     pred_dist = min([sum([dcj_dist(pred, s)[-1] for s in seqs]) for pred in pred_seqs])
#     low_dist = np.ceil(sum([dcj_dist(seqs[0], seqs[1])[-1], 
#                             dcj_dist(seqs[0], seqs[2])[-1], 
#                             dcj_dist(seqs[1], seqs[2])[-1]])/2)

#     print(f'{pred_dist:>3} -- {low_dist:<3}')
#     count += pred_dist - low_dist
    # print('---------')
    tmp_mat = np.copy(res)

    for i in range(0, 20, 2):
        tmp_mat[i, i:i+2] = 0
        tmp_mat[i + 1, i:i+2] = 0

    pred, prob_ = mat2adj(tmp_mat)
    mval = min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred])
    # print(mval, prob_)
    # print(pred)
    tag = False
    val_list = [mval]
    prob_list = [prob_]
    for _ in range(9):
        r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
        if r == c:
            break
        if tmp_mat[r,c] == 1.0:
            redv = 0.95
        else:
            redv = tmp_mat[r,c] * tmp_mat[r,c]
        tmp_mat[r,c] = redv
        tmp_mat[c,r] = redv
        pred, prob = mat2adj(tmp_mat)
        # if prob < prob_:
        #     print(f'** {prob}')
        #     print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
        val_list.append(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
        prob_list.append(prob)
        # if min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]) < mval:
        #     # print(pred)
        #     # print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
        #     tag = True
        # print(pred)
    if np.argmin(val_list) == np.argmin(prob_list):
        count += 1
    if np.argmin(val_list) == 0:
        num += 1
    # print(np.argmin(val_list), val_list)
    # print(np.argmin(prob_list), prob_list)
print(count, num)

91 89


In [55]:
tdict = {14: 18, 18: 14, 4: 15, 15: 4, 6: 16, 16: 6, 0: 11, 11: 0, 2: 12, 12: 2, 1: 17, 17: 1, 3: 7, 7: 3, 5: 13, 13: 5}

In [62]:
test_cycle(tdict, 9, 19)

False

In [71]:
tmp_mat = np.copy(res)

for i in range(0, 20, 2):
    tmp_mat[i, i:i+2] = 0
    tmp_mat[i + 1, i:i+2] = 0

pred, prob_ = mat2adj(tmp_mat)
mval = min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred])
print(mval, prob_)
# print(pred)
tag = False
val_list = [mval]
prob_list = [prob_]
for _ in range(9):
    print('-------')
    r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
    if tmp_mat[r,c] == 1.0:
        redv = 0.95
    else:
        redv = tmp_mat[r,c] * tmp_mat[r,c]
    tmp_mat[r,c] = redv
    tmp_mat[c,r] = redv
    pred, prob = mat2adj(tmp_mat)
    # if prob < prob_:
    #     print(f'** {prob}')
    #     print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
    val_list.append(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
    prob_list.append(prob)
    # if min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]) < mval:
    #     # print(pred)
    #     # print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
    #     tag = True
    # print(pred)
if np.argmin(val_list) == np.argmin(prob_list):
    count += 1
print(np.argmin(val_list), val_list)
print(np.argmin(prob_list), prob_list)

8 18
9 19
1.0
{8: 18, 18: 8} 9 19
9 10
14 19
4 15
6 16
0 11
2 12
1 17
3 7
8 0.04107344
-------
8 19
9 18
1.0
{8: 19, 19: 8} 9 18
9 10
14 18
4 15
6 16
0 11
2 12
1 17
3 7
-------
9 18
10 19
4 15
6 16
0 11
2 12
8 14
1 17
3 7
-------
9 19
10 18
4 15
6 16
0 11
2 12
8 14
1 17
3 7
-------
9 10
14 18
4 15
6 16
0 11
2 12
1 17
3 7
8 19
-------
10 18
14 19
4 15
6 16
0 11
2 12
1 17
3 7
5 13
0.7259615
{10: 18, 18: 10, 14: 19, 19: 14, 4: 15, 15: 4, 6: 16, 16: 6, 0: 11, 11: 0, 2: 12, 12: 2, 1: 17, 17: 1, 3: 7, 7: 3} 5 13
9 13
-------
9 10
14 18
4 15
6 16
0 11
2 12
1 17
3 7
8 19
-------
10 19
14 18
4 15
6 16
0 11
2 12
1 17
3 7
5 13
0.7259615
{10: 19, 19: 10, 14: 18, 18: 14, 4: 15, 15: 4, 6: 16, 16: 6, 0: 11, 11: 0, 2: 12, 12: 2, 1: 17, 17: 1, 3: 7, 7: 3} 5 13
9 13
-------
10 18
14 19
4 15
6 16
0 11
2 12
1 17
3 7
5 13
0.7259615
{10: 18, 18: 10, 14: 19, 19: 14, 4: 15, 15: 4, 6: 16, 16: 6, 0: 11, 11: 0, 2: 12, 12: 2, 1: 17, 17: 1, 3: 7, 7: 3} 5 13
9 13
-------
9 10
14 18
4 15
6 16
0 11
2 12
1 17
3 7
8 19

In [57]:
set(range(20)) - set(tdict.keys())

{8, 9, 10, 19}

In [61]:
from genome_file import encodeAdj
encodeAdj([1, 7, 9, -3, -10, -2, -4, -6, -8, -5])

array([-1,  0,  1, 12, 13, 16, 17,  5,  4, 19, 18,  3,  2,  7,  6, 11, 10,
       15, 14,  9,  8, -2], dtype=int32)

In [62]:
t = [0,  1, 12, 13, 16, 17,  5,  4, 19, 18,  3,  2,  7,  6, 11, 10,
       15, 14,  9,  8]

In [63]:
[(t[i]//2 + 1) * (-1 if t[i] > t[i+1] else 1) for i in range(0, len(t), 2)]

[1, 7, 9, -3, -10, -2, -4, -6, -8, -5]

In [28]:
mat2adj(tmp_mat)

[[1, 7, 9, -3, -10, -2, -4, -6, -8, -5], [5, 8, 6, 4, 2, 10, 3, -9, -7, -1]]

In [75]:
tmp_mat = np.copy(res)

for i in range(0, 20, 2):
    tmp_mat[i, i:i+2] = 0
    tmp_mat[i + 1, i:i+2] = 0

pred = mat2adj(tmp_mat)
print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
print(pred)

6
[[10, -2, 4, 9, 8, 5, 6, -1, -3, -7], [7, 3, 1, -6, -5, -8, -9, -4, 2, -10]]


In [None]:
tmp_mat

In [26]:
[  7,  10,  -2,   4,   9,   8,   5,   6,  -1,  -3]

6

In [70]:
seqs

array([[  7,  10,  -2,   4,   3,   1,  -6,  -5,  -8,  -9],
       [  1,  -6,  -5,  -8,  -9,  -4,   2, -10,  -7,  -3],
       [  7,  -4,   2, -10,   9,   8,   5,   6,  -1,  -3]], dtype=int32)

In [71]:
[dcj_dist([7, 3, 1, -6, -5, -8, -9, -4, 2, -10], s)[-1] for s in seqs]

[2, 2, 2]

In [73]:
res[13, 18]

0.6990546

In [29]:
tmp_mat[6,11] = 0
tmp_mat[11, 6] = 0

In [30]:
mat2adj(tmp_mat)

[[1, 7, 9, -3, -10, -2, 4, -6, -8, -5], [5, 8, 6, -4, 2, 10, 3, -9, -7, -1]]

In [76]:
for _ in range(10):
    r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
    tmp_mat[r,c] = 0
    tmp_mat[c,r] = 0
    pred = mat2adj(tmp_mat)
    print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
    print(pred)

9
[[10, -2, 4, 9, 8, 5, 6, 1, -3, -7], [7, 3, -1, -6, -5, -8, -9, -4, 2, -10]]
9
[[10, -2, 4, 9, 8, 5, -6, -1, -3, -7], [7, 3, 1, 6, -5, -8, -9, -4, 2, -10]]
15
[[10, -2, 4, 9, 8, -5, 7, 3, 1, 6], [-6, -1, -3, -7, 5, -8, -9, -4, 2, -10]]
12
[[10, -2, 4, 9, -8, 5, -6, -1, -3, -7], [7, 3, 1, 6, -5, 8, -9, -4, 2, -10]]
18
[[2, -10, 7, 3, 1, 6, -5, 8, -9, -4], [4, 9, -8, 5, -6, -1, -3, -7, 10, -2]]
19
[[1, -3, -7, 10, -2, -6, -5, 8, -9, -4], [4, 9, -8, 5, 6, 2, -10, 7, 3, -1]]
19
[[1, -3, -7, 10, 2, 4, 9, -8, 5, 6], [-6, -5, 8, -9, -4, -2, -10, 7, 3, -1]]
18
[[10, 2, 4, 9, -8, 5, -6, 1, -3, -7], [7, 3, -1, 6, -5, 8, -9, -4, -2, -10]]
24
[[-9, -4, -2, -10, 7, 3, -1, 6, -5, 8], [-8, 5, -6, 1, -3, -7, 10, 2, 4, 9]]
23
[[10, 2, 4, -9, 8, -5, 7, 3, -1, 6], [-6, 1, -3, -7, 5, -8, 9, -4, -2, -10]]


In [32]:
seqs

array([[-6,  4,  2, 10,  3,  5,  8, -1,  7,  9],
       [ 6, -8, -5, -2, -4, -3, -7, -1, 10,  9],
       [ 6, -4,  2, 10,  3,  5,  8,  1,  7,  9]], dtype=int32)

In [34]:
dcj_dist([5, 8, 6, 4, 2, 10, 3, -9, -7, -1], seqs[0])[-1]

5

In [55]:
[dcj_dist([1, 7, 9, -3, -10, -2, -4, 6, -8, -5], s)[-1] for s in seqs]

[3, 5, 4]

In [47]:
sum([dcj_dist([  6,   4,   2,  10,   3,   5,   8,   1,   7,   9], seqs[1])[-1], dcj_dist(seqs[0], [  6,   4,   2,  10,   3,   5,   8,   1,   7,   9])[-1], dcj_dist(seqs[2], [  6,   4,   2,  10,   3,   5,   8,   1,   7,   9])[-1]])

7

In [69]:
val_seq[1]

array([[  7,  -8,   3,  -2,  -9,  -6,   1,  -4,  -5, -10],
       [  5,  -3,   4,   7,  -2,   1,   6,  10,   9,   8],
       [  5, -10,  -4,   8,  -2,  -9,   6,   1,  -7,   3],
       [ -9,  -5,  -3,   8,  -4,  -2,  -6, -10,   7,  -1],
       [ -9,  10,  -4,   7,  -3,  -2,   1,   6,   8,  -5],
       [ -7,  10,  -4,   6,   8,   1,  -2,  -9,  -3,   5],
       [  1, -10,   3,   5,   7,  -8,   4,   9,  -6,   2],
       [ -3,  -2,  -7,   4,   8,   6,  10,   5,   1,   9],
       [ -7,   1,   8,   3,   5,   2,   6,  10,   9,  -4],
       [ -5,   7,  -4,  10,  -1,  -8,  -3,   6,   2,  -9],
       [  5,  -2,   7,   9,  -4,   1,  10,   8,   6,  -3],
       [ -8,  -6,   4,   7,  10,   2,  -1,  -9,   3,  -5],
       [  1,   3,  -8,  10,   4,   6,  -9,  -5,   7,  -2],
       [  2,   9,   3,  -1,   8,   4,   5,  10,   7,   6],
       [  4,  -5,   9,  -1,  -8,   3,   2,   7,  10,   6],
       [ -6,  -4,   3,   8,   7,   5,  -1, -10,  -9,  -2],
       [ 10,   4,   2,   5,   8,   1,   3,  -9,   6,   7

In [56]:
[dcj_dist([  6,   4,   2,  10,   3,   5,   8,   1,   7,   9], s)[-1] for s in seqs]

[2, 4, 1]