In [1]:
import time

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.nn import VGAE
from torch_geometric.loader import DataLoader
from torch_geometric.utils import (degree, negative_sampling, 
                                   batched_negative_sampling)

from torch.utils.tensorboard import SummaryWriter

from gene_graph_dataset import G3MedianDataset
from phylognn_model import G3Median_GCNConv

In [2]:
gpuid = 1

train_p, test_p, val_p = 0.7, 0.2, 0.1
train_batch, test_batch, val_batch = 128, 128, 4

In [3]:
device = torch.device('cuda:' + str(gpuid) if torch.cuda.is_available() else 'cpu')

In [4]:
dataset = G3MedianDataset('dataset_g3m', 10, 10, 200)

In [5]:
data_size = len(dataset)
train_size, test_size, val_size = ((int)(data_size * train_p), 
                                   (int)(data_size * test_p), 
                                   (int)(data_size * val_p))

In [6]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):(train_size + test_size + val_size)]

In [7]:
test_dataset = list(test_dataset)
for t in test_dataset:
    t.neg_edge_label_index = negative_sampling(t.pos_edge_label_index, 
                                        t.num_nodes,
                                        t.num_nodes**2)
train_dataset = list(train_dataset)
for t in train_dataset:
    t.neg_edge_label_index = negative_sampling(t.pos_edge_label_index, 
                                        t.num_nodes,
                                        t.num_nodes**2)
val_dataset = list(val_dataset)
for t in val_dataset:
    t.neg_edge_label_index = negative_sampling(t.pos_edge_label_index, 
                                        t.num_nodes,
                                        t.num_nodes**2)

In [8]:
train_loader = DataLoader(train_dataset, batch_size = train_batch, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = test_batch)
val_loader = DataLoader(val_dataset, batch_size= val_batch)

In [9]:
# deg = torch.zeros(5, dtype=torch.long)
# for data in train_dataset:
#     d = degree(data.edge_index[1].type(torch.int64), 
#                num_nodes=data.num_nodes, dtype=torch.long)
#     deg += torch.bincount(d, minlength=deg.numel())

In [10]:
in_channels, out_channels = dataset.num_features, 16

In [11]:
model = VGAE(G3Median_GCNConv(in_channels, out_channels)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10,
                              min_lr=0.00001)

In [12]:
writer = SummaryWriter(log_dir='runs_g3m_10/g3median_2000_gcn_aneg_run8')

In [13]:
from torch_geometric.data import Batch
def train(train_loader):
    model.train()
    
    total_loss = 0
    for data in train_loader:
        
        data = data.to(device)
        optimizer.zero_grad()
        
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss(z, data.pos_edge_label_index, data.neg_edge_label_index)
        loss = loss + (1 / data.num_nodes) * model.kl_loss()
        loss.backward()
        optimizer.step()
        
        total_loss += loss
    return total_loss/len(train_loader)

In [14]:
@torch.no_grad()
def test(test_loader):
    model.eval()
    auc, ap = 0, 0
    
    for data in test_loader:
        
        data = data.to(device)
        
        z = model.encode(data.x, data.edge_index)
        # loss += model.recon_loss(z, data.pos_edge_label_index, data.neg_edge_label_index)
        tauc, tap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
        
        auc += tauc
        ap += tap
        
    return auc/len(test_loader), ap/len(test_loader)

In [15]:
@torch.no_grad()
def val(val_loader):
    model.eval()
    loss = 0
    
    for data in val_loader:
        
        data = data.to(device)
        
        z = model.encode(data.x, data.edge_index)
        loss += model.recon_loss(z, data.pos_edge_label_index, data.neg_edge_label_index)
        # tauc, tap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
                
    return loss #/len(val_loader)

In [16]:
for epoch in range(1, 1000 + 1):
    # print(f'{time.ctime()} - Epoch: {epoch:04d}')
    loss = train(train_loader)
    # print(f'{time.ctime()} - \t train loss: {loss:.6f}')
    tloss = val(val_loader)
    # print(f'{time.ctime()} - \t val   loss: {tloss:.6f}')
    scheduler.step(tloss)
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Loss/val', tloss/len(val_loader), epoch)
    
    
    auc, ap = test(test_dataset)
    
    writer.add_scalar('AUC/test', auc, epoch)
    writer.add_scalar('AP/test', ap, epoch)
    
    if epoch % 50 == 0:
        print(f'{time.ctime()} - Epoch: {epoch:04d}        auc: {auc:.6f}, ap: {ap:.6f}')

Mon Dec 27 11:54:48 2021 - Epoch: 0050        auc: 0.851906, ap: 0.242792
Mon Dec 27 12:00:35 2021 - Epoch: 0100        auc: 0.870499, ap: 0.262982
Mon Dec 27 12:06:22 2021 - Epoch: 0150        auc: 0.883273, ap: 0.282682
Mon Dec 27 12:12:09 2021 - Epoch: 0200        auc: 0.888804, ap: 0.290417
Mon Dec 27 12:17:55 2021 - Epoch: 0250        auc: 0.891653, ap: 0.294942
Mon Dec 27 12:23:39 2021 - Epoch: 0300        auc: 0.893039, ap: 0.298911
Mon Dec 27 12:29:26 2021 - Epoch: 0350        auc: 0.895849, ap: 0.303107
Mon Dec 27 12:35:14 2021 - Epoch: 0400        auc: 0.898252, ap: 0.306953
Mon Dec 27 12:40:59 2021 - Epoch: 0450        auc: 0.898347, ap: 0.306679
Mon Dec 27 12:46:43 2021 - Epoch: 0500        auc: 0.901002, ap: 0.312568
Mon Dec 27 12:52:30 2021 - Epoch: 0550        auc: 0.902183, ap: 0.314088
Mon Dec 27 12:58:19 2021 - Epoch: 0600        auc: 0.904403, ap: 0.319120
Mon Dec 27 13:04:06 2021 - Epoch: 0650        auc: 0.905526, ap: 0.321767
Mon Dec 27 13:09:53 2021 - Epoch: 0700

In [17]:
writer.close()

In [18]:
# torch.save(model.state_dict(), 'g2g_test_model_batch')