## Utils

In [169]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

In [170]:
import math
import numpy as np
import torch
from texttable import Texttable

def tab_printer(args):
    """
    Function to print the logs in a nice tabular format.
    :param args: Parameters used for the model.
    """
    args = vars(args)
    keys = sorted(args.keys())
    t = Texttable()
    t.add_rows([["Parameter", "Value"]])
    t.add_rows([[k.replace("_", " ").capitalize(), args[k]] for k in keys])
    print(t.draw())

def denormalize_sim_score(g1, g2, sim_score):
    """
    Converts normalized similarity into ged.
    """
    return denormalize_ged(g1, g2, -math.log(sim_score, math.e))


def denormalize_ged(g1, g2, nged):
    """
    Converts normalized ged into ged.
    """
    return round(nged * (g1.num_nodes + g2.num_nodes) / 2) if nged != np.inf else np.inf

def to_directed(edge_index):
    row, col = edge_index
    mask = row < col
    row, col = row[mask], col[mask]
    return torch.stack([row, col], dim=0)

def calculate_ranking_correlation(rank_corr_function, prediction, target):
    """
    Calculating specific ranking correlation for predicted values.
    :param rank_corr_function: Ranking correlation function.
    :param prediction: Vector of predicted values.
    :param target: Vector of ground-truth values.
    :return ranking: Ranking correlation value.
    """
    temp = prediction.argsort()
    r_prediction = np.empty_like(temp)
    r_prediction[temp] = np.arange(len(prediction))

    temp = target.argsort()
    r_target = np.empty_like(temp)
    r_target[temp] = np.arange(len(target))

    return rank_corr_function(r_prediction, r_target).correlation

def calculate_prec_at_k(k, prediction, target):
    """
    Calculating precision at k.
    """
    best_k_pred = prediction.argsort()[:k]
    best_k_target = target.argsort()[:k]

    return len(set(best_k_pred).intersection(set(best_k_target))) / k

## Param Parser

In [171]:
import argparse

#TODO: Added 'args' parameter in order to be able to parse arguments using a pynotebook (it was CLI only)
def parameter_parser(args : list):
    """
    A method to parse up command line parameters.
    The default hyperparameters give a high performance model without grid search.
    """
    parser = argparse.ArgumentParser(description="Run SimGNN.")

    parser.add_argument("--dataset",
                        nargs="?",
                        default="AIDS700nef",  # AIDS700nef LINUX IMDBMulti
                        help="Dataset name. Default is AIDS700nef")
    parser.add_argument("--epochs",
                        type=int,
                        default=5,
	                help="Number of training epochs. Default is 5.")

    parser.add_argument("--filters-1",
                        type=int,
                        default=128,
	                help="Filters (neurons) in 1st convolution. Default is 128.")

    parser.add_argument("--filters-2",
                        type=int,
                        default=64,
	                help="Filters (neurons) in 2nd convolution. Default is 64.")

    parser.add_argument("--filters-3",
                        type=int,
                        default=32,
	                help="Filters (neurons) in 3rd convolution. Default is 32.")

    parser.add_argument("--tensor-neurons",
                        type=int,
                        default=16,
	                help="Neurons in tensor network layer. Default is 16.")

    parser.add_argument("--bottle-neck-neurons",
                        type=int,
                        default=16,
	                help="Bottle neck layer neurons. Default is 16.")

    parser.add_argument("--batch-size",
                        type=int,
                        default=128,
	                help="Number of graph pairs per batch. Default is 128.")

    parser.add_argument("--bins",
                        type=int,
                        default=16,
	                help="Similarity score bins. Default is 16.")

    parser.add_argument("--dropout",
                        type=float,
                        default=0.5,
	                help="Dropout probability. Default is 0.5.")

    parser.add_argument("--learning-rate",
                        type=float,
                        default=0.001,
	                help="Learning rate. Default is 0.001.")

    parser.add_argument("--weight-decay",
                        type=float,
                        default=5*10**-4,
	                help="Adam weight decay. Default is 5*10^-4.")

    parser.add_argument("--histogram",
                        dest="histogram",
                        action="store_true")

    parser.set_defaults(histogram=False)

    parser.add_argument("--save-path",
                        type=str,
                        default=None,
                        help="Where to save the trained model")

    parser.add_argument("--load-path",
                        type=str,
                        default=None,
                        help="Load a pretrained model")

    return parser.parse_args(args)

