[Reference](https://medium.com/stanford-cs224w/self-supervised-learning-for-graphs-963e03b9f809)

# 1. Edge Perturbation

In [1]:
class EdgePerturbation():
    """
    Edge perturbation on the given graph or batched graphs. Class objects callable via 
    method :meth:`views_fn`.
    
    Args:
        add (bool, optional): Set :obj:`True` if randomly add edges in a given graph.
            (default: :obj:`True`)
        drop (bool, optional): Set :obj:`True` if randomly drop edges in a given graph.
            (default: :obj:`False`)
        ratio (float, optional): Percentage of edges to add or drop. (default: :obj:`0.1`)
    """

    def __init__(self, add=True, drop=False, ratio=0.1):
        self.add = add
        self.drop = drop
        self.ratio = ratio
        
    def do_trans(self, data):
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        perturb_num = int(edge_num * self.ratio)

        edge_index = data.edge_index.detach().clone()
        idx_remain = edge_index
        idx_add = torch.tensor([]).reshape(2, -1).long()

        if self.drop:
            idx_remain = edge_index[:, np.random.choice(edge_num, edge_num-perturb_num, replace=False)]

        if self.add:
            idx_add = torch.randint(node_num, (2, perturb_num))

        new_edge_index = torch.cat((idx_remain, idx_add), dim=1)
        new_edge_index = torch.unique(new_edge_index, dim=1)

        return Data(x=data.x, edge_index=new_edge_index)

# 2. Diffusion

In [2]:
class Diffusion():
    """
    Diffusion on the given graph or batched graphs, used in 
    `MVGRL <https://arxiv.org/pdf/2006.05582v1.pdf>`_. Class objects callable via 
    method :meth:`views_fn`.
    
    Args:
        mode (string, optional): Diffusion instantiation mode with two options:
            :obj:`"ppr"`: Personalized PageRank; :obj:`"heat"`: heat kernel.
            (default: :obj:`"ppr"`)
        alpha (float, optinal): Teleport probability in a random walk. (default: :obj:`0.2`)
        t (float, optinal): Diffusion time. (default: :obj:`5`)
        add_self_loop (bool, optional): Set True to add self-loop to edge_index.
            (default: :obj:`True`)
    """

    def __init__(self, mode="ppr", alpha=0.2, t=5, add_self_loop=True):
        self.mode = mode
        self.alpha = alpha
        self.t = t
        self.add_self_loop = add_self_loop

    def do_trans(self, data):
        node_num, _ = data.x.size()
        if self.add_self_loop:
            sl = torch.tensor([[n, n] for n in range(node_num)]).t()
            edge_index = torch.cat((data.edge_index, sl), dim=1)
        else:
            edge_index = data.edge_index.detach().clone()
        
        orig_adj = to_dense_adj(edge_index)[0]
        orig_adj = torch.where(orig_adj>1, torch.ones_like(orig_adj), orig_adj)
        d = torch.diag(torch.sum(orig_adj, 1))

        if self.mode == "ppr":
            dinv = torch.inverse(torch.sqrt(d))
            at = torch.matmul(torch.matmul(dinv, orig_adj), dinv)
            diff_adj = self.alpha * torch.inverse((torch.eye(orig_adj.shape[0]) - (1 - self.alpha) * at))

        elif self.mode == "heat":
            diff_adj = torch.exp(self.t * (torch.matmul(orig_adj, torch.inverse(d)) - 1))

        else:
            raise Exception("Must choose one diffusion instantiation mode from 'ppr' and 'heat'!")
            
        edge_ind, edge_attr = dense_to_sparse(diff_adj)

        return Data(x=data.x, edge_index=edge_ind, edge_attr=edge_attr)

# 3. Node Dropping

In [3]:
class UniformSample():
    """
    Uniformly node dropping on the given graph or batched graphs. 
    Class objects callable via method :meth:`views_fn`.
    
    Args:
        ratio (float, optinal): Ratio of nodes to be dropped. (default: :obj:`0.1`)
    """

    def __init__(self, ratio=0.1):
        self.ratio = ratio
    
    def do_trans(self, data):
        
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        
        keep_num = int(node_num * (1-self.ratio))
        idx_nondrop = torch.randperm(node_num)[:keep_num]
        mask_nondrop = torch.zeros_like(data.x[:,0]).scatter_(0, idx_nondrop, 1.0).bool()
        
        edge_index, _ = subgraph(mask_nondrop, data.edge_index, relabel_nodes=True, num_nodes=node_num)
        return Data(x=data.x[mask_nondrop], edge_index=edge_index)

# 4. Random Walk based Sampling

In [4]:
class RWSample():
    """
    Subgraph sampling based on random walk on the given graph or batched graphs.
    Class objects callable via method :meth:`views_fn`.
    
    Args:
        ratio (float, optional): Percentage of nodes to sample from the graph.
            (default: :obj:`0.1`)
        add_self_loop (bool, optional): Set True to add self-loop to edge_index.
            (default: :obj:`False`)
    """

    def __init__(self, ratio=0.1, add_self_loop=False):
        self.ratio = ratio
        self.add_self_loop = add_self_loop
    
    def do_trans(self, data):
        node_num, _ = data.x.size()
        sub_num = int(node_num * self.ratio)

        if self.add_self_loop:
            sl = torch.tensor([[n, n] for n in range(node_num)]).t()
            edge_index = torch.cat((data.edge_index, sl), dim=1)
        else:
            edge_index = data.edge_index.detach().clone()

        idx_sub = [np.random.randint(node_num, size=1)[0]]
        idx_neigh = set([n.item() for n in edge_index[1][edge_index[0]==idx_sub[0]]])

        count = 0
        while len(idx_sub) <= sub_num:
            count = count + 1
            if count > node_num:
                break
            if len(idx_neigh) == 0:
                break
            sample_node = np.random.choice(list(idx_neigh))
            if sample_node in idx_sub:
                continue
            idx_sub.append(sample_node)
            idx_neigh.union(set([n.item() for n in edge_index[1][edge_index[0]==idx_sub[-1]]]))

        idx_sub = torch.LongTensor(idx_sub).to(data.x.device)
        mask_nondrop = torch.zeros_like(data.x[:,0]).scatter_(0, idx_sub, 1.0).bool()
        edge_index, _ = subgraph(mask_nondrop, data.edge_index, relabel_nodes=True, num_nodes=node_num)
        return Data(x=data.x[mask_nondrop], edge_index=edge_index)

# 5. Node Attribute Masking

In [5]:
class NodeAttrMask():
    """
    Node attribute masking on the given graph or batched graphs. 
    Class objects callable via method :meth:`views_fn`.
    
    Args:
        mode (string, optinal): Masking mode with three options:
            :obj:`"whole"`: mask all feature dimensions of the selected node with a Gaussian distribution;
            :obj:`"partial"`: mask only selected feature dimensions with a Gaussian distribution;
            :obj:`"onehot"`: mask all feature dimensions of the selected node with a one-hot vector.
            (default: :obj:`"whole"`)
        mask_ratio (float, optinal): The ratio of node attributes to be masked. (default: :obj:`0.1`)
        mask_mean (float, optional): Mean of the Gaussian distribution to generate masking values.
            (default: :obj:`0.5`)
        mask_std (float, optional): Standard deviation of the distribution to generate masking values. 
            Must be non-negative. (default: :obj:`0.5`)
    """

    def __init__(self, mode='whole', mask_ratio=0.1, mask_mean=0.5, mask_std=0.5, return_mask=False):
        self.mode = mode
        self.mask_ratio = mask_ratio
        self.mask_mean = mask_mean
        self.mask_std = mask_std
        self.return_mask = return_mask
    
    def do_trans(self, data):
        
        node_num, feat_dim = data.x.size()
        x = data.x.detach().clone()

        if self.mode == "whole":
            mask = torch.zeros(node_num)
            mask_num = int(node_num * self.mask_ratio)
            idx_mask = np.random.choice(node_num, mask_num, replace=False)
            x[idx_mask] = torch.tensor(np.random.normal(loc=self.mask_mean, scale=self.mask_std, 
                                                        size=(mask_num, feat_dim)), dtype=torch.float32)
            mask[idx_mask] = 1

        elif self.mode == "partial":
            mask = torch.zeros((node_num, feat_dim))
            for i in range(node_num):
                for j in range(feat_dim):
                    if random.random() < self.mask_ratio:
                        x[i][j] = torch.tensor(np.random.normal(loc=self.mask_mean, 
                                                                scale=self.mask_std), dtype=torch.float32)
                        mask[i][j] = 1

        elif self.mode == "onehot":
            mask = torch.zeros(node_num)
            mask_num = int(node_num * self.mask_ratio)
            idx_mask = np.random.choice(node_num, mask_num, replace=False)
            x[idx_mask] = torch.tensor(np.eye(feat_dim)[np.random.randint(0, feat_dim, size=(mask_num))], dtype=torch.float32)
            mask[idx_mask] = 1

        else:
            raise Exception("Masking mode option '{0:s}' is not available!".format(mode))

        if self.return_mask:
            return Data(x=x, edge_index=data.edge_index, mask=mask)
        else:
            return Data(x=x, edge_index=data.edge_index)

# Graph Neural Networks


In [13]:
import torch.nn as nn
from torch_geometric.nn import SAGEConv

class GraphSAGE(nn.Module):
    def __init__(self, feat_dim, hidden_dim, n_layers):
        super(GraphSAGE, self).__init__()

        self.convs = nn.ModuleList()
        self.acts = nn.ModuleList()
        self.n_layers = n_layers

        a = nn.ReLU()
        for i in range(n_layers):
            start_dim = hidden_dim if i else feat_dim
            conv = SAGEConv(start_dim, hidden_dim)
            self.convs.append(conv)
            self.acts.append(a)

    def forward(self, data):
        x, edge_index, batch = data
        for i in range(self.n_layers):
            x = self.convs[i](x, edge_index)
            x = self.acts[i](x)
        return x

In [14]:
import torch.nn as nn
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, feat_dim, hidden_dim, n_layers):
        super(GCN, self).__init__()

        self.convs = nn.ModuleList()
        self.acts = nn.ModuleList()
        self.n_layers = n_layers

        a = nn.ReLU()
        for i in range(n_layers):
            start_dim = hidden_dim if i else feat_dim
            conv = GCNConv(start_dim, hidden_dim)
            self.convs.append(conv)
            self.acts.append(a)

    def forward(self, data):
        x, edge_index, batch = data
        for i in range(self.n_layers):
            x = self.convs[i](x, edge_index)
            x = self.acts[i](x)
        return x

