In [1]:
import pickle
import torch
from torch.nn import ELU, ModuleList
import torch.nn.functional as F
from torch_geometric.transforms import RandomNodeSplit
import random
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.nn import Linear
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.nn import global_mean_pool, global_max_pool, SAGPooling
from sklearn.metrics import average_precision_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set seeds for consistent results
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [3]:
def load(file):
    with open(file, 'rb') as f:
        return pickle.load(f)
    
def save(data, file):
    with open(file, 'wb') as f:
        pickle.dump(data, f)

In [4]:
device = torch.device('cuda')

In [5]:
train_dataset = load('data/train_graphs')
validate_dataset = load('data/validate_graphs')
test_dataset = load('data/test_graphs')

In [16]:
def train(model, optimizer):
    model.train()

    running_loss = 0.0
    for data in DataLoader(train_dataset, batch_size=32, shuffle=True):
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        criterion = torch.nn.CrossEntropyLoss()
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        optimizer.zero_grad() 
    
    return running_loss


@torch.no_grad()
def test(model, dataset):
    model.eval()

    running_loss = 0.0
    outs = list()
    ys = list()
    for data in DataLoader(dataset, batch_size=32, shuffle=False):
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        criterion = torch.nn.CrossEntropyLoss()
        loss = criterion(out, data.y)
        running_loss += loss.item()
        y = data.y.cpu()
        outs.append(F.softmax(out, dim=1).cpu()[:,1])
        ys.append(y)
    out = torch.cat(outs, dim=0)
    y = torch.cat(ys, dim=0)
    return running_loss, average_precision_score(y, out)

In [7]:
class GenModel(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.activation_conv = config['activation_conv']
        self.activation_lin = config['activation_lin']
        self.dropout = config['dropout']
        conv = config['conv']
        self.convs = ModuleList()
        self.convs.append(conv(1, config['conv_hidden_size']))
        for _ in range(0, config['num_conv']):
            self.convs.append(config['conv'](config['conv_hidden_size'], config['conv_hidden_size']))
        self.lins = ModuleList()
        if config['num_lin'] == 0:
            self.lins.append(Linear(3 * config['conv_hidden_size'], 2))
        else:
            self.lins.append(Linear(3 * config['conv_hidden_size'], config['lin_hidden_size']))
            for _ in range(1, config['num_lin']):
                self.lins.append(Linear(config['lin_hidden_size'], config['lin_hidden_size']))
            self.lins.append(Linear(config['lin_hidden_size'], 2))
    
    def forward(self, x, edge_index, batch):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = self.activation_conv(x)
            x = F.dropout(x, p=self.dropout)
        self.convs[-1](x, edge_index)
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x_min = global_max_pool(-x, batch)
        x = torch.cat([x_mean, x_min, x_max], dim=1)
        for lin in self.lins[:-1]:
            x = lin(x)
            x = self.activation_lin(x)
        x = self.lins[-1](x)
        return x

In [8]:
def sage_max_conv(size1, size2):
    return SAGEConv(size1, size2, aggr='max')

def gatconv2(size1, size2):
    return GATConv(size1, int(size2/2), heads=2)

def gatconv4(size1, size2):
    return GATConv(size1, int(size2/4), heads=4)

def run(configs, name):
    random.shuffle(configs)
    best = 0.0
    for config in configs:
        print(config)
        model = GenModel(config).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
        loss_prev = float('NaN')
        acc_prev = float('NaN')
        for epoch in range(1, 30):
            loss = train(model, optimizer)
            v_loss, v_acc = test(model, validate_dataset)
            if v_acc > best:
                best = v_acc
                torch.save(model.state_dict(), f'models/{name}_state')
                config['epoch'] = epoch
                config['score'] = v_acc
                save(str(config), f'{name}_config')
                print(f'New best: {v_acc:.4f}: {config} ({epoch} epochs)')
            print(f'  Epoch: {epoch:03d}, loss: {loss:.4f}, validation loss: {v_loss:.4f}, validation score: {v_acc:.4f}')
            if epoch > 15 and v_loss > loss_prev and v_acc < acc_prev:
                break
            loss_prev = v_loss
            acc_prev = v_acc

In [9]:
configs = list()
for lr in [0.01, 0.005, 0.001, 0.0005, 0.0001]:
    for conv in [SAGEConv, sage_max_conv, GCNConv, gatconv2, gatconv4]:
        for dropout in [0.1, 0.3]:
            for activation_conv in [F.relu, F.elu]:
                for activation_lin in [F.relu, torch.sigmoid]:
                    for num_conv in [1, 2, 3, 4]:
                        for num_lin in [0, 1, 2, 3, 4]:
                            for conv_hidden_size in [32, 64, 128]:
                                for lin_hidden_size in [64, 128, 256]:
                                    config = {
                                        'lr': lr,
                                        'conv': conv,
                                        'dropout': dropout,
                                        'activation_conv': activation_conv,
                                        'activation_lin': activation_lin,
                                        'num_conv': num_conv,
                                        'num_lin': num_lin,
                                        'conv_hidden_size': conv_hidden_size,
                                        'lin_hidden_size': lin_hidden_size,
                                    }
                                    configs.append(config)

run(configs, 'bestmodel')

{'lr': 0.01, 'conv': <class 'torch_geometric.nn.conv.gcn_conv.GCNConv'>, 'dropout': 0.1, 'activation_conv': <function elu at 0x7f5ee4d16040>, 'activation_lin': <built-in method sigmoid of type object at 0x7f60c50ee4c0>, 'num_conv': 1, 'num_lin': 3, 'conv_hidden_size': 64, 'lin_hidden_size': 256}
New best: 0.3791: {'lr': 0.01, 'conv': <class 'torch_geometric.nn.conv.gcn_conv.GCNConv'>, 'dropout': 0.1, 'activation_conv': <function elu at 0x7f5ee4d16040>, 'activation_lin': <built-in method sigmoid of type object at 0x7f60c50ee4c0>, 'num_conv': 1, 'num_lin': 3, 'conv_hidden_size': 64, 'lin_hidden_size': 256, 'epoch': 1, 'score': 0.3791485382435702} (1 epochs)
  Epoch: 001, loss: 100.0781, validation loss: 20.6406, validation score: 0.3791
  Epoch: 002, loss: 96.1813, validation loss: 20.6779, validation score: 0.3703
  Epoch: 003, loss: 95.1316, validation loss: 20.5116, validation score: 0.3277
  Epoch: 004, loss: 95.1100, validation loss: 20.5284, validation score: 0.3631
  Epoch: 005, l

KeyboardInterrupt: 

In [24]:
manual_graphs = load('data/manual_graphs')
config = load(f'bestmodel_config')
model = GenModel({
    'conv': sage_max_conv,
    'dropout': 0.1,
    'activation_conv': F.relu,
    'activation_lin': F.relu,
    'num_conv': 4,
    'num_lin': 2,
    'conv_hidden_size': 128,
    'lin_hidden_size': 256,
})
model.load_state_dict(torch.load(f'bestmodel_state'))
model = model.to(device)
print(config)
test(model, manual_graphs)

{'lr': 0.0005, 'conv': <function sage_max_conv at 0x7f5d93af3550>, 'dropout': 0.1, 'activation_conv': <function relu at 0x7f5ee4d14dc0>, 'activation_lin': <function relu at 0x7f5ee4d14dc0>, 'num_conv': 4, 'num_lin': 2, 'conv_hidden_size': 128, 'lin_hidden_size': 256, 'epoch': 15, 'score': 0.40303174096032623}


(1.8160532712936401, 0.822974342572982)