## Layers

In [172]:
import torch
import torch_scatter

#TODO: Replace with torch.scatter
def scatter_(name, src, index, dim=0, dim_size=None):
    assert name in ['add', 'mean', 'min', 'max']

    op = getattr(torch_scatter, 'scatter_{}'.format(name))
    out = op(src, index, dim, None, dim_size)
    out = out[0] if isinstance(out, tuple) else out

    if name == 'max':
        out[out < -10000] = 0
    elif name == 'min':
        out[out > 10000] = 0

    return out

class AttentionModule(torch.nn.Module):
    """
    SimGNN Attention Module to make a pass on graph.
    """
    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(AttentionModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(torch.Tensor(self.args.filters_3, self.args.filters_3))

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)

    def forward(self, x, batch, size=None):
        """
        Making a forward propagation pass to create a graph level representation.
        :param x: Result of the GNN.
        :param batch: Batch vector, which assigns each node to a specific example
        :return representation: A graph level representation matrix. 
        """
        size = batch[-1].item() + 1 if size is None else size
        mean = scatter_('mean', x, batch, dim_size=size)
        transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))

        coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))
        weighted = coefs.unsqueeze(-1) * x

        return scatter_('add', weighted, batch, dim_size=size)

class TensorNetworkModule(torch.nn.Module):
    """
    SimGNN Tensor Network module to calculate similarity vector.
    """
    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(TensorNetworkModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(torch.Tensor(self.args.filters_3, self.args.filters_3, self.args.tensor_neurons))
        self.weight_matrix_block = torch.nn.Parameter(torch.Tensor(self.args.tensor_neurons, 2*self.args.filters_3))
        self.bias = torch.nn.Parameter(torch.Tensor(self.args.tensor_neurons, 1))

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)
        torch.nn.init.xavier_uniform_(self.weight_matrix_block)
        torch.nn.init.xavier_uniform_(self.bias)

    def forward(self, embedding_1, embedding_2):
        """
        Making a forward propagation pass to create a similarity vector.
        :param embedding_1: Result of the 1st embedding after attention.
        :param embedding_2: Result of the 2nd embedding after attention.
        :return scores: A similarity score vector.
        """
        batch_size = len(embedding_1)
        scoring = torch.matmul(embedding_1, self.weight_matrix.view(self.args.filters_3, -1))
        scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1])
        scoring = torch.matmul(scoring, embedding_2.view(batch_size, self.args.filters_3, 1)).view(batch_size, -1)
        combined_representation = torch.cat((embedding_1, embedding_2), 1)
        block_scoring = torch.t(torch.mm(self.weight_matrix_block, torch.t(combined_representation)))
        scores = torch.nn.functional.relu(scoring + block_scoring + self.bias.view(-1))
        return scores

## SimGNN

In [173]:
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_batch