In [15]:
import torch.nn as nn
from torch_geometric.nn import GATConv

class GAT(nn.Module):
    def __init__(self, feat_dim, hidden_dim, n_layers, heads):
        super(GAT, self).__init__()

        self.convs = nn.ModuleList()
        self.acts = nn.ModuleList()
        self.n_layers = n_layers

        a = nn.LeakyReLU()
        for i in range(n_layers):
            start_dim = hidden_dim if i else feat_dim
            conv = GATConv(start_dim, hidden_dim, heads=heads, concat=False)
            self.convs.append(conv)
            self.acts.append(a)

    def forward(self, data):
        x, edge_index, batch = data
        for i in range(self.n_layers):
            x = self.convs[i](x, edge_index)
            x = self.acts[i](x)
        return x

In [16]:
import torch.nn as nn
from torch_geometric.nn import GINConv

class GIN(nn.Module):
    def __init__(self, feat_dim, hidden_dim, n_layers):
        super(GIN, self).__init__()

        self.convs = nn.ModuleList()
        self.n_layers = n_layers

        self.act = nn.ReLU()

        for i in range(n_layers):
            start_dim = hidden_dim if i else feat_dim
            mlp = nn.Sequential(
                        nn.Linear(start_dim, hidden_dim),
                        self.act,
                        nn.Linear(hidden_dim, hidden_dim)
                        )
            conv = GINConv(mlp)
            self.convs.append(conv)

    def forward(self, data):
        x, edge_index, batch = data
        for i in range(self.n_layers):
            x = self.convs[i](x, edge_index)
            x = self.act(x)
        return x

