In [1]:
import time

import numpy as np

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

# from torch_geometric.nn import VGAE
from torch_geometric.loader import DataLoader
from torch_geometric.utils import (degree, negative_sampling, 
                                   batched_negative_sampling,
                                  add_self_loops, to_undirected)

from torch.utils.tensorboard import SummaryWriter

from gene_graph_dataset import G3MedianDataset
from phylognn_model import G3Median_GCNConv, G3Median_VGAE

from sklearn.metrics import (roc_auc_score, roc_curve,
                             average_precision_score, 
                             precision_recall_curve,
                             f1_score, matthews_corrcoef)

from sklearn.model_selection import KFold

import matplotlib.pyplot as plt

In [2]:
from gene_graph_dataset import G3MedianDataset

In [3]:
from dcj_comp import dcj_dist
from genome_file import mat2adj
from genome_file import dict2adj

In [4]:
dataset = G3MedianDataset('dataset_g3m_exps', 10, 9, 1000)

Generating...
Processing...
Done!


In [5]:
val_dataset = G3MedianDataset('val_seq_g3m', 10, 9, 200)

In [6]:
val_seq = torch.load('val_seq_g3m_3_10_9_200/raw/g3raw_10_9.pt')

In [7]:
train_batch, test_batch, val_batch = 256, 64, 8

device = torch.device('cuda:' + str(1) if torch.cuda.is_available() else 'cpu')

in_channels, out_channels = 75, 128

In [8]:
dataset = dataset.shuffle()

In [9]:
def test_cycle(adj_dict, x, y):
    while True:
        x = x + (1 if x % 2 == 0 else -1)
        if x == y:
            return True
        if x not in adj_dict.keys():
            return False
        x = adj_dict[x]

In [10]:
# def test_cycle(adj_dict, x, y):
#     while True:
#         x = x + (1 if x % 2 == 0 else -1)
#         if x == y:
#             return True
#         if x not in adj_dict.keys():
#             return False
#         x = adj_dict[x]

# def dict2adj(adj_dict, start):
#     res = []
#     while True:
#         end = start + (1 if start % 2 == 0 else -1)
#         res += [start, end]
#         if end not in adj_dict.keys():
#             break
#         start = adj_dict[end]
        
#     return [(res[i]//2 + 1) * (-1 if res[i] > res[i+1] else 1) for i in range(0, len(res), 2)]
        

# def mat2adj(res_mat):
#     gen_len = res_mat.shape[0]
#     tmp = np.copy(res_mat)
#     for i in range(0, gen_len, 2):
#         tmp[i, i:i+2] = 0
#         tmp[i + 1, i:i+2] = 0
#     adj_dict = {}
#     p_list = []
#     while True:
#         r,c = np.unravel_index(np.argmax(tmp, axis = None), tmp.shape)
#         if r == c:
#             print(adj_dict)
#             break
                
#         # test cycles
#         if test_cycle(adj_dict, r, c):
#             tmp[r, c] = 0
#             tmp[c, r] = 0
#             continue
#         p_list.append(tmp[r,c])    
#         tmp[(r,c), :] = 0
#         tmp[:, (r,c)] = 0
#         adj_dict[r] = c
#         adj_dict[c] = r
#         if len(adj_dict)//2 == (gen_len//2 - 1):
#             break
    
#     start = list(set(range(gen_len)) - set(adj_dict.keys()))
    
#     return [dict2adj(adj_dict, a) for a in start], -np.log(p_list).sum()

In [11]:
def train(model, train_loader):
    model.train()
    
    total_loss = 0
    for data in train_loader:    
        optimizer.zero_grad()
        data = data.to(device)
        
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss_wt(z, data.pos_edge_label_index, data.neg_edge_label_index, 2, 1) * 5
        loss = loss + (1 / data.num_nodes) * model.kl_loss() * 0.5
        loss.backward()
        optimizer.step()
        
        total_loss += loss
    return total_loss/len(train_loader)

In [12]:
# @torch.no_grad()
# def predict(model, test_loader):
#     model.eval()
#     y_list, pred_list = [], []
        
#     for data in test_loader:
        
#         data = data.to(device)
        
#         z = model.encode(data.x, data.edge_index)
#         # loss += model.recon_loss(z, data.pos_edge_label_index, data.neg_edge_label_index)
#         y, pred = model.pred(z, data.pos_edge_label_index, data.neg_edge_label_index)
        
#         y_list.append(y)
#         pred_list.append(pred)
        
#     return y_list, pred_list

# @torch.no_grad()
# def val(model, val_loader):
#     model.eval()
#     loss = 0
    
