In [1]:
import time

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.nn import VGAE
from torch_geometric.loader import DataLoader
from torch_geometric.utils import degree, negative_sampling

from torch.utils.tensorboard import SummaryWriter

from gene_graph_dataset import G3MedianDataset
from phylognn_model import G3Median

In [2]:
gpuid = 0

train_p, test_p = 0.8, 0.2
train_batch, test_batch = 128, 64

In [3]:
device = torch.device('cuda:' + str(gpuid) if torch.cuda.is_available() else 'cpu')

In [4]:
dataset = G3MedianDataset('dataset_g3m', 100, 100)

In [5]:
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [6]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
# val_dataset = dataset[(train_size + test_size):]

In [7]:
train_loader = DataLoader(train_dataset, batch_size = train_batch, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = test_batch)
# val_loader = DataLoader(val_dataset, batch_size=8)

In [8]:
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [9]:
in_channels, out_channels = dataset.num_features, 16

In [10]:
model = VGAE(G3Median(in_channels, out_channels, deg)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

In [11]:
writer = SummaryWriter(log_dir='runs/g3median_lbatch_pna_run2')

In [12]:
def train(train_loader):
    model.train()
    
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss(z, data.pos_edge_label_index)
        loss = loss + (1 / data.num_nodes) * model.kl_loss()
        loss.backward()
        optimizer.step()
        
        total_loss += loss
    return total_loss/len(train_loader)

In [13]:
@torch.no_grad()
def test(test_loader):
    model.eval()
    loss, auc, ap = 0, 0, 0
    
    for data in test_loader:
        
        data = data.to(device)
        
        neg_samples = negative_sampling(data.pos_edge_label_index, 
                                        data.num_nodes,
                                        data.num_nodes*10)
        
        z = model.encode(data.x, data.edge_index)
        loss += model.recon_loss(z, data.pos_edge_label_index, neg_samples)
                                 # negative_sampling(data.pos_edge_label_index, 
                                                # data.num_nodes,
                                                # data.num_nodes*10))
                                 # data.neg_edge_label_index)
        tauc, tap = model.test(z, data.pos_edge_label_index, neg_samples)
                               # negative_sampling(data.pos_edge_label_index, 
                                                # data.num_nodes,
                                                # data.num_nodes*10))
                               # data.neg_edge_label_index)
        auc += tauc
        ap += tap
        
    return loss/len(test_loader), auc/len(test_loader), ap/len(test_loader)

In [14]:
for epoch in range(1, 200 + 1):
    print(f'{time.ctime()} - Epoch: {epoch:04d}')
    loss = train(train_loader)
    print(f'{time.ctime()} - \t train loss: {loss:.6f}')
    tloss, auc, ap = test(test_loader)
    print(f'{time.ctime()} - \t test  loss: {tloss:.6f}, auc: {auc:.6f}, ap: {ap:.6f}')
    scheduler.step(1 - auc)
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Loss/test', tloss, epoch)
    writer.add_scalar('AUC/test', auc, epoch)
    writer.add_scalar('AP/test', ap, epoch)
    # if epoch % 50 == 0:
    # print(f'{time.ctime()} - '
          # f'Epoch: {epoch:04d}, loss: {loss:.6f}, AUC: {auc:.6f}, '
          # f', TL: {tloss:.6f}')

Sat Dec 25 12:06:43 2021 - Epoch: 0001
Sat Dec 25 12:07:45 2021 - 	 train loss: 1.498455
Sat Dec 25 12:08:03 2021 - 	 test  loss: 1.044663, auc: 0.913005, ap: 0.461320
Sat Dec 25 12:08:03 2021 - Epoch: 0002
Sat Dec 25 12:09:05 2021 - 	 train loss: 0.931019
Sat Dec 25 12:09:21 2021 - 	 test  loss: 0.944239, auc: 0.955998, ap: 0.662932
Sat Dec 25 12:09:21 2021 - Epoch: 0003
Sat Dec 25 12:10:24 2021 - 	 train loss: 0.894842
Sat Dec 25 12:10:40 2021 - 	 test  loss: 0.904250, auc: 0.968448, ap: 0.745100
Sat Dec 25 12:10:40 2021 - Epoch: 0004
Sat Dec 25 12:11:42 2021 - 	 train loss: 0.880140
Sat Dec 25 12:11:58 2021 - 	 test  loss: 0.898926, auc: 0.971577, ap: 0.751383
Sat Dec 25 12:11:58 2021 - Epoch: 0005
Sat Dec 25 12:13:00 2021 - 	 train loss: 0.872557
Sat Dec 25 12:13:16 2021 - 	 test  loss: 0.883975, auc: 0.975063, ap: 0.774436
Sat Dec 25 12:13:16 2021 - Epoch: 0006
Sat Dec 25 12:14:18 2021 - 	 train loss: 0.868838
Sat Dec 25 12:14:34 2021 - 	 test  loss: 0.887465, auc: 0.976442, ap: 0

In [15]:
writer.close()

In [16]:
# torch.save(model.state_dict(), 'g2g_test_model_batch')