In [1]:
from gene_graph_dataset import ExpsDataset, G3MedianDataset

In [2]:
dataset = ExpsDataset('dataset_g3m_exps', 10, 3, 10000)

In [3]:
dataset100 = G3MedianDataset('dataset_g3m', 100, 10, 1000)
dataset100 = dataset100.shuffle()

In [4]:
import time

import numpy as np

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

# from torch_geometric.nn import VGAE
from torch_geometric.loader import DataLoader
from torch_geometric.utils import (degree, negative_sampling, 
                                   batched_negative_sampling,
                                  add_self_loops, to_undirected)

from torch.utils.tensorboard import SummaryWriter

from gene_graph_dataset import G3MedianDataset
from phylognn_model import G3Median_GCNConv, G3Median_VGAE

from sklearn.metrics import (roc_auc_score, roc_curve,
                             average_precision_score, 
                             precision_recall_curve,
                             f1_score, matthews_corrcoef)

from sklearn.model_selection import KFold

import matplotlib.pyplot as plt

In [5]:
train_batch, test_batch, val_batch = 256, 64, 8

device = torch.device('cuda:' + str(1) if torch.cuda.is_available() else 'cpu')

in_channels, out_channels = None, 128

In [6]:
dataset = dataset.shuffle()

In [7]:
def train(model, train_loader):
    model.train()
    
    total_loss = 0
    for data in train_loader:    
        optimizer.zero_grad()
        data = data.to(device)
        
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss_wt(z, data.pos_edge_label_index, data.neg_edge_label_index, 2, 1) * 5
        loss = loss + (1 / data.num_nodes) * model.kl_loss() * 0.5
        loss.backward()
        optimizer.step()
        
        total_loss += loss
    return total_loss/len(train_loader)

In [8]:
@torch.no_grad()
def predict(model, test_loader):
    model.eval()
    y_list, pred_list = [], []
        
    for data in test_loader:
        
        data = data.to(device)
        
        z = model.encode(data.x, data.edge_index)
        # loss += model.recon_loss(z, data.pos_edge_label_index, data.neg_edge_label_index)
        y, pred = model.pred(z, data.pos_edge_label_index, data.neg_edge_label_index)
        
        y_list.append(y)
        pred_list.append(pred)
        
    return y_list, pred_list

@torch.no_grad()
def val(model, val_loader):
    model.eval()
    loss = 0
    
    for data in val_loader:        
        data = data.to(device)        
        z = model.encode(data.x, data.edge_index)        
        loss += model.recon_loss_wt(z, data.pos_edge_label_index, data.neg_edge_label_index, 2, 1)
        # tauc, tap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
                
    return loss/len(val_loader)

def auc_ap(y_list, pred_list):
    pred_accuracy = [[roc_auc_score(y, pred), average_precision_score(y, pred)]
                     for y, pred in zip(y_list, pred_list)]
    auc, ap = np.mean(pred_accuracy, axis = 0)
    return auc, ap

In [9]:
y_pred_res = []
counter = 1

In [10]:
print(f'{time.ctime()} -- seqlen:{10:0>4} '
      f'rate:{0.1:.2f} samples:{5000:0>5} -- fold: {counter:0>2}')

model = G3Median_VGAE(G3Median_GCNConv(in_channels, out_channels)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001,verbose=True)

writer = SummaryWriter(log_dir='exps_g3median_' f'{10:0>4}' '/e' f'{10000:0>5}' '_r' 
                       f'{0.3:0>3.1f}' '_' 'run' f'{counter:0>2}')

train_dataset = dataset[:int(len(dataset) * 0.9)]
val_dataset = dataset[int(len(dataset) * 0.9):]

train_loader = DataLoader(train_dataset, batch_size = train_batch, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size = test_batch)
val_loader = DataLoader(val_dataset, batch_size = val_batch)

start_time = time.time()

y_pred = None
p_auc, p_ap = 0, 0

Sun Jan 16 15:02:14 2022 -- seqlen:0010 rate:0.10 samples:05000 -- fold: 01


In [11]:
for epoch in range(1, 100 + 1):

    loss = train(model, train_loader)
    tloss = val(model, val_loader)
    scheduler.step(tloss)

    writer.add_scalar('loss/train', loss, epoch)
    writer.add_scalar('loss/val', tloss, epoch)

    y_list, pred_list = predict(model, dataset100)
    # pred_acc, figures = cal_accuracy(y_list, pred_list)        
    # auc, ap = pred_acc

    # y_list, pred_list = predict(model, test_dataset)
    auc, ap = auc_ap(y_list, pred_list)

    writer.add_scalar('auc/test', auc, epoch)
    writer.add_scalar('ap/test', ap, epoch)

    # writer.add_figure('roc/test', figures[0], epoch)
    # writer.add_figure('pr/test', figures[1], epoch)

    if auc >= p_auc and ap >= p_ap:
        y_pred = np.concatenate([np.array([y, pred])
                                 for y, pred in zip(y_list, pred_list)], 
                                axis = 1)
        p_auc, p_ap = auc, ap

end_time = time.time()
print(f'{time.ctime()} -- seqlen:{10:0>4} '
      f'rate:{0.1:.2f} samples:{5000:0>5} -- fold: {counter:0>2}'
     f' -- {(end_time - start_time)/100:>10.3f}s * {100:0>4} epoches')
y_pred_res.append(y_pred)

KeyboardInterrupt: 