In [1]:
import time

import torch

import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.utils import degree
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, PNAConv, BatchNorm, global_add_pool

from gene_graph_dataset import GeneGraphDataset

In [2]:
train_p, test_p = 0.7, 0.2

In [3]:
dataset = GeneGraphDataset('dataset', 100, 5)
data_size = len(dataset)
train_size, test_size = (int)(data_size * train_p), (int)(data_size * test_p)

In [4]:
dataset = dataset.shuffle()
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:(train_size + test_size)]
val_dataset = dataset[(train_size + test_size):]

In [5]:
# len(train_dataset), len(test_dataset), len(val_dataset)

(350, 100, 50)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)

In [7]:
# len(train_loader), len(test_loader), len(val_loader)

(3, 1, 1)

In [8]:
deg = torch.zeros(5, dtype=torch.long)
for data in train_dataset:
    d = degree(data.edge_index[1].type(torch.int64), 
               num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

In [9]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.node_emb = Embedding(21, 75)
        self.edge_emb = Embedding(4, 25)

        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = ModuleList()
        self.batch_norms = ModuleList()
        for _ in range(4):
            conv = PNAConv(in_channels=75, out_channels=75,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           edge_dim=50, towers=5, pre_layers=1, post_layers=1,
                           divide_input=False)
            # conv = GCNConv(in_channels=75, out_channels=75)
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(75))
            
        self.pre_lin = Linear(150,75)

        self.mlp = Sequential(Linear(75, 50), ReLU(), Linear(50, 25), ReLU(),
                              Linear(25, 1))

    def forward(self, x, edge_index, edge_attr, batch):
        
        x = torch.reshape(self.node_emb(x.squeeze()), (-1, 150))
        x = self.pre_lin(x)
        
        edge_attr = torch.reshape(self.edge_emb(edge_attr), (-1,50))

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))
            # x = F.relu(batch_norm(conv(x, edge_index)))
        x = global_add_pool(x, batch)        
        return self.mlp(x)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)

In [11]:
def train(train_loader):
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)

In [12]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)

In [13]:
for epoch in range(1, 301):
    loss = train(train_loader)
    test_mae = test(test_loader)
    val_mae = test(val_loader)
    
    scheduler.step(val_mae)
    # if epoch % 10 == 0:
    print(f'{time.ctime()}\t'
          f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')

Mon Nov 22 17:11:17 2021	Epoch: 01, Loss: 6.1938, Val: 3.1765, Test: 2.9964
Mon Nov 22 17:11:19 2021	Epoch: 02, Loss: 1.7099, Val: 2.7765, Test: 2.5961
Mon Nov 22 17:11:22 2021	Epoch: 03, Loss: 1.6319, Val: 2.2906, Test: 2.1094
Mon Nov 22 17:11:24 2021	Epoch: 04, Loss: 1.2005, Val: 2.0680, Test: 1.9179
Mon Nov 22 17:11:27 2021	Epoch: 05, Loss: 1.1400, Val: 1.9023, Test: 1.7593
Mon Nov 22 17:11:29 2021	Epoch: 06, Loss: 1.1348, Val: 1.7413, Test: 1.5886
Mon Nov 22 17:11:32 2021	Epoch: 07, Loss: 1.1368, Val: 1.8401, Test: 1.6766
Mon Nov 22 17:11:34 2021	Epoch: 08, Loss: 1.0848, Val: 1.7011, Test: 1.5493
Mon Nov 22 17:11:36 2021	Epoch: 09, Loss: 1.0623, Val: 1.8407, Test: 1.6330
Mon Nov 22 17:11:39 2021	Epoch: 10, Loss: 1.0582, Val: 2.7903, Test: 2.6462
Mon Nov 22 17:11:41 2021	Epoch: 11, Loss: 1.0253, Val: 3.1026, Test: 2.9518
Mon Nov 22 17:11:44 2021	Epoch: 12, Loss: 1.0080, Val: 3.4607, Test: 3.3015
Mon Nov 22 17:11:46 2021	Epoch: 13, Loss: 1.2310, Val: 3.7483, Test: 3.5069
Mon Nov 22 1