In [1]:
import time
import torch

from phylognn_model import G2Braph_GCNConv
from gene_graph_dataset import G2BraphDataset

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader

import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.metrics import roc_auc_score, average_precision_score

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2
train_batch = 25
test_batch, val_batch = 8, 8

In [3]:
dataset = G2BraphDataset('dataset_g2b', 10, 10).shuffle()
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

Generating...
Processing...
Done!


In [4]:
print(data_size)

1000


In [5]:
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [6]:
train_loader = DataLoader(train_dataset, batch_size=train_batch, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch)
val_loader = DataLoader(val_dataset, batch_size=val_batch)

In [7]:
# deg = torch.zeros(5, dtype=torch.long)
# for data in train_dataset:
#     d = degree(data.edge_index[1].type(torch.int64), 
#                num_nodes=data.num_nodes, dtype=torch.long)
#     deg += torch.bincount(d, minlength=deg.numel())

In [8]:
gpuid = 0

In [9]:
device = torch.device('cuda:' + str(gpuid) if torch.cuda.is_available() else 'cpu')

model = G2Braph_GCNConv().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

In [10]:
def train(train_dataset):
    model.train()
    
    total_loss = 0
    for data in train_dataset:
        data = data.to(device)
        optimizer.zero_grad()
        
        res = model(data.x, data.edge_index)
        loss = F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        loss.backward()
        
        total_loss += loss
        optimizer.step()
        
    return total_loss / len(train_dataset)

In [11]:
@torch.no_grad()
def validate(test_dataset):
    model.eval()
    
    tloss = 0
    for data in test_dataset:
        data = data.to(device)
        res = model(data.x, data.edge_index)
        
        # y, pred = data.node_label.cpu().numpy(), res.squeeze().cpu().numpy()
        
        tloss += F.binary_cross_entropy(res.squeeze(), data.node_label.to(torch.float))
        
    return tloss / len(test_dataset)

In [12]:
@torch.no_grad()
def test(test_dataset):
    model.eval()
    
    auc, ap, counter = 0, 0, 0
    for data in test_dataset:
        data = data.to(device)
        res = model(data.x, data.edge_index)
        
        y, pred = data.node_label.cpu().numpy(), res.squeeze().cpu().numpy()
        if y.sum() == 0 or y.sum() == len(y):
            continue
        counter += 1
        auc += roc_auc_score(y, pred)
        ap += average_precision_score(y, pred)
        
    return auc/counter, ap/counter

In [13]:
writer = SummaryWriter(log_dir='runs_g2b_10/1000_gcn_run5')

In [14]:
for epoch in range(1, 101):
    train_loss = train(train_loader)
    val_loss = validate(val_loader)
    
    scheduler.step(val_loss)
    
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/validate', val_loss, epoch)
    
    auc, ap = test(test_dataset)
    writer.add_scalar('AUC/validate', auc, epoch)
    writer.add_scalar('AP/validate', ap, epoch)
    
    if epoch % 10 == 0:
        print(f'{time.ctime()}  '
              f'Epoch: {epoch:04d}, train Loss: {train_loss:.4f}, '
              f'val Loss: {val_loss:.4f}, auc: {auc:.4f}, '
              f'ap: {ap:.4f}')

Fri Dec 31 12:06:11 2021  Epoch: 0010, train Loss: 0.5035, val Loss: 0.5147, auc: 0.8135, ap: 0.8599
Fri Dec 31 12:06:30 2021  Epoch: 0020, train Loss: 0.4404, val Loss: 0.4754, auc: 0.8547, ap: 0.8931
Fri Dec 31 12:06:49 2021  Epoch: 0030, train Loss: 0.3815, val Loss: 0.4919, auc: 0.8631, ap: 0.9016
Fri Dec 31 12:07:07 2021  Epoch: 0040, train Loss: 0.3657, val Loss: 0.4624, auc: 0.8738, ap: 0.9101
Fri Dec 31 12:07:27 2021  Epoch: 0050, train Loss: 0.3508, val Loss: 0.4382, auc: 0.8690, ap: 0.9082
Fri Dec 31 12:07:46 2021  Epoch: 0060, train Loss: 0.3297, val Loss: 0.4422, auc: 0.8842, ap: 0.9175
Fri Dec 31 12:08:06 2021  Epoch: 0070, train Loss: 0.3026, val Loss: 0.4311, auc: 0.8899, ap: 0.9205
Fri Dec 31 12:08:25 2021  Epoch: 0080, train Loss: 0.2922, val Loss: 0.4331, auc: 0.8882, ap: 0.9179
Fri Dec 31 12:08:45 2021  Epoch: 0090, train Loss: 0.2861, val Loss: 0.4425, auc: 0.8908, ap: 0.9212
Fri Dec 31 12:09:05 2021  Epoch: 0100, train Loss: 0.2768, val Loss: 0.4416, auc: 0.8908, a

