In [1]:
import time
import torch

from phylognn_model import G2Braph
from gene_graph_dataset import G2BraphDataset

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader

import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.metrics import roc_auc_score, average_precision_score

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2

In [3]:
dataset = G2BraphDataset('dataset_g2b', 1000, 10).shuffle()
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [5]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)

In [6]:
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = G2Braph(deg).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)

In [8]:
def train(train_dataset):
    model.train()
    
    total_loss = 0
    for data in train_dataset:
        data = data.to(device)
        optimizer.zero_grad()
        
        res = model(data.x, data.edge_index, None, None)
        loss = F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        loss.backward()
        
        total_loss += loss
        optimizer.step()
        
    return total_loss / len(train_dataset)

In [9]:
@torch.no_grad()
def test(test_dataset):
    model.eval()
    
    total_loss, auc, ap = 0, 0, 0
    for data in test_dataset:
        data = data.to(device)
        res = model(data.x, data.edge_index, None, None)
        
        y, pred = data.node_label.cpu().numpy(), res.squeeze().cpu().numpy()
        
        total_loss += F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        auc += roc_auc_score(y, pred)
        ap += average_precision_score(y, pred)
        
    return total_loss / len(test_dataset), auc / len(test_dataset), ap / len(test_dataset)

In [10]:
writer = SummaryWriter(log_dir='runs/g2braph')

In [None]:
for epoch in range(1, 101):
    train_loss = train(train_dataset.shuffle())
    val_loss, val_auc, val_ap = test(val_dataset.shuffle())
    
    scheduler.step(val_loss)
    
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/validate', val_loss, epoch)
    writer.add_scalar('AUC/validate', val_auc, epoch)
    writer.add_scalar('AP/validate', val_ap, epoch)
    
    print(f'{time.ctime()}  '
              f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}, '
              f'Val AP: {val_ap:.4f}')

Tue Dec 21 21:06:48 2021	Epoch: 001, Train Loss: 0.0489, Val Loss: 0.0320, Val AUC: 0.8876, Val AP: 0.7000
Tue Dec 21 21:07:49 2021	Epoch: 002, Train Loss: 0.0264, Val Loss: 0.0268, Val AUC: 0.9099, Val AP: 0.7237
Tue Dec 21 21:08:50 2021	Epoch: 003, Train Loss: 0.0218, Val Loss: 0.0243, Val AUC: 0.9125, Val AP: 0.7327
Tue Dec 21 21:09:51 2021	Epoch: 004, Train Loss: 0.0200, Val Loss: 0.0235, Val AUC: 0.9185, Val AP: 0.7421
Tue Dec 21 21:10:51 2021	Epoch: 005, Train Loss: 0.0190, Val Loss: 0.0230, Val AUC: 0.9200, Val AP: 0.7472
Tue Dec 21 21:11:51 2021	Epoch: 006, Train Loss: 0.0184, Val Loss: 0.0226, Val AUC: 0.9214, Val AP: 0.7460
Tue Dec 21 21:12:51 2021	Epoch: 007, Train Loss: 0.0180, Val Loss: 0.0235, Val AUC: 0.9149, Val AP: 0.7379
Tue Dec 21 21:13:51 2021	Epoch: 008, Train Loss: 0.0178, Val Loss: 0.0240, Val AUC: 0.9130, Val AP: 0.7325
Tue Dec 21 21:14:52 2021	Epoch: 009, Train Loss: 0.0176, Val Loss: 0.0243, Val AUC: 0.9168, Val AP: 0.7372
Tue Dec 21 21:15:51 2021	Epoch: 010, 

In [None]:
test_loss, test_auc, test_ap = test(test_dataset.shuffle())
print(f'Test Loss: {train_loss:.4f}, Test AUC: {test_auc:.4f}, '
          f'Test AP: {test_ap:.4f}')

In [None]:
writer.close()