class SimGNN(torch.nn.Module):
    """
    SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
    https://arxiv.org/abs/1808.05689
    """
    def __init__(self, args, number_of_node_labels, number_of_edge_labels):
        """
        :param args: Arguments object.
        :param number_of_labels: Number of node labels.
        """
        super(SimGNN, self).__init__()
        self.args = args
        self.number_node_labels = number_of_node_labels
        # self.number_edge_labels = number_of_edge_labels
        self.setup_layers()    
        
    def calculate_bottleneck_features(self):
        """
        Deciding the shape of the bottleneck layer.
        """
        if self.args.histogram == True:
            self.feature_count = self.args.tensor_neurons + self.args.bins
        else:
            self.feature_count = self.args.tensor_neurons
            
    def setup_layers(self):
        """
        Creating the layers.
        """
        self.calculate_bottleneck_features()
        self.convolution_1 = GCNConv(self.number_node_labels, self.args.filters_1)
        self.convolution_2 = GCNConv(self.args.filters_1, self.args.filters_2)
        self.convolution_3 = GCNConv(self.args.filters_2, self.args.filters_3)
        self.attention = AttentionModule(self.args)
        self.tensor_network = TensorNetworkModule(self.args)
        self.fully_connected_first = torch.nn.Linear(self.feature_count, self.args.bottle_neck_neurons)
        self.scoring_layer = torch.nn.Linear(self.args.bottle_neck_neurons, 1)
        
    def calculate_histogram(self, abstract_features_1, abstract_features_2, batch_1, batch_2):
        """
        Calculate histogram from similarity matrix.
        :param abstract_features_1: Feature matrix for target graphs.
        :param abstract_features_2: Feature matrix for source graphs.
        :param batch_1: Batch vector for source graphs, which assigns each node to a specific example
        :param batch_1: Batch vector for target graphs, which assigns each node to a specific example
        :return hist: Histsogram of similarity scores.
        """
        abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)
        abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)

        B1, N1, _ = abstract_features_1.size()
        B2, N2, _ = abstract_features_2.size()

        mask_1 = mask_1.view(B1, N1)
        mask_2 = mask_2.view(B2, N2)
        num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))

        scores = torch.matmul(abstract_features_1, abstract_features_2.permute([0, 2, 1])).detach()

        hist_list = []
        for i, mat in enumerate(scores):
            mat = torch.sigmoid(mat[:num_nodes[i], :num_nodes[i]]).view(-1)
            hist = torch.histc(mat, bins=self.args.bins)
            hist = hist / torch.sum(hist)
            hist = hist.view(1, -1)
            hist_list.append(hist)

        return torch.stack(hist_list).view(-1, self.args.bins)
    
    def convolutional_pass(self, edge_index, features):
        """
        Making convolutional pass.
        :param edge_index: Edge indices.
        :param features: Feature matrix.
        :return features: Absstract feature matrix.
        """
        features = self.convolution_1(features, edge_index)
        features = torch.nn.functional.relu(features)
        features = torch.nn.functional.dropout(features,
                                               p=self.args.dropout,
                                               training=self.training)

        features = self.convolution_2(features, edge_index)
        features = torch.nn.functional.relu(features)
        features = torch.nn.functional.dropout(features,
                                               p=self.args.dropout,
                                               training=self.training)

        features = self.convolution_3(features, edge_index)
        return features
    
    def forward(self, data):
        """
        Forward pass with graphs.
        :param data: Data dictionary.
        :return score: Similarity score.
        """

        # Graph attribute initiallizations
        g1 = data["g1"]
        g2 = data["g2"]

        pooled_features_1, abstract_features_1, batch_1 = self.get_embedding(g1)
        pooled_features_2, abstract_features_2, batch_2 = self.get_embedding(g2)

        # Output scores' vector
        scores = self.tensor_network(pooled_features_1, pooled_features_2)

        if self.args.histogram:
            hist = self.calculate_histogram(abstract_features_1, abstract_features_2, batch_1, batch_2)
            scores = torch.cat((scores, hist), dim=1)

        # Final connected layers
        scores = torch.nn.functional.relu(self.fully_connected_first(scores))
        score = torch.sigmoid(self.scoring_layer(scores)).view(-1)

        return score
    
    def get_embedding(self, g):
        """
        This method implements the first 2 steps of the pipeline for a given graph. It used to be directly implemented
        in the forward() method. It returns intermediate data too, because they are needed from "Strategy 2" module.
        :param g: The graph we want to take the embedding for.
        :return: The embedding of the given graph g, abstract_features(node embeddings) and batch.
        """
        # Graph attribute initiallizations
        edge_index = g.edge_index

        # edge_attr = g.edge_attr
        features = g.x
        batch = g.batch if hasattr(g, 'batch') else torch.tensor((), dtype=torch.long).new_zeros(g.num_nodes)

        # Node embeddings
        abstract_features = self.convolutional_pass(edge_index, features)

        # Graph embeddings
        graph_embedding = self.attention(abstract_features, batch)

        return graph_embedding, abstract_features, batch

## GEDDataset

In [174]:
import os.path as osp
import os
import urllib

def download_url(url, folder, log=True):
    r"""Downloads the content of an URL to a specific folder.

    Args:
        url (string): The url.
        folder (string): The folder.
        log (bool, optional): If :obj:`False`, will not print anything to the
            console. (default: :obj:`True`)
    """
    filename = url.rpartition('/')[2]
    path = osp.join(folder, filename)

    if osp.exists(path):  # pragma: no cover
        if log:
            print('Using exist file', filename)
        return path

    if log:
        print('Downloading', url)
        
    # Create Folder if not exist
    if not os.path.exists(folder):
        os.makedirs(folder)
        
    data = urllib.request.urlopen(url)

    path = path.replace('?', '')
    with open(path, 'wb', ) as f:
        f.write(data.read())

    return path

In [175]:
import os
import os.path as osp
import glob
import pickle

