In [1]:
import time

import torch

import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, PNAConv, BatchNorm, global_add_pool

from gene_graph_dataset import GeneGraphDataset

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_p, test_p = 0.7, 0.2

In [3]:
dataset = GeneGraphDataset('dataset', 100, 5)
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [5]:
# len(train_dataset), len(test_dataset), len(val_dataset)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)

In [7]:
# len(train_loader), len(test_loader), len(val_loader)

In [8]:
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [9]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.node_emb = Embedding(21, 75)
        self.edge_emb = Embedding(4, 25)

        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = ModuleList()
        self.batch_norms = ModuleList()
        for _ in range(4):
            conv = PNAConv(in_channels=75, out_channels=75,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           edge_dim=50, towers=5, pre_layers=1, post_layers=1,
                           divide_input=False)
            # conv = GCNConv(in_channels=75, out_channels=75)
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(75))
            
        self.pre_lin = Linear(150,75)

        self.mlp = Sequential(Linear(75, 50), ReLU(), Linear(50, 25), ReLU(),
                              Linear(25, 1))

    def forward(self, x, edge_index, edge_attr, batch):
        
        x = torch.reshape(self.node_emb(x.squeeze()), (-1, 150))
        x = self.pre_lin(x)
        
        edge_attr = torch.reshape(self.edge_emb(edge_attr), (-1,50))

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))
            # x = F.relu(batch_norm(conv(x, edge_index)))
        x = global_add_pool(x, batch)        
        return self.mlp(x)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)

In [11]:
def train(train_loader):
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)

In [12]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)

In [13]:
writer = SummaryWriter(log_dir='runs/g2d')

In [14]:
result = torch.zeros(1000, 3)

for epoch in range(1, 1001):
    loss = train(train_loader)
    test_mae = test(test_loader)
    val_mae = test(val_loader)
    
    scheduler.step(val_mae)
    result[epoch - 1] = torch.tensor([loss, val_mae, test_mae])
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Loss/test', test_mae, epoch)
    writer.add_scalar('Loss/val', val_mae, epoch)
    
    if epoch % 50 == 0:
        print(f'{time.ctime()}\t'
              f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
              f'Test: {test_mae:.4f}')

Mon Dec 13 11:02:48 2021	Epoch: 050, Loss: 0.6048, Val: 0.6378, Test: 0.6549
Mon Dec 13 11:04:55 2021	Epoch: 100, Loss: 0.4603, Val: 0.6132, Test: 0.5862
Mon Dec 13 11:07:03 2021	Epoch: 150, Loss: 0.4343, Val: 0.5481, Test: 0.5113
Mon Dec 13 11:09:10 2021	Epoch: 200, Loss: 0.3479, Val: 0.6002, Test: 0.5975
Mon Dec 13 11:11:17 2021	Epoch: 250, Loss: 0.2553, Val: 0.4916, Test: 0.4296
Mon Dec 13 11:13:25 2021	Epoch: 300, Loss: 0.2291, Val: 0.4848, Test: 0.4255
Mon Dec 13 11:15:32 2021	Epoch: 350, Loss: 0.2122, Val: 0.4791, Test: 0.4217
Mon Dec 13 11:17:40 2021	Epoch: 400, Loss: 0.1973, Val: 0.4744, Test: 0.4251
Mon Dec 13 11:19:47 2021	Epoch: 450, Loss: 0.2023, Val: 0.4805, Test: 0.4244
Mon Dec 13 11:21:55 2021	Epoch: 500, Loss: 0.2326, Val: 0.4803, Test: 0.4254
Mon Dec 13 11:24:02 2021	Epoch: 550, Loss: 0.1972, Val: 0.4775, Test: 0.4305
Mon Dec 13 11:26:10 2021	Epoch: 600, Loss: 0.1855, Val: 0.4775, Test: 0.4327
Mon Dec 13 11:28:17 2021	Epoch: 650, Loss: 0.1987, Val: 0.4740, Test: 0.4256

In [15]:
writer.close()

In [16]:
model

Net(
  (node_emb): Embedding(21, 75)
  (edge_emb): Embedding(4, 25)
  (convs): ModuleList(
    (0): PNAConv(75, 75, towers=5, edge_dim=50)
    (1): PNAConv(75, 75, towers=5, edge_dim=50)
    (2): PNAConv(75, 75, towers=5, edge_dim=50)
    (3): PNAConv(75, 75, towers=5, edge_dim=50)
  )
  (batch_norms): ModuleList(
    (0): BatchNorm(75)
    (1): BatchNorm(75)
    (2): BatchNorm(75)
    (3): BatchNorm(75)
  )
  (pre_lin): Linear(in_features=150, out_features=75, bias=True)
  (mlp): Sequential(
    (0): Linear(in_features=75, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=25, bias=True)
    (3): ReLU()
    (4): Linear(in_features=25, out_features=1, bias=True)
  )
)