In [1]:
import time
import torch

from phylognn_model import G2Braph
from gene_graph_dataset import G2BraphDataset

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader

import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.metrics import roc_auc_score, average_precision_score

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2
train_batch = 25
test_batch, val_batch = 8, 8

In [3]:
dataset = G2BraphDataset('dataset_g2b', 1000, 10).shuffle()
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [5]:
train_loader = DataLoader(train_dataset, batch_size=train_batch, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch)
val_loader = DataLoader(val_dataset, batch_size=val_batch)

In [6]:
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = G2Braph(deg).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)

In [8]:
def train(train_dataset):
    model.train()
    
    total_loss = 0
    for data in train_dataset:
        data = data.to(device)
        optimizer.zero_grad()
        
        res = model(data.x, data.edge_index, None, None)
        loss = F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        loss.backward()
        
        total_loss += loss
        optimizer.step()
        
    return total_loss / len(train_dataset)

In [9]:
@torch.no_grad()
def validate(test_dataset):
    model.eval()
    
    total_loss, auc, ap = 0, 0, 0
    for data in test_dataset:
        data = data.to(device)
        res = model(data.x, data.edge_index, None, None)
        
        y, pred = data.node_label.cpu().numpy(), res.squeeze().cpu().numpy()
        
        total_loss += F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        auc += roc_auc_score(y, pred)
        ap += average_precision_score(y, pred)
        
    return total_loss / len(test_dataset), auc / len(test_dataset), ap / len(test_dataset)

In [10]:
@torch.no_grad()
def test(test_dataset):
    model.eval()
    
    loss, auc, ap = [], [], []
    for data in test_dataset:
        data = data.to(device)
        res = model(data.x, data.edge_index, None, None)
        
        y, pred = data.node_label.cpu().numpy(), res.squeeze().cpu().numpy()
        
        loss.apppend(F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float)))
        auc.append(roc_auc_score(y, pred))
        ap.append(average_precision_score(y, pred))
        
    return loss, auc, ap

In [11]:
writer = SummaryWriter(log_dir='runs/g2braph_batch')

In [None]:
for epoch in range(1, 2001):
    train_loss = train(train_loader)
    val_loss, val_auc, val_ap = validate(val_loader)
    
    scheduler.step(val_loss)
    
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/validate', val_loss, epoch)
    writer.add_scalar('AUC/validate', val_auc, epoch)
    writer.add_scalar('AP/validate', val_ap, epoch)
    
    if epoch % 50 == 0:
        print(f'{time.ctime()}  '
              f'Epoch: {epoch:04d}, Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}, '
              f'Val AP: {val_ap:.4f}')

In [None]:
test_loss, test_auc, test_ap = test(test_dataset.shuffle())
print('Test Loss: \t',
      ' '.join(f'{x:.6f}' for x in test_loss), '\n'
      'Test AUC: \t',
      ' '.join(f'{x:.6f}' for x in test_auc), '\n'
      'Test AP: \t',
      ' '.join(f'{x:.6f}' for x in test_ap))

In [None]:
writer.close()