In [1]:
import time
import torch

from phylognn_model import G2Braph
from gene_graph_dataset import G2BraphDataset

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader

import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.metrics import roc_auc_score, average_precision_score

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2
train_batch = 25
test_batch, val_batch = 8, 8

In [3]:
dataset = G2BraphDataset('dataset_g2b', 1000, 10).shuffle()
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [5]:
train_loader = DataLoader(train_dataset, batch_size=train_batch, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch)
val_loader = DataLoader(val_dataset, batch_size=val_batch)

In [6]:
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = G2Braph(deg).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)

In [8]:
def train(train_dataset):
    model.train()
    
    total_loss = 0
    for data in train_dataset:
        data = data.to(device)
        optimizer.zero_grad()
        
        res = model(data.x, data.edge_index, None, None)
        loss = F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        loss.backward()
        
        total_loss += loss
        optimizer.step()
        
    return total_loss / len(train_dataset)

In [9]:
@torch.no_grad()
def validate(test_dataset):
    model.eval()
    
    total_loss, auc, ap = 0, 0, 0
    for data in test_dataset:
        data = data.to(device)
        res = model(data.x, data.edge_index, None, None)
        
        y, pred = data.node_label.cpu().numpy(), res.squeeze().cpu().numpy()
        
        total_loss += F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        auc += roc_auc_score(y, pred)
        ap += average_precision_score(y, pred)
        
    return total_loss / len(test_dataset), auc / len(test_dataset), ap / len(test_dataset)

In [10]:
@torch.no_grad()
def test(test_dataset):
    model.eval()
    
    loss, auc, ap = [], [], []
    for data in test_dataset:
        data = data.to(device)
        res = model(data.x, data.edge_index, None, None)
        
        y, pred = data.node_label.cpu().numpy(), res.squeeze().cpu().numpy()
        
        loss.append(F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float)))
        auc.append(roc_auc_score(y, pred))
        ap.append(average_precision_score(y, pred))
        
    return loss, auc, ap

In [11]:
writer = SummaryWriter(log_dir='runs/g2braph_batch_2')

In [12]:
for epoch in range(1, 51):
    train_loss = train(train_loader)
    val_loss, val_auc, val_ap = validate(val_loader)
    
    scheduler.step(val_loss)
    
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/validate', val_loss, epoch)
    writer.add_scalar('AUC/validate', val_auc, epoch)
    writer.add_scalar('AP/validate', val_ap, epoch)
    
    if epoch % 10 == 0:
        print(f'{time.ctime()}  '
              f'Epoch: {epoch:04d}, Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}, '
              f'Val AP: {val_ap:.4f}')

Wed Dec 22 16:34:33 2021  Epoch: 0010, Train Loss: 0.0280, Val Loss: 0.0382, Val AUC: 0.8421, Val AP: 0.6103
Wed Dec 22 16:39:10 2021  Epoch: 0020, Train Loss: 0.0222, Val Loss: 0.0288, Val AUC: 0.9136, Val AP: 0.6887
Wed Dec 22 16:43:47 2021  Epoch: 0030, Train Loss: 0.0194, Val Loss: 0.0302, Val AUC: 0.9133, Val AP: 0.6650
Wed Dec 22 16:48:24 2021  Epoch: 0040, Train Loss: 0.0162, Val Loss: 0.0302, Val AUC: 0.9002, Val AP: 0.6814
Wed Dec 22 16:53:01 2021  Epoch: 0050, Train Loss: 0.0141, Val Loss: 0.0328, Val AUC: 0.9107, Val AP: 0.6731


In [13]:
test_loss, test_auc, test_ap = test(test_dataset.shuffle())
print('Test Loss: \t',
      ' '.join(f'{x:.6f}' for x in test_loss), '\n'
      'Test AUC: \t',
      ' '.join(f'{x:.6f}' for x in test_auc), '\n'
      'Test AP: \t',
      ' '.join(f'{x:.6f}' for x in test_ap))

Test Loss: 	 0.005354 0.050691 0.030682 0.045348 0.019017 0.051243 0.031681 0.012724 0.039103 0.036835 0.008026 0.058381 0.005135 0.024380 0.007817 0.004937 0.057609 0.005749 0.019348 0.038373 0.065024 0.098988 0.068290 0.024826 0.016429 0.049858 0.009597 0.031701 0.030388 0.034998 0.039312 0.055404 0.006902 0.032704 0.015426 0.047623 0.011754 0.031844 0.016085 0.016053 0.051114 0.015186 0.011782 0.034021 0.053946 0.011753 0.042085 0.043892 0.018423 0.031375 0.024189 0.072504 0.003953 0.021340 0.047164 0.050099 0.031879 0.024087 0.042593 0.021698 0.028279 0.076137 0.056587 0.060260 0.011728 0.035763 0.046189 0.043487 0.011263 0.022523 0.039387 0.013756 0.006370 0.042798 0.005843 0.013248 0.043271 0.047642 0.047707 0.022073 0.042018 0.032200 0.005220 0.062464 0.043325 0.034163 0.018414 0.052686 0.035485 0.013641 0.069731 0.003919 0.040806 0.051475 0.023486 0.043519 0.012952 0.012336 0.059805 0.053709 0.030230 0.038454 0.032066 0.023047 0.032180 0.006966 0.018665 0.023925 0.005592 0.0256

In [14]:
writer.close()

In [15]:
torch.tensor(test_ap).cpu().numpy().mean()

0.7040116044142253