import torch
import torch.nn.functional as F
import networkx as nx
from torch_geometric.data import (InMemoryDataset, Data, extract_zip, extract_tar)
from torch_geometric.utils import to_undirected


class GEDDataset(InMemoryDataset):

    url = 'https://drive.google.com/uc?export=download&id={}'

    datasets = {
        'AIDS700nef': {
            'id': '10czBPJDEzEDI2tq7Z7mkBjLhj55F-a2z',
            'extract': extract_zip,
            'pickle': '1OpV4bCHjBkdpqI6H5Mg0-BqlA2ee2eBW',
        },
        'LINUX': {
            'id': '1nw0RRVgyLpit4V4XFQyDy0pI6wUEXSOI',
            'extract': extract_tar,
            'pickle': '14FDm3NSnrBvB7eNpLeGy5Bz6FjuCSF5v',
        },
        'ALKANE': {
            'id': '1-LmxaWW3KulLh00YqscVEflbqr0g4cXt',
            'extract': extract_tar,
            'pickle': '15BpvMuHx77-yUGYgM27_sQett02HQNYu',
        },
        'IMDBMulti': {
            'id': '12QxZ7EhYA7pJiF4cO-HuE8szhSOWcfST',
            'extract': extract_zip,
            'pickle': '1wy9VbZvZodkixxVIOuRllC-Lp-0zdoYZ',
        },
    }

    types = [
        'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',
        'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',
        'Sb', 'Se', 'Ni', 'Te'
    ]

    def __init__(self,
                 root,
                 name,
                 train=True,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None):
        self.name = name
        assert self.name in self.datasets.keys()
        super(GEDDataset, self).__init__(root, transform, pre_transform, pre_filter)
        path = self.processed_paths[0] if train else self.processed_paths[1]
        self.data, self.slices = torch.load(path)
        path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name))
        self.ged = torch.load(path)
        path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name))
        self.norm_ged = torch.load(path)

    @property
    def raw_file_names(self):
        return [osp.join(self.name, s) for s in ['train', 'test']]

    @property
    def processed_file_names(self):
        return ['{}_{}.pt'.format(self.name, s) for s in ['training', 'test']]

    def download(self):
        name = self.datasets[self.name]['id']
        path = download_url(self.url.format(name), self.raw_dir)
        self.datasets[self.name]['extract'](path, self.raw_dir)
        os.unlink(path)

        name = self.datasets[self.name]['pickle']
        path = download_url(self.url.format(name), self.raw_dir)
        os.rename(path, osp.join(self.raw_dir, self.name, 'ged.pickle'))

    def process(self):
        ids, Ns = [], []
        for r_path, p_path in zip(self.raw_paths, self.processed_paths):
            names = glob.glob(osp.join(r_path, '*.gexf'))
            ids.append(sorted([int(osp.basename(i)[:-5]) for i in names]))

            data_list = []
            for i, idx in enumerate(ids[-1]):
                i = i if len(ids) == 1 else i + len(ids[0])
                G = nx.read_gexf(osp.join(r_path, '{}.gexf'.format(idx)))
                mapping = {name: j for j, name in enumerate(G.nodes())}
                G = nx.relabel_nodes(G, mapping)
                Ns.append(G.number_of_nodes())
                edge_index = torch.tensor(list(G.edges)).t().contiguous()
                if edge_index.numel() == 0:
                    edge_index = torch.empty((2, 0), dtype=torch.long)
                edge_index = to_undirected(edge_index, num_nodes=Ns[-1])

                data = Data(edge_index=edge_index, i=i)
                data.num_nodes = Ns[-1]

                if self.name == 'AIDS700nef':
                    x = torch.zeros(data.num_nodes, dtype=torch.long)
                    for node, info in G.nodes(data=True):
                        x[int(node)] = self.types.index(info['type'])
                    data.x = F.one_hot(
                        x, num_classes=len(self.types)).to(torch.float)

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                data_list.append(data)
            torch.save(self.collate(data_list), p_path)

        assoc = {idx: i for i, idx in enumerate(ids[0])}
        assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])})

        path = osp.join(self.raw_dir, self.name, 'ged.pickle')
        mat = torch.full((len(assoc), len(assoc)), float('inf'))
        with open(path, 'rb') as f:
            obj = pickle.load(f)
            xs, ys, gs = [], [], []
            for (x, y), g in obj.items():
                xs += [assoc[x]]
                ys += [assoc[y]]
                gs += [g]
            x, y, g = torch.tensor(xs), torch.tensor(ys), torch.tensor(gs, dtype=torch.float)
            mat[x, y], mat[y, x] = g, g

        path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name))
        torch.save(mat, path)

        N = torch.tensor(Ns, dtype=torch.float)
        norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1)))

        path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name))
        torch.save(norm_mat, path)

    def __repr__(self):
        return '{}({})'.format(self.name, len(self))

