In [1]:
import time

import torch
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch_geometric.transforms as T

from torch_geometric.nn import GCNConv, VGAE, PNAConv, BatchNorm
from torch_geometric.loader import DataLoader

from gene_graph_dataset import G2GraphDataset

from torch_geometric.utils import (degree,
                                   negative_sampling, 
                                   remove_self_loops,
                                   add_self_loops)

from torch.utils.tensorboard import SummaryWriter

In [2]:
gpuid = 0

train_p, test_p = 0.8, 0.2

In [3]:
device = torch.device('cuda:' + str(gpuid) if torch.cuda.is_available() else 'cpu')

In [4]:
dataset = G2GraphDataset('dataset_g2g', 100, 3)

In [5]:
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [6]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
# val_dataset = dataset[(train_size + test_size):]

In [7]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
# val_loader = DataLoader(val_dataset, batch_size=8)

In [8]:
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [9]:
class EncoderNet(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderNet, self).__init__()

        self.node_emb = Embedding(21, 75)        

        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']
        
        self.conv_mu = GCNConv(75, out_channels)
        self.conv_logstd = GCNConv(75, out_channels)
        
        self.convs = ModuleList()
        self.batch_norms = ModuleList()
        for _ in range(4):
            conv = PNAConv(in_channels=75, out_channels=75,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           towers=5, pre_layers=1, post_layers=1,
                           divide_input=False)
            # conv = GCNConv(in_channels=75, out_channels=75)
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(75))
            
        self.pre_lin = Linear(150,75)
        
    def forward(self, x, edge_index):
        
        x = torch.reshape(self.node_emb(x.squeeze().to(torch.int)), (-1, 150))
        x = self.pre_lin(x)
        
        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = batch_norm(conv(x, edge_index)).relu()
            # x = F.relu(batch_norm(conv(x, edge_index)))
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [10]:
in_channels, out_channels = dataset.num_features, 16

In [11]:
# model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
model = VGAE(EncoderNet(in_channels, out_channels))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

In [12]:
writer = SummaryWriter(log_dir='runs/g2g_100_nattr')

In [13]:
def train(train_loader):
    model.train()
    
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        # z = model.encode(data.x, data.edge_index)
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss(z, data.pos_edge_label_index, 
                                data.neg_edge_label_index)
        loss = loss + (1 / data.num_nodes) * model.kl_loss()
        loss.backward()
        optimizer.step()
        
        total_loss += loss
    return total_loss/len(train_loader)

@torch.no_grad()
def test(test_loader):
    model.eval()
    auc, ap = 0, 0
    
    for data in test_loader:
        
        data = data.to(device)
        # z = model.encode(data.x, data.edge_index)
        z = model.encode(data.x, data.edge_index)
        tauc, tap = model.test(z, data.pos_edge_label_index, 
                               data.neg_edge_label_index)
        auc += tauc
        ap += tap
        
    return auc/len(test_loader), ap/len(test_loader)

In [14]:
for epoch in range(1, 2000 + 1):
    loss = train(train_loader)
    auc, ap = test(test_loader)
    
    scheduler.step(1 - auc)
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('AUC/test', auc, epoch)
    writer.add_scalar('AP/test', ap, epoch)
    if epoch % 50 == 0:
        print(f'{time.ctime()}\t'
              f'Epoch: {epoch:03d}, loss: {loss:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}')

Tue Dec 14 10:12:51 2021	Epoch: 050, loss: 1.0725, AUC: 0.8778, AP: 0.0246
Tue Dec 14 10:16:00 2021	Epoch: 100, loss: 1.0536, AUC: 0.8897, AP: 0.0268
Tue Dec 14 10:19:09 2021	Epoch: 150, loss: 1.0463, AUC: 0.9033, AP: 0.0293
Tue Dec 14 10:22:18 2021	Epoch: 200, loss: 1.0449, AUC: 0.9041, AP: 0.0296
Tue Dec 14 10:25:27 2021	Epoch: 250, loss: 1.0452, AUC: 0.9043, AP: 0.0297
Tue Dec 14 10:28:38 2021	Epoch: 300, loss: 1.0438, AUC: 0.9043, AP: 0.0297
Tue Dec 14 10:31:49 2021	Epoch: 350, loss: 1.0452, AUC: 0.9042, AP: 0.0297
Tue Dec 14 10:34:59 2021	Epoch: 400, loss: 1.0443, AUC: 0.9045, AP: 0.0298
Tue Dec 14 10:38:11 2021	Epoch: 450, loss: 1.0433, AUC: 0.9046, AP: 0.0298
Tue Dec 14 10:41:49 2021	Epoch: 500, loss: 1.0442, AUC: 0.9046, AP: 0.0298
Tue Dec 14 10:45:27 2021	Epoch: 550, loss: 1.0429, AUC: 0.9047, AP: 0.0299
Tue Dec 14 10:49:05 2021	Epoch: 600, loss: 1.0438, AUC: 0.9047, AP: 0.0299
Tue Dec 14 10:52:43 2021	Epoch: 650, loss: 1.0424, AUC: 0.9049, AP: 0.0300
Tue Dec 14 10:56:20 2021	

In [15]:
writer.close()