#     for data in val_loader:        
#         data = data.to(device)        
#         z = model.encode(data.x, data.edge_index)        
#         loss += model.recon_loss_wt(z, data.pos_edge_label_index, data.neg_edge_label_index, 2, 1)
#         # tauc, tap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
                
#     return loss/len(val_loader)

# def auc_ap(y_list, pred_list):
#     pred_accuracy = [[roc_auc_score(y, pred), average_precision_score(y, pred)]
#                      for y, pred in zip(y_list, pred_list)]
#     auc, ap = np.mean(pred_accuracy, axis = 0)
#     return auc, ap

In [13]:
y_pred_res = []
counter = 1

In [14]:
print(f'{time.ctime()} -- seqlen:{10:0>4} '
      f'rate:{0.1:.2f} samples:{5000:0>5} -- fold: {counter:0>2}')

model = G3Median_VGAE(G3Median_GCNConv(in_channels, out_channels)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001,verbose=True)

train_dataset = dataset[:int(len(dataset) * 1.0)]
# val_dataset = dataset[int(len(dataset) * 0.9):]

train_loader = DataLoader(train_dataset, batch_size = train_batch, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size = test_batch)
# val_loader = DataLoader(val_dataset, batch_size = val_batch)

start_time = time.time()

y_pred = None
p_auc, p_ap = 0, 0

Sat Feb 12 15:41:30 2022 -- seqlen:0010 rate:0.10 samples:05000 -- fold: 01


In [15]:
def pre_set_mat(res_mat):
    tmp_mat = np.copy(res_mat)
    for i in range(0, tmp_mat.shape[0], 2):
        tmp_mat[i, i:i+2] = 0
        tmp_mat[i + 1, i:i+2] = 0
        
    return tmp_mat

In [16]:
def dist_eval(d, steps = 10, rate = 0.9):
    d = d.to(device)
    z = model.encode(d.x, d.edge_index)
    res = model.decoder.forward_all(z).detach().cpu().numpy()
    
    pred_list, prob_list = [], []

    res_mat = pre_set_mat(res)
    pred, prob_ = mat2adj(res_mat)
    pred_list.append(pred)
    prob_list.append(prob_)

    for _ in range(1, steps):
        r,c = np.unravel_index(np.argmax(res_mat, axis = None), res_mat.shape)
        if r == c:
            break
        if res_mat[r,c] == 1.0:
            redv = 0.95
        else:
            redv = res_mat[r,c] * rate # res_mat[r,c]
        res_mat[r,c] = redv
        res_mat[c,r] = redv
        pred, prob_ = mat2adj(res_mat)
        
        pred_list.append(pred)
        prob_list.append(prob_)
        
    return pred_list, prob_list

In [17]:
def pred_pair(val_data, val_seqs, steps = 10, rate = 0.9):
    model.eval()
    pred_list, prob_list = dist_eval(val_data, steps, rate)
    result = []
    for pred, prob_ in zip(pred_list, prob_list):
        tmp_dist = [sum([dcj_dist(p, s)[-1] for s in val_seqs]) for p in pred]
        t = np.argmin(tmp_dist)
        # p_seq = pred[np.argmin(tmp_dist)]
        result.append((pred[t], prob_, tmp_dist[t]))
    return result

In [18]:
low_mean = 0
for seqs in val_seq[0]:
    low_mean += np.ceil(sum([dcj_dist(seqs[0], seqs[1])[-1], 
                            dcj_dist(seqs[0], seqs[2])[-1], 
                            dcj_dist(seqs[1], seqs[2])[-1]])/2)
low_mean = low_mean / len(val_seq[0])

In [19]:
low_mean

8.63

In [20]:
# writer = SummaryWriter(log_dir='dist_g3median_' f'{args.seqlen:0>4}' '/e' f'{args.samples:0>5}' '_r' 
#                        f'{args.rate:0>3.1f}' '_' 'run_' f'{args.vals:0>4}')
writer = SummaryWriter(log_dir='dist_g3median_' f'{10:0>4}' '/e' f'{1000:0>5}' '_r' 
                       f'{0.1:0>3.1f}' '_' 'run_' f'{200:0>4}')

In [21]:
final_dist = np.inf
final_list = []

In [22]:
for epoch in range(1, 100 + 1):
    loss = train(model, train_loader)
    # tloss = val(model, val_loader)
    scheduler.step(loss)
    