In [17]:
import torch.nn as nn
from torch_geometric.nn import SGConv

class SGC(nn.Module):
    def __init__(self, feat_dim, hidden_dim, n_layers):
        super(SGC, self).__init__()

        self.conv = SGConv(feat_dim, hidden_dim, n_layers)
        self.act = nn.ReLU()

    def forward(self, data):
        x, edge_index, batch = data

        x = self.conv(x, edge_index)
        x = self.act(x)
        return x

In [18]:
import torch

def infonce(readout_anchor, readout_positive, tau=0.5, norm=True):
    """
    The InfoNCE (NT-XENT) loss in contrastive learning. The implementation
    follows the paper `A Simple Framework for Contrastive Learning of 
    Visual Representations <https://arxiv.org/abs/2002.05709>`.
    Args:
        readout_anchor, readout_positive: Tensor of shape [batch_size, feat_dim]
        tau: Float. Usually in (0,1].
        norm: Boolean. Whether to apply normlization.
    """

    batch_size = readout_anchor.shape[0]
    sim_matrix = torch.einsum("ik,jk->ij", readout_anchor, readout_positive)

    if norm:
        readout_anchor_abs = readout_anchor.norm(dim=1)
        readout_positive_abs = readout_positive.norm(dim=1)
        sim_matrix = sim_matrix / torch.einsum("i,j->ij", readout_anchor_abs, readout_positive_abs)

    sim_matrix = torch.exp(sim_matrix / tau)
    pos_sim = sim_matrix[range(batch_size), range(batch_size)]
    loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
    loss = - torch.log(loss).mean()
    return loss

