In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from tqdm import tqdm

from util import load_data, separate_data
from models.graphcnn import GraphCNN

In [14]:
criterion = nn.CrossEntropyLoss()

def train(model, device, train_graphs, optimizer, epoch, iters_per_epoch, batch_size):
    model.train()

    total_iters = iters_per_epoch
    pbar = tqdm(range(total_iters), unit='batch')

    loss_accum = 0
    for pos in pbar:
        selected_idx = np.random.permutation(len(train_graphs))[:batch_size]

        batch_graph = [train_graphs[idx] for idx in selected_idx]
        output = model(batch_graph)

        labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device)

        #compute loss
        loss = criterion(output, labels)

        #backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()         
            optimizer.step()
        

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

        #report
        pbar.set_description('epoch: %d' % (epoch))

    average_loss = loss_accum/total_iters
    print("loss training: %f" % (average_loss))
    
    return average_loss

# Pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation).
def pass_data_iteratively(model, graphs, minibatch_size = 64):
    model.eval()
    output = []
    idx = np.arange(len(graphs))
    for i in range(0, len(graphs), minibatch_size):
        sampled_idx = idx[i:i+minibatch_size]
        if len(sampled_idx) == 0:
            continue
        output.append(model([graphs[j] for j in sampled_idx]).detach())
    return torch.cat(output, 0)

def test(model, device, train_graphs, test_graphs, epoch):
    model.eval()

    output = pass_data_iteratively(model, train_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_train = correct / float(len(train_graphs))

    output = pass_data_iteratively(model, test_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_test = correct / float(len(test_graphs))

    print("accuracy train: %f test: %f" % (acc_train, acc_test))

    return acc_train, acc_test

def main(dataset: str = 'MUTAG',
         device: int = 0,
         batch_size: int = 32,
         iters_per_epoch: int = 1000,
         epochs: int = 350,
         lr: float = 0.01,
         seed: int = 0,
         fold_idx: int = 0,
         num_layers: int = 5,
         num_mlp_layers: int = 2,
         hidden_dim: int = 64,
         final_dropout: float = 0.5,
         graph_pooling_type: str = 'sum',
         neighbor_pooling_type: str = 'sum',
         learn_eps:bool = True,
         degree_as_tag:bool = True,
         filename: str = 'output file'):
    '''
    PyTorch graph convolutional neural net for whole-graph classification.

    Parameters:
    - dataset_name (str): Name of the dataset. Default is 'MUTAG'.
    - device (int): GPU device to use if any. Default is 0.
    - batch_size (int): Input batch size for training. Default is 32.
    - iters_per_epoch (int): Number of iterations per epoch. Default is 1000.
    - epochs (int): Number of epochs to train. Default is 350.
    - lr (float): Learning rate. Default is 0.01.
    - seed (int): Random seed for splitting the dataset into 10. Default is 0.
    - fold_idx (int): Index of fold in 10-fold validation. Should be less then 10. Default is 0.
    - num_layers (int): Number of layers, INCLUDING the input one. Default is 5.
    - num_mlp_layers (int): Number of layers for MLP, EXCLUDING the input one. Default is 2.
    - hidden_dim (int): Number of hidden units. Default is 64.
    - final_dropout (float): Final layer dropout. Default is 0.5.
    - graph_pooling_type (str): Pooling for over nodes in a graph. Default is 'sum'. Choices = ["sum", "average"].
    - neighbor_pooling_type (str): Pooling for over neighboring nodes. Default is 'sum'. Choices=["sum", "average", "max"].
    - learn_eps (bool): Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.
    - degree_as_tag (bool): Let the input node features be the degree of nodes (heuristics for unlabeled graph). 
    - filename (str): Output file name. Default is 'output file'.
    '''

    # Training 50
    # Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.

    #set up seeds and gpu device
    torch.manual_seed(0)
    np.random.seed(0)    
    device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    graphs, num_classes = load_data(dataset, degree_as_tag)

    ##10-fold cross validation. Conduct an experiment on the fold specified by fold_idx.
    train_graphs, test_graphs = separate_data(graphs, seed, fold_idx)

    model = GraphCNN(num_layers, num_mlp_layers, train_graphs[0].node_features.shape[1], hidden_dim, num_classes, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)


    for epoch in range(1, epochs + 1):
        scheduler.step()

        avg_loss = train(model, device, train_graphs, optimizer, epoch, iters_per_epoch, batch_size)
        acc_train, acc_test = test(model, device, train_graphs, test_graphs, epoch)

        if not filename == "":
            with open(filename, 'w') as f:
                f.write("%f %f %f" % (avg_loss, acc_train, acc_test))
                f.write("\n")
        print("")

        print(model.eps)

In [16]:
main(dataset='PROTEINS')

loading data




# classes: 2
# maximum node tag: 17
# data: 1113


epoch: 1:   9%|▉         | 90/1000 [00:02<00:25, 35.98batch/s]


KeyboardInterrupt: 