#     if epoch % 5 != 0:
#         continue
    
    result_list = []
    for vd, vs in zip(val_dataset, val_seq[0]):
        seq_dist = pred_pair(vd, vs, 20, rate = 0.7)
        dist_list = [x[2] for x in seq_dist]
        res_pair = seq_dist[np.argmin(dist_list)]
        result_list.append(res_pair)
        
    mean_dist = np.mean([x[2] for x in result_list])
    print(mean_dist, epoch)
    writer.add_scalar('dist/mean', mean_dist - low_mean, epoch)
    
    if mean_dist < final_dist:
        final_list = result_list

16.72 1
16.535 2
16.355 3
16.05 4
15.695 5
15.39 6
15.07 7
15.035 8
14.985 9
14.97 10
14.885 11
14.725 12
14.715 13
14.525 14
14.435 15
14.58 16
14.26 17
14.275 18
14.03 19
14.025 20
14.16 21
14.13 22
14.08 23
14.145 24
14.215 25
14.145 26
14.005 27
14.095 28
14.09 29
14.055 30
14.005 31
13.9 32
14.0 33
14.015 34
14.03 35
13.99 36
13.92 37
14.01 38
13.97 39
13.915 40
13.865 41
13.87 42
13.95 43
13.93 44
13.86 45
13.815 46
13.865 47
13.915 48
13.87 49
13.645 50
13.805 51
13.78 52
13.82 53
13.72 54
13.7 55
13.715 56
13.7 57
13.605 58
13.6 59
13.705 60
13.785 61
13.71 62
13.655 63
13.75 64
13.63 65
13.75 66
13.705 67
13.645 68
13.715 69
13.575 70
13.745 71
13.75 72
13.845 73
13.73 74
13.66 75
13.78 76
13.7 77
13.635 78
13.69 79
13.55 80
13.695 81
13.725 82
13.55 83
13.71 84
13.68 85
13.655 86
13.7 87
13.6 88
13.72 89
13.605 90
13.7 91
13.485 92
13.55 93
13.665 94
13.53 95
13.435 96
13.595 97
13.56 98
13.54 99
13.67 100


In [32]:
final_list