In [19]:
import torch
import numpy as np
import torch.nn.functional as F

def get_expectation(masked_d_prime, positive=True):
    """
    Args:
        masked_d_prime: Tensor of shape [n_graphs, n_graphs] for global_global,
                        tensor of shape [n_nodes, n_graphs] for local_global.
        positive (bool): Set True if the d_prime is masked for positive pairs,
                        set False for negative pairs.
    """

    log_2 = np.log(2.)
    if positive:
        score = log_2 - F.softplus(-masked_d_prime)
    else:
        score = F.softplus(-masked_d_prime) + masked_d_prime - log_2
    return score

def jensen_shannon(readout_anchor, readout_positive):
    """
    The Jensen-Shannon Estimator of Mutual Information used in contrastive learning. The
    implementation follows the paper `Learning deep representations by mutual information 
    estimation and maximization <https://arxiv.org/abs/1808.06670>`.
    Note: The JSE loss implementation can produce negative values because a :obj:`-2log2` shift is 
        added to the computation of JSE, for the sake of consistency with other f-convergence losses.
    Args:
        readout_anchor, readout_positive: Tensor of shape [batch_size, feat_dim].
    """

    batch_size = readout_anchor.shape[0]

    pos_mask = torch.zeros((batch_size, batch_size))
    neg_mask = torch.ones((batch_size, batch_size))
    for graphidx in range(batch_size):
        pos_mask[graphidx][graphidx] = 1.
        neg_mask[graphidx][graphidx] = 0.

    d_prime = torch.matmul(readout_anchor, readout_positive.t())

    E_pos = get_expectation(d_prime * pos_mask, positive=True).sum()
    E_pos = E_pos / batch_size
    E_neg = get_expectation(d_prime * neg_mask, positive=False).sum()
    E_neg = E_neg / (batch_size * (batch_size - 1))
    return E_neg - E_pos

In [20]:
import os
import torch
import torch.nn as nn