In [176]:
#TODO: in the future
from torch_geometric.data import (InMemoryDataset)

class CustomDataset(InMemoryDataset):
    def __init__(self,
                 root,
                 name,
                 train=True,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None):
        self.name = name
        super(CustomDataset, self).__init__(root, transform, pre_transform, pre_filter)
        path = self.processed_paths[0] if train else self.processed_paths[1]
        print(path)
        # self.data, self.slices = torch.load(path)
        # path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name))
        # self.ged = torch.load(path)
        # path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name))
        # self.norm_ged = torch.load(path)
        
    def process(self):
        ids, Ns = [], []
        for r_path, p_path in zip(self.raw_paths, self.processed_paths):
            names = glob.glob(osp.join(r_path, '*.json'))
            ids.append(sorted([int(i.split(os.sep)[-1][:-5]) for i in names]))

            print(ids)
            data_list = []
            for i, idx in enumerate(ids[-1]):
                i = i if len(ids) == 1 else i + len(ids[0])
                G = nx.read_gexf(osp.join(r_path, '{}.gexf'.format(idx)))
        #         mapping = {name: j for j, name in enumerate(G.nodes())}
        #         G = nx.relabel_nodes(G, mapping)
        #         Ns.append(G.number_of_nodes())
        #         edge_index = torch.tensor(list(G.edges)).t().contiguous()
        #         if edge_index.numel() == 0:
        #             edge_index = torch.empty((2, 0), dtype=torch.long)
        #         edge_index = to_undirected(edge_index, num_nodes=Ns[-1])

        #         data = Data(edge_index=edge_index, i=i)
        #         data.num_nodes = Ns[-1]

        #         if self.name == 'AIDS700nef':
        #             x = torch.zeros(data.num_nodes, dtype=torch.long)
        #             for node, info in G.nodes(data=True):
        #                 x[int(node)] = self.types.index(info['type'])
        #             data.x = F.one_hot(
        #                 x, num_classes=len(self.types)).to(torch.float)

        #         if self.pre_filter is not None and not self.pre_filter(data):
        #             continue
        #         if self.pre_transform is not None:
        #             data = self.pre_transform(data)
        #         data_list.append(data)
        #     torch.save(self.collate(data_list), p_path)

        # assoc = {idx: i for i, idx in enumerate(ids[0])}
        # assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])})

        # path = osp.join(self.raw_dir, self.name, 'ged.pickle')
        # mat = torch.full((len(assoc), len(assoc)), float('inf'))
        # with open(path, 'rb') as f:
        #     obj = pickle.load(f)
        #     xs, ys, gs = [], [], []
        #     for (x, y), g in obj.items():
        #         xs += [assoc[x]]
        #         ys += [assoc[y]]
        #         gs += [g]
        #     x, y, g = torch.tensor(xs), torch.tensor(ys), torch.tensor(gs, dtype=torch.float)
        #     mat[x, y], mat[y, x] = g, g

        # path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name))
        # torch.save(mat, path)

        # N = torch.tensor(Ns, dtype=torch.float)
        # norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1)))

        # path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name))
        # torch.save(norm_mat, path)
    
    @property
    def raw_file_names(self):
        return [osp.join(self.name, s) for s in ['train', 'test']]
    
    @property
    def processed_file_names(self):
        return ['{}_{}.pt'.format(self.name, s) for s in ['training', 'test']]
    
    def __repr__(self):
        return '{}({})'.format(self.name, len(self))
    
# TRAIN = CustomDataset("datasets/extrasmall", "extrasmall", train=True)
# TRAIN

## Trainer

In [177]:
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch_geometric.utils import degree
from torch_geometric.transforms import OneHotDegree
from tqdm import tqdm, trange
from scipy.stats import spearmanr, kendalltau