[([2, 6, 10, -3, -5, 7, -8, 4, 1, 9], 0.4018711, 3),
 ([-4, 7, -10, 1, 9, 3, 6, 5, -8, -2], 0.09979269, 8),
 ([1, 5, -7, -10, -9, -3, -2, -6, -8, -4], 0.57516044, 6),
 ([-6, 8, -10, -9, 1, -7, -3, 2, -5, 4], 0.040184647, 4),
 ([4, 3, -8, -10, -5, -6, 2, -7, 9, -1], 0.24551669, 4),
 ([-4, -5, 9, 2, 6, 10, 1, -3, -8, 7], 0.2541943, 4),
 ([-8, 9, 10, 2, 5, -1, -7, 3, -4, -6], 0.12051022, 7),
 ([9, -10, 5, -6, 1, -7, -4, 2, -8, -3], 0.48029083, 10),
 ([1, 3, -5, 4, 8, 6, -9, 10, 7, 2], 0.24127603, 4),
 ([10, -3, 8, -4, -7, 9, 2, -5, 1, -6], 0.88799655, 4),
 ([1, -10, -9, -6, -5, -8, 2, -7, -4, 3], 0.54682887, 7),
 ([-9, -7, 2, 4, 6, 1, -5, 3, 10, 8], 0.7295481, 10),
 ([9, 6, 5, 8, 10, 7, 4, 2, 3, 1], 0.23502158, 3),
 ([-4, -10, -1, -6, -9, -2, 7, 3, 5, -8], 0.6899465, 6),
 ([-8, -6, 5, 9, -10, 1, -4, 7, 3, 2], 0.13138314, 6),
 ([-10, 6, 2, 5, -9, 7, -8, 3, -1, -4], 0.21802264, 7),
 ([-7, -4, -9, -3, -2, 5, -10, -8, 1, -6], 0.16754094, 7),
 ([4, 3, 7, 8, 6, 5, 10, 2, 1, 9], 0.4412796, 3),
 

In [23]:
seq_dist = pred_pair(vd, vs, 20, rate = 0.9)

In [24]:
seq_dist

[([8, 6, 2, 4, 9, 5, 1, 10, 3, 7], 0.16923569, 8),
 ([8, 6, 2, 4, 9, -5, 1, 10, 3, 7], 0.18237135, 9),
 ([8, 6, 2, -4, 9, -5, 1, 10, 3, 7], 0.18809369, 12),
 ([8, -2, 6, -4, 9, -5, 1, 10, 3, 7], 0.19728915, 16),
 ([8, 6, -2, -4, 9, -5, 1, 10, 3, 7], 0.18979667, 14),
 ([9, -4, 2, -6, -8, -7, -3, -10, -1, 5], 0.4397651, 19),
 ([-5, 1, 10, 3, 7, 8, -2, 6, 4, -9], 0.45007086, 22),
 ([8, 6, 4, -2, 9, -5, 1, 10, 3, 7], 0.19785967, 13),
 ([4, -2, 8, 6, 9, -5, 1, 10, 3, 7], 0.2644184, 15),
 ([8, -4, 2, 6, 9, -5, 1, 10, 3, 7], 0.35642767, 17),
 ([-5, 1, 10, 3, 7, 8, 2, 6, 9, -4], 0.47409064, 21),
 ([-5, 1, -10, -3, 7, 8, 2, 6, 9, -4], 0.6319049, 23),
 ([8, 2, 6, -4, 9, -5, 1, -10, -3, 7], 0.48155028, 18),
 ([2, 8, 6, -4, 9, -5, 1, -10, -3, 7], 0.47837177, 16),
 ([8, 6, -4, -2, 9, -5, 1, -10, -3, 7], 0.4603958, 12),
 ([4, -9, 2, 8, 6, -7, 3, 10, -1, 5], 1.3512394, 19),
 ([4, -9, -2, 8, 6, -7, 3, 10, -1, 5], 1.4031819, 22),
 ([4, -6, -8, 2, 9, -5, 1, -10, -3, 7], 0.5346977, 16),
 ([4, -6, -8, 2, 

In [25]:
pred_pair(vd,vs, 20, 0.8)

[([8, 6, 2, 4, 9, 5, 1, 10, 3, 7], 0.16923569, 8),
 ([8, 6, 2, 4, 9, -5, 1, 10, 3, 7], 0.18237135, 9),
 ([8, 6, 2, -4, 9, -5, 1, 10, 3, 7], 0.18809369, 12),
 ([8, -2, 6, -4, 9, -5, 1, 10, 3, 7], 0.19728915, 16),
 ([8, 6, -2, -4, 9, -5, 1, 10, 3, 7], 0.18979667, 14),
 ([9, -4, 2, -6, -8, -7, -3, -10, -1, 5], 0.4397651, 19),
 ([-5, 1, 10, 3, 7, 8, -2, 6, 4, -9], 0.45007086, 22),
 ([8, 6, 4, -2, 9, -5, 1, 10, 3, 7], 0.19785967, 13),
 ([4, -2, 8, 6, 9, -5, 1, 10, 3, 7], 0.2644184, 15),
 ([8, -4, 2, 6, 9, -5, 1, 10, 3, 7], 0.35642767, 17),
 ([-5, 1, 10, 3, 7, 8, 2, 6, 9, -4], 0.47409064, 21),
 ([-5, 1, -10, -3, 7, 8, 2, 6, 9, -4], 0.6319049, 23),
 ([8, 2, 6, -4, 9, -5, 1, -10, -3, 7], 0.59933335, 18),
 ([2, 8, 6, -4, 9, -5, 1, -10, -3, 7], 0.5961548, 16),
 ([8, 6, -4, -2, 9, -5, 1, -10, -3, 7], 0.5781788, 12),
 ([4, -9, 2, 8, 6, -7, 3, 10, -1, 5], 1.3512394, 19),
 ([4, -9, -2, 8, 6, -7, 3, 10, -1, 5], 1.4031819, 22),
 ([4, -6, -8, 2, 9, -5, 1, -10, -3, 7], 0.6524807, 16),
 ([4, -6, -8, 2, 9

In [27]:
pred_pair(vd,vs, 20, 0.1)

[([8, 6, 2, 4, 9, 5, 1, 10, 3, 7], 0.16923583, 8),
 ([8, 6, 2, 4, 9, -5, 1, 10, 3, 7], 0.18237148, 9),
 ([8, 6, 2, -4, 9, -5, 1, 10, 3, 7], 0.18809383, 12),
 ([8, -2, 6, -4, 9, -5, 1, 10, 3, 7], 0.19728929, 16),
 ([8, 6, -2, -4, 9, -5, 1, 10, 3, 7], 0.1897968, 14),
 ([9, -4, 2, -6, -8, -7, -3, -10, -1, 5], 0.43976521, 19),
 ([-5, 1, 10, 3, 7, 8, -2, 6, 4, -9], 0.45007098, 22),
 ([8, 6, 4, -2, 9, -5, 1, 10, 3, 7], 0.19785981, 13),
 ([4, -2, 8, 6, 9, -5, 1, 10, 3, 7], 0.2644185, 15),
 ([8, -4, 2, 6, 9, -5, 1, 10, 3, 7], 0.3564278, 17),
 ([-5, 1, 10, 3, 7, 8, 2, 6, 9, -4], 0.47409075, 21),
 ([-5, 1, -10, -3, 7, 8, 2, 6, 9, -4], 0.6319049, 23),
 ([-6, -2, -8, -7, 3, 10, -1, 5, -9, 4], 0.64165294, 20),
 ([-4, 9, -5, 1, -10, -3, 7, 2, 8, 6], 1.1589904, 18),
 ([-6, -8, -7, 3, 10, -1, 5, -9, 2, 4], 0.6204985, 18),
 ([4, -9, 2, 8, 6, -7, 3, 10, -1, 5], 1.3512394, 19),
 ([4, -9, -2, 8, 6, -7, 3, 10, -1, 5], 1.4031819, 22),
 ([-6, -8, 2, 9, -5, 1, -10, -3, 7, 4], 1.4252586, 18),
 ([-6, -8, 2, 9, 

In [28]:
pred_pair(vd,vs, 50, 0.1)

[([8, 6, 2, 4, 9, 5, 1, 10, 3, 7], 0.16923569, 8),
 ([8, 6, 2, 4, 9, -5, 1, 10, 3, 7], 0.18237135, 9),
 ([8, 6, 2, -4, 9, -5, 1, 10, 3, 7], 0.18809369, 12),
 ([8, -2, 6, -4, 9, -5, 1, 10, 3, 7], 0.19728915, 16),
 ([8, 6, -2, -4, 9, -5, 1, 10, 3, 7], 0.18979667, 14),
 ([9, -4, 2, -6, -8, -7, -3, -10, -1, 5], 0.4397651, 19),
 ([-5, 1, 10, 3, 7, 8, -2, 6, 4, -9], 0.45007086, 22),
 ([8, 6, 4, -2, 9, -5, 1, 10, 3, 7], 0.19785967, 13),
 ([4, -2, 8, 6, 9, -5, 1, 10, 3, 7], 0.2644184, 15),
 ([8, -4, 2, 6, 9, -5, 1, 10, 3, 7], 0.35642758, 17),
 ([-5, 1, 10, 3, 7, 8, 2, 6, 9, -4], 0.47409064, 21),
 ([-5, 1, -10, -3, 7, 8, 2, 6, 9, -4], 0.6319049, 23),
 ([-6, -2, -8, -7, 3, 10, -1, 5, -9, 4], 0.64165294, 20),
 ([-4, 9, -5, 1, -10, -3, 7, 2, 8, 6], 1.1589904, 18),
 ([-6, -8, -7, 3, 10, -1, 5, -9, 2, 4], 0.6204985, 18),
 ([4, -9, 2, 8, 6, -7, 3, 10, -1, 5], 1.3512394, 19),
 ([4, -9, -2, 8, 6, -7, 3, 10, -1, 5], 1.4031819, 22),
 ([-6, -8, 2, 9, -5, 1, -10, -3, 7, 4], 1.4252586, 18),
 ([-6, -8, 2, 9,

In [52]:
len(val_dataset)

200

In [41]:
pred_list

[[[9, 2, 6, 10, -3, -5, 7, -8, 4, 1], [-1, -4, 8, -7, 5, 3, -10, -6, -2, -9]],
 [[9, 2, 6, 10, -3, -5, 7, -8, 4, 1], [-1, -4, 8, -7, 5, 3, -10, -6, -2, -9]],
 [[9, 2, 6, 10, -3, -5, 7, -8, 4, 1], [-1, -4, 8, -7, 5, 3, -10, -6, -2, -9]]]

In [42]:
prob_list

[0.6614356, 0.66143703, 0.6614413]

In [38]:
for pred, prob_ in zip(pred_list, prob_list):
    tmp_dist = [sum([dcj_dist(p, s)[-1] for s in val_seq[0][0]]) for p in pred]
    p_seq = pred[np.argmin(tmp_dist)]
    print(p_seq, tmp_dist, prob_)

[-1, -4, 8, -7, 5, 3, -10, -6, -2, -9] [8, 7] 0.6614356
[-1, -4, 8, -7, 5, 3, -10, -6, -2, -9] [8, 7] 0.6614363
[-1, -4, 8, -7, 5, 3, -10, -6, -2, -9] [8, 7] 0.66143775


In [33]:
2**3

8

In [30]:
np.argmin([8, 7])

1

In [29]:
val_seq[0][0]

array([[  2,  -9,  -1,  -4,   8,  -7,   5,   3, -10,  -6],
       [  2,   6,  10,  -3,  -5,   7,   8,   4,   1,   9],
       [  2,   6,  10,  -3,  -5,   7,  -8,   4,   1,  -9]], dtype=int32)

In [None]:
for d, seqs in zip(val_dataset, val_seq[0]):
    pred_list, prob_list = dist_eval(d, 1)
    dist_val = min([sum() for p in ])
    

In [None]:
# for epoch in range(1, 500 + 1):

#     loss = train(model, train_loader)
#     # tloss = val(model, val_loader)
#     scheduler.step(loss)
    
#     if epoch % 10 != 0:
#         continue
    
#     model.eval()
#     count = 0
#     num = 0
#     for d, seqs in zip(val_dataset, val_seq[0]):
#         d = d.to(device)
#         z = model.encode(d.x, d.edge_index)
#         res = model.decoder.forward_all(z).detach().cpu().numpy()
#     #     pred_seqs = mat2adj(res)
#     #     pred_dist = min([sum([dcj_dist(pred, s)[-1] for s in seqs]) for pred in pred_seqs])
#     #     low_dist = np.ceil(sum([dcj_dist(seqs[0], seqs[1])[-1], 
#     #                             dcj_dist(seqs[0], seqs[2])[-1], 
#     #                             dcj_dist(seqs[1], seqs[2])[-1]])/2)

#     #     print(f'{pred_dist:>3} -- {low_dist:<3}')
#     #     count += pred_dist - low_dist
#         # print('---------')
#         tmp_mat = np.copy(res)

#         for i in range(0, 20, 2):
#             tmp_mat[i, i:i+2] = 0
#             tmp_mat[i + 1, i:i+2] = 0

#         pred, prob_ = mat2adj(tmp_mat)
#         mval = min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred])
#         # print(mval, prob_)
#         # print(pred)
#         tag = False
#         val_list = [mval]
#         prob_list = [prob_]
#         for _ in range(9):
#             r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
#             if r == c:
#                 break
#             if tmp_mat[r,c] == 1.0:
#                 redv = 0.95
#             else:
#                 redv = tmp_mat[r,c] * tmp_mat[r,c]
#             tmp_mat[r,c] = redv
#             tmp_mat[c,r] = redv
#             pred, prob = mat2adj(tmp_mat)
#             # if prob < prob_:
#             #     print(f'** {prob}')
#             #     print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
#             val_list.append(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
#             prob_list.append(prob)
#             # if min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]) < mval:
#             #     # print(pred)
#             #     # print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
#             #     tag = True
#             # print(pred)
#         if np.argmin(val_list) == np.argmin(prob_list):
#             count += 1
#         if np.argmin(val_list) == 0:
#             num += 1
#         # print(np.argmin(val_list), val_list)
#         # print(np.argmin(prob_list), prob_list)
#     print(epoch, count, num)

In [80]:
# model.eval()
# count = 0
# num = 0
# for d, seqs in zip(val_dataset, val_seq[0]):
#     d = d.to(device)
#     z = model.encode(d.x, d.edge_index)
#     res = model.decoder.forward_all(z).detach().cpu().numpy()
# #     pred_seqs = mat2adj(res)
# #     pred_dist = min([sum([dcj_dist(pred, s)[-1] for s in seqs]) for pred in pred_seqs])
# #     low_dist = np.ceil(sum([dcj_dist(seqs[0], seqs[1])[-1], 
# #                             dcj_dist(seqs[0], seqs[2])[-1], 
# #                             dcj_dist(seqs[1], seqs[2])[-1]])/2)

# #     print(f'{pred_dist:>3} -- {low_dist:<3}')
# #     count += pred_dist - low_dist
#     # print('---------')
#     tmp_mat = np.copy(res)

#     for i in range(0, 20, 2):
#         tmp_mat[i, i:i+2] = 0
#         tmp_mat[i + 1, i:i+2] = 0

#     pred, prob_ = mat2adj(tmp_mat)
#     mval = min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred])
#     # print(mval, prob_)
#     # print(pred)
#     tag = False
#     val_list = [mval]
#     prob_list = [prob_]
#     for _ in range(9):
#         r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
#         if r == c:
#             break
#         if tmp_mat[r,c] == 1.0:
#             redv = 0.95
#         else:
#             redv = tmp_mat[r,c] * tmp_mat[r,c]
#         tmp_mat[r,c] = redv
#         tmp_mat[c,r] = redv
#         pred, prob = mat2adj(tmp_mat)
#         # if prob < prob_:
#         #     print(f'** {prob}')
#         #     print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
#         val_list.append(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
#         prob_list.append(prob)
#         # if min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]) < mval:
#         #     # print(pred)
#         #     # print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
#         #     tag = True
#         # print(pred)
#     if np.argmin(val_list) == np.argmin(prob_list):
#         count += 1
#     if np.argmin(val_list) == 0:
#         num += 1
#     # print(np.argmin(val_list), val_list)
#     # print(np.argmin(prob_list), prob_list)
# print(count, num)

91 89


In [55]:
# tdict = {14: 18, 18: 14, 4: 15, 15: 4, 6: 16, 16: 6, 0: 11, 11: 0, 2: 12, 12: 2, 1: 17, 17: 1, 3: 7, 7: 3, 5: 13, 13: 5}

In [62]:
# test_cycle(tdict, 9, 19)

False

In [None]:
# tmp_mat = np.copy(res)

# for i in range(0, 20, 2):
#     tmp_mat[i, i:i+2] = 0
#     tmp_mat[i + 1, i:i+2] = 0

# pred, prob_ = mat2adj(tmp_mat)
# mval = min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred])
# print(mval, prob_)
# # print(pred)
# tag = False
# val_list = [mval]
# prob_list = [prob_]
# for _ in range(9):
#     print('-------')
#     r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
#     if tmp_mat[r,c] == 1.0:
#         redv = 0.95
#     else:
#         redv = tmp_mat[r,c] * tmp_mat[r,c]
#     tmp_mat[r,c] = redv
#     tmp_mat[c,r] = redv
#     pred, prob = mat2adj(tmp_mat)
#     # if prob < prob_:
#     #     print(f'** {prob}')
#     #     print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
#     val_list.append(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
#     prob_list.append(prob)
#     # if min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]) < mval:
#     #     # print(pred)
#     #     # print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]), prob)
#     #     tag = True
#     # print(pred)
# if np.argmin(val_list) == np.argmin(prob_list):
#     count += 1
# print(np.argmin(val_list), val_list)
# print(np.argmin(prob_list), prob_list)

In [57]:
set(range(20)) - set(tdict.keys())

{8, 9, 10, 19}

In [61]:
from genome_file import encodeAdj
encodeAdj([1, 7, 9, -3, -10, -2, -4, -6, -8, -5])

array([-1,  0,  1, 12, 13, 16, 17,  5,  4, 19, 18,  3,  2,  7,  6, 11, 10,
       15, 14,  9,  8, -2], dtype=int32)

In [62]:
t = [0,  1, 12, 13, 16, 17,  5,  4, 19, 18,  3,  2,  7,  6, 11, 10,
       15, 14,  9,  8]

In [63]:
[(t[i]//2 + 1) * (-1 if t[i] > t[i+1] else 1) for i in range(0, len(t), 2)]

[1, 7, 9, -3, -10, -2, -4, -6, -8, -5]

In [28]:
mat2adj(tmp_mat)

[[1, 7, 9, -3, -10, -2, -4, -6, -8, -5], [5, 8, 6, 4, 2, 10, 3, -9, -7, -1]]

In [75]:
tmp_mat = np.copy(res)

for i in range(0, 20, 2):
    tmp_mat[i, i:i+2] = 0
    tmp_mat[i + 1, i:i+2] = 0

pred = mat2adj(tmp_mat)
print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
print(pred)

6
[[10, -2, 4, 9, 8, 5, 6, -1, -3, -7], [7, 3, 1, -6, -5, -8, -9, -4, 2, -10]]


In [None]:
tmp_mat

In [26]:
[  7,  10,  -2,   4,   9,   8,   5,   6,  -1,  -3]

6

In [70]:
seqs

array([[  7,  10,  -2,   4,   3,   1,  -6,  -5,  -8,  -9],
       [  1,  -6,  -5,  -8,  -9,  -4,   2, -10,  -7,  -3],
       [  7,  -4,   2, -10,   9,   8,   5,   6,  -1,  -3]], dtype=int32)

In [71]:
[dcj_dist([7, 3, 1, -6, -5, -8, -9, -4, 2, -10], s)[-1] for s in seqs]

[2, 2, 2]

In [73]:
res[13, 18]

0.6990546

In [29]:
tmp_mat[6,11] = 0
tmp_mat[11, 6] = 0

In [30]:
mat2adj(tmp_mat)

[[1, 7, 9, -3, -10, -2, 4, -6, -8, -5], [5, 8, 6, -4, 2, 10, 3, -9, -7, -1]]

In [76]:
for _ in range(10):
    r,c = np.unravel_index(np.argmax(tmp_mat, axis = None), tmp_mat.shape)
    tmp_mat[r,c] = 0
    tmp_mat[c,r] = 0
    pred = mat2adj(tmp_mat)
    print(min([sum([dcj_dist(p, s)[-1] for s in seqs]) for p in pred]))
    print(pred)

9
[[10, -2, 4, 9, 8, 5, 6, 1, -3, -7], [7, 3, -1, -6, -5, -8, -9, -4, 2, -10]]
9
[[10, -2, 4, 9, 8, 5, -6, -1, -3, -7], [7, 3, 1, 6, -5, -8, -9, -4, 2, -10]]
15
[[10, -2, 4, 9, 8, -5, 7, 3, 1, 6], [-6, -1, -3, -7, 5, -8, -9, -4, 2, -10]]
12
[[10, -2, 4, 9, -8, 5, -6, -1, -3, -7], [7, 3, 1, 6, -5, 8, -9, -4, 2, -10]]
18
[[2, -10, 7, 3, 1, 6, -5, 8, -9, -4], [4, 9, -8, 5, -6, -1, -3, -7, 10, -2]]
19
[[1, -3, -7, 10, -2, -6, -5, 8, -9, -4], [4, 9, -8, 5, 6, 2, -10, 7, 3, -1]]
19
[[1, -3, -7, 10, 2, 4, 9, -8, 5, 6], [-6, -5, 8, -9, -4, -2, -10, 7, 3, -1]]
18
[[10, 2, 4, 9, -8, 5, -6, 1, -3, -7], [7, 3, -1, 6, -5, 8, -9, -4, -2, -10]]
24
[[-9, -4, -2, -10, 7, 3, -1, 6, -5, 8], [-8, 5, -6, 1, -3, -7, 10, 2, 4, 9]]
23
[[10, 2, 4, -9, 8, -5, 7, 3, -1, 6], [-6, 1, -3, -7, 5, -8, 9, -4, -2, -10]]


In [32]:
seqs

array([[-6,  4,  2, 10,  3,  5,  8, -1,  7,  9],
       [ 6, -8, -5, -2, -4, -3, -7, -1, 10,  9],
       [ 6, -4,  2, 10,  3,  5,  8,  1,  7,  9]], dtype=int32)

In [34]:
dcj_dist([5, 8, 6, 4, 2, 10, 3, -9, -7, -1], seqs[0])[-1]

5

In [55]:
[dcj_dist([1, 7, 9, -3, -10, -2, -4, 6, -8, -5], s)[-1] for s in seqs]

[3, 5, 4]

In [47]:
sum([dcj_dist([  6,   4,   2,  10,   3,   5,   8,   1,   7,   9], seqs[1])[-1], dcj_dist(seqs[0], [  6,   4,   2,  10,   3,   5,   8,   1,   7,   9])[-1], dcj_dist(seqs[2], [  6,   4,   2,  10,   3,   5,   8,   1,   7,   9])[-1]])

7

In [69]:
val_seq[1]

array([[  7,  -8,   3,  -2,  -9,  -6,   1,  -4,  -5, -10],
       [  5,  -3,   4,   7,  -2,   1,   6,  10,   9,   8],
       [  5, -10,  -4,   8,  -2,  -9,   6,   1,  -7,   3],
       [ -9,  -5,  -3,   8,  -4,  -2,  -6, -10,   7,  -1],
       [ -9,  10,  -4,   7,  -3,  -2,   1,   6,   8,  -5],
       [ -7,  10,  -4,   6,   8,   1,  -2,  -9,  -3,   5],
       [  1, -10,   3,   5,   7,  -8,   4,   9,  -6,   2],
       [ -3,  -2,  -7,   4,   8,   6,  10,   5,   1,   9],
       [ -7,   1,   8,   3,   5,   2,   6,  10,   9,  -4],
       [ -5,   7,  -4,  10,  -1,  -8,  -3,   6,   2,  -9],
       [  5,  -2,   7,   9,  -4,   1,  10,   8,   6,  -3],
       [ -8,  -6,   4,   7,  10,   2,  -1,  -9,   3,  -5],
       [  1,   3,  -8,  10,   4,   6,  -9,  -5,   7,  -2],
       [  2,   9,   3,  -1,   8,   4,   5,  10,   7,   6],
       [  4,  -5,   9,  -1,  -8,   3,   2,   7,  10,   6],
       [ -6,  -4,   3,   8,   7,   5,  -1, -10,  -9,  -2],
       [ 10,   4,   2,   5,   8,   1,   3,  -9,   6,   7

In [56]:
[dcj_dist([  6,   4,   2,  10,   3,   5,   8,   1,   7,   9], s)[-1] for s in seqs]

[2, 4, 1]