class GraphClassificationModel(nn.Module):
    """
    Model for graph classification.
    GNN Encoder followed by linear layer.
    
    Args:
        feat_dim (int): The dimension of input node features.
        hidden_dim (int): The dimension of node-level (local) embeddings. 
        n_layers (int, optional): The number of GNN layers in the encoder. (default: :obj:`5`)
        gnn (string, optional): The type of GNN layer, :obj:`gcn` or :obj:`gin` or :obj:`gat`
            or :obj:`graphsage` or :obj:`resgcn` or :obj:`sgc`. (default: :obj:`gcn`)
        load (string, optional): The SSL model to be loaded. The GNN encoder will be
            initialized with pretrained SSL weights, and only the classifier head will
            be trained. Otherwise, GNN encoder and classifier head are trained end-to-end.
    """

    def __init__(self, feat_dim, hidden_dim, n_layers, output_dim, gnn, load=None):
        super(GraphClassificationModel, self).__init__()

        # Encoder is a wrapper class for easy instantiation of pre-implemented graph encoders.
        self.encoder = Encoder(feat_dim, hidden_dim, n_layers=n_layers, gnn=gnn)

        if load:
            ckpt = torch.load(os.path.join("logs", load, "best_model.ckpt"))
            self.encoder.load_state_dict(ckpt["state"])
            for param in self.encoder.parameters():
                param.requires_grad = False

        if gnn in ["resgcn", "sgc"]:
            feat_dim = hidden_dim
        else:
            feat_dim = n_layers * hidden_dim
        self.classifier = nn.Linear(feat_dim, output_dim)

    def forward(self, data):
        embeddings = self.encoder(data)
        scores = self.classifier(embeddings)
        return scores

In [21]:
import torch
import random
import torch.nn.functional as F
from torch_geometric.utils import degree
from torch_geometric.datasets import TUDataset

DATA_SPLIT = [0.7, 0.2, 0.1] # Train / val / test split ratio

def get_max_deg(dataset):
    """
    Find the max degree across all nodes in all graphs.
    """
    max_deg = 0
    for data in dataset:
        row, col = data.edge_index
        num_nodes = data.num_nodes
        deg = degree(row, num_nodes)
        deg = max(deg).item()
        if deg > max_deg:
            max_deg = int(deg)
    return max_deg

class CatDegOnehot(object):
    """
    Adds the node degree as one hot encodings to the node features.
    Args:
        max_degree (int): Maximum degree.
        in_degree (bool, optional): If set to :obj:`True`, will compute the in-
            degree of nodes instead of the out-degree. (default: :obj:`False`)
        cat (bool, optional): Concat node degrees to node features instead
            of replacing them. (default: :obj:`True`)
    """

    def __init__(self, max_degree, in_degree=False, cat=True):
        self.max_degree = max_degree
        self.in_degree = in_degree
        self.cat = cat

    def __call__(self, data):
        idx, x = data.edge_index[1 if self.in_degree else 0], data.x
        deg = degree(idx, data.num_nodes, dtype=torch.long)
        deg = F.one_hot(deg, num_classes=self.max_degree + 1).to(torch.float)

        if x is not None and self.cat:
            x = x.view(-1, 1) if x.dim() == 1 else x
            data.x = torch.cat([x, deg.to(x.dtype)], dim=-1)
        else:
            data.x = deg
        return data

def split_dataset(dataset, train_data_percent=1.0):
    """
    Splits the data into train / val / test sets.
    Args:
        dataset (list): all graphs in the dataset.
        train_data_percent (float): Fraction of training data
            which is labelled. (default 1.0)
    """

    random.shuffle(dataset)

    n = len(dataset)
    train_split, val_split, test_split = DATA_SPLIT

    train_end = int(n * DATA_SPLIT[0] * train_data_percent)
    val_end = train_end + int(n * DATA_SPLIT[1])
    train_dataset, val_dataset, test_dataset = [i for i in dataset[:train_end]], [i for i in dataset[train_end:val_end]], [i for i in dataset[val_end:]]
    return train_dataset, val_dataset, test_dataset

# load MUTAG from TUDataset
dataset = TUDataset(root="/tmp/TUDataset/MUTAG", name="MUTAG", use_node_attr=True)

# expand node features by adding node degrees as one hot encodings.
max_degree = get_max_deg(dataset)
transform = CatDegOnehot(max_degree)
dataset = [transform(graph) for graph in dataset]