class SimGNNTrainer(object):
    """
    SimGNN model trainer.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        self.args = args
        self.process_dataset()
        self.model, self.optimizer = self.setup_model()
        
    def setup_model(self):
        """
        Creating a SimGNN.
        """
        model = SimGNN(self.args, self.number_of_node_labels, self.number_of_edge_labels).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay)

        return model, optimizer
    
    def process_dataset(self):
        """
        Downloading and processing dataset.
        """
        print("\nPreparing dataset.\n")
        self.training_graphs = GEDDataset(f'datasets/{self.args.dataset}', self.args.dataset, train=True)
        self.testing_graphs = GEDDataset(f'datasets/{self.args.dataset}', self.args.dataset, train=False)
        
        self.ged_matrix = self.training_graphs.ged
        self.nged_matrix = self.training_graphs.norm_ged

        if self.training_graphs[0].x is None:
            max_degree = 0
            for g in self.training_graphs + self.testing_graphs:
                if g.edge_index.size(1) > 0:
                    max_degree = max(max_degree, int(degree(g.edge_index[0]).max().item()))
            one_hot_degree = OneHotDegree(max_degree, cat=False)
            self.training_graphs.transform = one_hot_degree
            self.testing_graphs.transform = one_hot_degree

        self.number_of_node_labels = self.training_graphs.num_features
        self.number_of_edge_labels = self.training_graphs.num_edge_features

    def create_train_batches(self, train_set):
            """
            Creating suffled batches from the training graph list.
            :return batches: Zipped loaders as list.
            """
            source_loader = DataLoader(dataset=train_set.shuffle(), batch_size=self.args.batch_size)
            target_loader = DataLoader(dataset=train_set.shuffle(), batch_size=self.args.batch_size)

            return list(zip(source_loader, target_loader))

    def transform(self, data):
        """
        Getting ground truth GED for graph pairs and grouping as data into dictionary.
        :param data: Graph pair - tuple list for 1 target graph with any number of source graphs.
        :return new_data: Dictionary with data. Data contain the source graphs as g1, the target graph as g2 and the
        normalized ged (exponentiated) for each pair as target.
        """
        new_data = dict()

        new_data["g1"] = data[0].to(device)
        new_data["g2"] = data[1].to(device)

        # for each g1 and g2, access the lists with the graphs' index 'i', and retrieve the respective ged value.
        normalized_ged = self.nged_matrix[data[0]["i"].reshape(-1).tolist(), data[1]["i"].reshape(-1).tolist()].tolist()
        new_data["target"] = torch.from_numpy(np.exp([(-el) for el in normalized_ged])).view(-1).float().to(device)

        return new_data

    def process_batch(self, data):
        """
        Performs the forward pass with a batch of data.
        :param data: Data that is essentially pairs of batches, for source and target graphs.
        :return loss: Loss on the data.
        """
        self.optimizer.zero_grad()
        data = self.transform(data)
        prediction = self.model(data)
        loss = torch.nn.functional.mse_loss(prediction, data["target"], reduction='sum')
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def fit(self):
            """
            Training a model.
            """
            self.model.train()

            train_set = self.training_graphs
            e = self.args.epochs
            epochs = trange(e, leave=None, desc="Epoch", position=0)

            loss_list = []

            for epoch in epochs:
                batches = self.create_train_batches(train_set)
                main_index = 0
                loss_sum = 0
                # TODO: fix the printing
                # Batches progress bar is deactivated until it is fixed. Relevant to "nested bars issue"
                #_batches = tqdm(enumerate(batches), total=len(batches), desc="Batches", position=1)
                for index, batch_pair in enumerate(batches):
                    loss_score = self.process_batch(batch_pair)
                    main_index = main_index + batch_pair[0].num_graphs
                    loss_sum = loss_sum + loss_score
                    #_batches.update()

                loss = loss_sum / main_index
                loss_list.append(loss)

                epochs.set_description("Epoch (Loss=%g)" % round(loss, 5))
                
    def globalScore(self):
        """
        Simple scoring
        """
        print("\n\nModel evaluation.\n")
        self.model = self.model.cpu()
        self.model.eval()

        # tests.debugModelStateDict(self.model)

        # filled with exponentiated normalized ged values
        self.norm_ground_truth = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        self.norm_prediction_mat = np.empty((len(self.testing_graphs), len(self.training_graphs)))

        self.scores = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        self.rho_list = []
        self.tau_list = []
        self.prec_at_10_list = []
        self.prec_at_20_list = []

        t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs))

        for i, g in enumerate(self.testing_graphs):
            # source batch is a batch with 1 test graph repeated multiple times
            source_batch = Batch.from_data_list([g] * len(self.training_graphs))
            # target batch is a batch with every training graph
            target_batch = Batch.from_data_list(self.training_graphs)

            # Get ground truth
            data = self.transform((source_batch, target_batch))
            data["g1"] = data["g1"].cpu()
            data["g2"] = data["g2"].cpu()
            target = data["target"].cpu()
            self.norm_ground_truth[i] = target

            # Get prediction
            prediction = self.model(data)
            self.norm_prediction_mat[i] = prediction.detach().numpy()

            # Update metrics
            self.scores[i] = torch.nn.functional.mse_loss(prediction, target, reduction='none').detach().numpy()
            self.rho_list.append(calculate_ranking_correlation(spearmanr, self.norm_prediction_mat[i], self.norm_ground_truth[i]))
            self.tau_list.append(calculate_ranking_correlation(kendalltau, self.norm_prediction_mat[i], self.norm_ground_truth[i]))
            self.prec_at_10_list.append(calculate_prec_at_k(10, self.norm_prediction_mat[i], self.norm_ground_truth[i]))
            self.prec_at_20_list.append(calculate_prec_at_k(20, self.norm_prediction_mat[i], self.norm_ground_truth[i]))

            t.update(len(self.training_graphs))

        # Calculate metrics
        self.model_error = np.mean(self.scores)
        self.rho = np.mean(self.rho_list)
        self.tau = np.mean(self.tau_list)
        self.prec_at_10 = np.mean(self.prec_at_10_list)
        self.prec_at_20 = np.mean(self.prec_at_20_list)

        paperMetrics = [self.model_error, self.rho, self.tau, self.prec_at_10, self.prec_at_20]
        
        print("\nmse(10^-3): " + str(round(paperMetrics[0] * 1000, 5)) + ".")
        print("Spearman's rho: " + str(round(paperMetrics[1], 5)) + ".")
        print("Kendall's tau: " + str(round(paperMetrics[2], 5)) + ".")
        print("p@10: " + str(round(paperMetrics[3], 5)) + ".")
        print("p@20: " + str(round(paperMetrics[4], 5)) + ".")

        # self.presentGraphsWithGEDs(entireDataset=False, withPrediction=True)


## Main

In [178]:
import time

def main(args : list):
    """
    Parsing command line parameters, reading data.
    Fitting and scoring a SimGNN model.
    """
    args = parameter_parser(args)
    tab_printer(args)
    trainer = SimGNNTrainer(args)
    
    
    use_pretrained = args.load_path != None  # To determine whether a pre-trained model will be used.
    
    if use_pretrained is True:
        print("Pre-trained mode: load an already fit state instead of training.")
        checkpoint = torch.load(args.load_path)
        trainer.model.load_state_dict(checkpoint['model_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        #TRAIN
        start = time.time()
        trainer.fit()
        end = time.time()
        print(f"\nTotal Training Time: {end-start}")
        if args.save_path != None:
            torch.save({
                'model_state_dict': trainer.model.state_dict(),
                'optimizer_state_dict': trainer.optimizer.state_dict()
            }, args.save_path)
    
    #TEST
    trainer.globalScore()

## Execution

In [179]:
main(args=[
    "--dataset=AIDS700nef",
    "--save-path=AIDS700nef.pt",
    "--epochs=100",
    "--batch-size=128",
    "--histogram"
])

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch (Loss=0.02128):   0%|          | 0/100 [00:00<?, ?it/s]


Preparing dataset.



Epoch (Loss=0.00589): 100%|██████████| 100/100 [00:12<00:00,  8.06it/s]
  1%|▏         | 1120/78400 [00:00<00:09, 8009.23it/s]


Total Training Time: 12.40739393234253


Model evaluation.



100%|██████████| 78400/78400 [00:10<00:00, 7639.75it/s]


mse(10^-3): 5.79031.
Spearman's rho: 0.65839.
Kendall's tau: 0.49844.
p@10: 0.79786.
p@20: 0.79643.





In [180]:
with open("datasets/AIDS700nef/raw/AIDS700nef/ged.pickle", 'rb') as f:
    obj = pickle.load(f)
    for (x, y), g in obj.items():
        if x == 103:
            print((x, y), g)