In [15]:
writer.close()

In [21]:
model.eval()
for tld in test_dataset:
    tld = tld.to(device)
    res = model(tld.x, tld.edge_index)
    print(res.squeeze().sum().item(), tld.node_label.sum().item())

11.466941833496094 12.0
12.003521919250488 11.0
12.3941650390625 11.0
13.10908317565918 18.0
8.639545440673828 6.0
13.911587715148926 17.0
14.392096519470215 16.0
12.558234214782715 10.0
11.795515060424805 11.0
16.90052032470703 20.0
12.111684799194336 12.0
15.301536560058594 16.0
15.997536659240723 17.0
10.294456481933594 11.0
16.00666046142578 18.0
9.794851303100586 11.0
12.918076515197754 11.0
8.472949028015137 7.0
7.614141941070557 4.0
16.88071060180664 19.0
11.560836791992188 11.0
14.556867599487305 16.0
15.653470993041992 18.0
9.923681259155273 13.0
14.836398124694824 18.0
15.461889266967773 19.0
7.731366157531738 4.0
12.946213722229004 15.0
7.93955135345459 3.0
11.71007251739502 10.0
10.144815444946289 12.0
15.797547340393066 18.0
7.550556182861328 3.0
16.488765716552734 19.0
10.833463668823242 13.0
12.03402328491211 10.0
12.475933074951172 12.0
10.929224014282227 14.0
17.793888092041016 15.0
13.447530746459961 13.0
9.8836669921875 6.0
13.130849838256836 15.0
10.616541862487793 

In [23]:
for tld in test_dataset:
    print(tld.y.item(), tld.node_label.sum().item())

5 12.0
4 11.0
4 11.0
8 18.0
2 6.0
9 17.0
7 16.0
4 10.0
4 11.0
10 20.0
6 12.0
7 16.0
10 17.0
4 11.0
7 18.0
5 11.0
6 11.0
2 7.0
1 4.0
9 19.0
5 11.0
8 16.0
8 18.0
8 13.0
10 18.0
9 19.0
1 4.0
7 15.0
1 3.0
7 10.0
5 12.0
9 18.0
1 3.0
9 19.0
6 13.0
3 10.0
4 12.0
10 14.0
5 15.0
6 13.0
2 6.0
8 15.0
8 13.0
7 18.0
4 11.0
8 15.0
10 19.0
9 16.0
2 8.0
3 9.0
2 8.0
2 5.0
6 12.0
6 12.0
6 12.0
7 12.0
1 3.0
10 18.0
1 3.0
2 6.0
9 18.0
6 18.0
4 11.0
7 14.0
1 3.0
7 18.0
3 7.0
2 7.0
7 18.0
7 16.0
2 8.0
7 20.0
4 10.0
2 6.0
7 17.0
5 14.0
9 18.0
10 16.0
7 14.0
6 13.0
3 7.0
2 6.0
1 4.0
10 18.0
3 9.0
10 18.0
9 16.0
5 15.0
1 3.0
8 18.0
6 12.0
6 13.0
3 9.0
7 14.0
1 4.0
1 4.0
3 10.0
9 16.0
5 11.0
6 17.0
4 11.0
9 16.0
8 16.0
10 18.0
4 7.0
8 20.0
7 14.0
6 14.0
7 17.0
1 4.0
3 10.0
7 18.0
2 7.0
8 15.0
3 8.0
8 18.0
5 11.0
4 11.0
3 9.0
5 15.0
8 14.0
1 4.0
7 17.0
1 4.0
4 12.0
2 6.0
10 18.0
2 7.0
9 18.0
10 19.0
9 16.0
9 17.0
3 11.0
3 10.0
5 14.0
2 8.0
5 14.0
10 16.0
2 7.0
3 9.0
4 10.0
9 14.0
4 11.0
6 11.0
10 17.0
5 12.0
2 7