In [None]:
import numpy as np
import torch
import pandas as pd
import scipy.sparse as sp
import time
import copy
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, auc
import torch.nn.functional as F
import pickle
from tqdm import tqdm
import networkx as nx
import community as cm

def get_scores(edges_pos, edges_neg, A_pred, adj_label):
    # get logits and labels
    preds_pos = A_pred[edges_pos[:, 0], edges_pos[:, 1]]
    preds_neg = A_pred[edges_neg[:, 0], edges_neg[:, 1]]

    logits = np.hstack([preds_pos, preds_neg])
    labels = np.hstack([np.ones(preds_pos.size(0)), np.zeros(preds_neg.size(0))])

    roc_auc = roc_auc_score(labels, logits)
    ap_score = average_precision_score(labels, logits)
    precisions, recalls, thresholds = precision_recall_curve(labels, logits)
    pr_auc = auc(recalls, precisions)

    f1s = np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
    best_comb = np.argmax(f1s)
    f1 = f1s[best_comb]
    pre = precisions[best_comb]
    rec = recalls[best_comb]
    thresh = thresholds[best_comb]

    adj_rec = copy.deepcopy(A_pred)
    adj_rec[adj_rec < thresh] = 0
    adj_rec[adj_rec >= thresh] = 1

    labels_all = adj_label.to_dense().view(-1).long()
    preds_all = adj_rec.view(-1).long()
    recon_acc = (preds_all == labels_all).sum().float() / labels_all.size(0)
    results = {
        'roc': roc_auc,
        'pr': pr_auc,
        'ap': ap_score,
        'pre': pre,
        'rec': rec,
        'f1': f1,
        'acc': recon_acc,
        'adj_recon': adj_rec
    }
    return results


def sample_graph_det(adj_orig, A_pred, remove_edge_num=100):
    if remove_edge_num == 0:
        return copy.deepcopy(adj_orig)
    orig_upper = sp.triu(adj_orig, 1)
    edges = np.asarray(orig_upper.nonzero()).T
    if remove_edge_num:
        n_remove = remove_edge_num
        edge_prob = A_pred[edges.T[0], edges.T[1]]
        edge_index_to_remove = np.argpartition(edge_prob, n_remove)[:n_remove]
        mask = np.ones(len(edges), dtype=bool)
        mask[edge_index_to_remove] = False
        edges_pred = edges[mask]
    else:
        edges_pred = edges

    # Recover the edges to [a,b] and [b,a] instead of [a,b]
    edges_pred = np.concatenate([edges_pred, edges_pred[:, ::-1]])

    return edges_pred


def sample_graph_community(adj_orig, A_pred, remove_edge_num=100, tau=0.05):
    if remove_edge_num == 0:
        return copy.deepcopy(adj_orig)
    orig_upper = sp.triu(adj_orig, 1)
    edges = np.asarray(orig_upper.nonzero()).T
    partition = cm.best_partition(nx.Graph(adj_orig), random_state=42)
    # Get the indices for each community index
    community_indices = [np.where(np.array(list(partition.values())) == i)[0] for i in set(partition.values())]

    if remove_edge_num:
        n_remove = remove_edge_num
        # Create a list storing the current weight of each node
        node_count = np.zeros(len(partition))
        edge_index_to_remove = []
        for _ in tqdm(range(n_remove)):
            # Find the edge with the smallest probability
            edge_prob = A_pred[edges.T[0], edges.T[1]]
            remaining_edges = [idx for idx in np.argsort(edge_prob) if idx not in edge_index_to_remove]
            if len(remaining_edges) == 0:
                break
            min_edge_idx = remaining_edges[0]
            min_edge = edges[min_edge_idx]

            # Get the communities of the two nodes
            community_i = partition[min_edge[0]]
            community_j = partition[min_edge[1]]

            # Compute the penalty term
            penalty_i = node_count[community_i] / (2 * n_remove)
            penalty_j = node_count[community_j] / (2 * n_remove)

            # Update the probability of all nodes within the same community as community_i and community_j
            A_pred[community_indices[community_i], :] += (1 -
                                                          A_pred[community_indices[community_i], :]) * penalty_i * tau
            A_pred[community_indices[community_j], :] += (1 -
                                                          A_pred[community_indices[community_j], :]) * penalty_j * tau
            # A_pred[community_indices[community_i], :] *= (1 + penalty_i * tau)
            # A_pred[community_indices[community_j], :] *= (1 + penalty_j * tau)

            # Add the edge index to the removal list
            edge_index_to_remove.append(min_edge_idx)

            # Update the node count for the communities
            node_count[community_i] += 1
            node_count[community_j] += 1

        # Remove the edges with the smallest probabilities
        edges_pred = np.delete(edges, edge_index_to_remove, axis=0)
    else:
        edges_pred = edges

    edges_pred = np.concatenate([edges_pred, edges_pred[:, ::-1]])

    return edges_pred


In [None]:
def louvain_clustering(adj, s_rec):
    """
    Performs community detection on a graph using the Louvain method
    :param adj: adjacency matrix of the graph
    :param s_rec: s hyperparameter for s-regular sparsification
    :return: adj_louvain, the Louvain community membership matrix obtained;
    nb_communities_louvain, the number of communities; partition, the community
    associated with each node from the graph
    """
    graph = nx.Graph(adj)

    # Community detection using the Louvain method
    partition = cm.best_partition(graph)
    communities_louvain = list(partition.values())

    # Number of communities found by the Louvain method
    nb_communities_louvain = np.max(communities_louvain) + 1

    # One-hot representation of communities
    communities_louvain_onehot = sp.csr_matrix(np.eye(nb_communities_louvain)[communities_louvain])

    # Community membership matrix (adj_louvain[i,j] = 1 if nodes i and j are in the same community)
    adj_louvain = communities_louvain_onehot.dot(communities_louvain_onehot.transpose())

    # Remove the diagonal
    adj_louvain = adj_louvain - sp.eye(adj_louvain.shape[0])

    # s-regular sparsification of adj_louvain
    adj_louvain = sparsification(adj_louvain, s_rec)

    return adj_louvain, nb_communities_louvain, partition


def sparsification(adj_louvain, s=1):
    """
    Performs an s-regular sparsification of the adj_louvain matrix (if possible)
    :param adj_louvain: the initial community membership matrix
    :param s: value of s for s-regular sparsification
    :return: s-sparsified adj_louvain matrix
    """

    # Number of nodes
    n = adj_louvain.shape[0]

    # Compute degrees
    degrees = np.sum(adj_louvain, axis=0).getA1()

    for i in range(n):

        # Get non-null neighbors of i
        edges = sp.find(adj_louvain[i, :])[1]

        # More than s neighbors? Subsample among those with degree > s
        if len(edges) > s:
            # Neighbors of i with degree > s
            high_degrees = np.where(degrees > s)
            edges_s = np.intersect1d(edges, high_degrees)
            # Keep s of them (if possible), randomly selected
            removed_edges = np.random.choice(edges_s, min(len(edges_s), len(edges) - s), replace=False)
            adj_louvain[i, removed_edges] = 0.0
            adj_louvain[removed_edges, i] = 0.0
            degrees[i] = s
            degrees[removed_edges] -= 1

    adj_louvain.eliminate_zeros()

    return adj_louvain

In [None]:
# Define the dataset
class GraphData():

    def __init__(self, data_path, device="cpu", use_louvain=True, louvain_neighbors=10, louvain_lambda=0.5):
        self.data = torch.load(data_path)
        self.device = device
        self.use_louvain = use_louvain
        self.louvain_neighbors = louvain_neighbors
        self.louvain_lambda = louvain_lambda
        self.init_data()

    def init_data(self):
        self.x, self.y, self.edge_index, self.val_mask, self.train_mask = self.data.x, self.data.y, self.data.edge_index, self.data.val_mask, self.data.train_mask

        # Filter out the validation and training data
        self.x_train = self.x[self.train_mask]
        self.edge_index_train = torch.stack(
            [edge for edge in self.edge_index.permute(1, 0) if self.train_mask[edge[0]] and self.train_mask[edge[1]]])
        self.edge_index_val = torch.stack(
            [edge for edge in self.edge_index.permute(1, 0) if self.val_mask[edge[0]] and self.val_mask[edge[1]]])

        # Adjacency matrix contains all the edges in the graph
        # NOTE: train_edges here can be all the edges or the training edges only
        self.train_edges = self.edge_index
        # ! self.x ought to be modified if train_edges is not all edges
        self.adj_train = self.build_adj(self.train_edges, self.x.shape[0]).numpy()

        if self.use_louvain:
            self.louvain_adj, _, _ = louvain_clustering(self.adj_train, self.louvain_neighbors)
            self.adj_train_louvain = self.adj_train + self.louvain_adj * self.louvain_lambda

        self.adj_train_norm = self.normalize_adj(self.adj_train)
        self.adj_train_louvain_norm = self.normalize_adj(self.adj_train_louvain)

        # ! self.x ought to be modified if train_edges is not all edges
        self.adj_label = self.build_adj_label(self.train_edges, self.x.shape[0])

        self.val_edges = self.edge_index_val
        self.val_edges_false = self.generate_false_edges(self.val_edges.shape[0])

    def generate_false_edges(self, num_edges):
        """
            Generates num_edges false edges for the graph
        """
        false_edges = set()

        while len(false_edges) < num_edges:
            src_node = np.random.randint(0, self.x.shape[0])
            dst_node = np.random.randint(0, self.x.shape[0])
            if src_node != dst_node and not self.adj_train_norm[src_node, dst_node]:
                false_edges.add((src_node, dst_node))

        false_edges = torch.tensor(list(false_edges)).to(self.device)
        return false_edges

    def sparse_to_tuple(self, sparse_mx):
        if not sp.isspmatrix_coo(sparse_mx):
            sparse_mx = sparse_mx.tocoo()
        coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
        values = sparse_mx.data
        shape = sparse_mx.shape
        return coords, values, shape

    def build_adj(self, edge_idx, num_verts, half=False):
        """
            Input:
                edge_idx: [torch.Tensor] a tensor of shape (2, num_edges) representing the edge indices of the graph
                num_verts: the number of nodes in the graph
            Returns:
                adjacency_matrix: an adjacency matrix built from edge_idx 
        """
        adjacency_matrix = torch.zeros((num_verts, num_verts))

        # Iterate over each edge and set the corresponding entries in the adjacency matrix
        for i in range(edge_idx.size(1)):
            src_node = edge_idx[0, i]
            dst_node = edge_idx[1, i]
            adjacency_matrix[src_node, dst_node] = 1
            if not half:
                adjacency_matrix[dst_node, src_node] = 1

        return adjacency_matrix

    def build_adj_label(self, edge_idx, num_verts, half=False):
        """
            Input:
                edge_idx: [torch.Tensor] a tensor of shape (2, num_edges) representing the edge indices of the graph
                num_verts: the number of nodes in the graph
            Returns:
                adj_label: [torch.sparse.FloatTensor] a sparse matrix representing the adjacency matrix label of the graph (half)
        """
        adj = self.build_adj(edge_idx, num_verts, half)
        adj_label = adj + torch.eye(num_verts)
        adj_label = adj_label.to_sparse()

        return adj_label

    def normalize_adj(self, adj_train):
        """
            Input:
                adj_train: [np.ndarray] a sparse matrix representing the adjacency matrix of the graph
            Returns:
                adj_norm: [torch.sparse.FloatTensor] a normalized version of the adjacency matrix
        """
        adj_ = sp.coo_matrix(adj_train)
        adj_.setdiag(1)
        rowsum = np.array(adj_.sum(1))
        degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
        adj_norm = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
        adj_norm_tuple = self.sparse_to_tuple(adj_norm)
        adj_norm = torch.sparse.FloatTensor(torch.LongTensor(adj_norm_tuple[0].T), torch.FloatTensor(adj_norm_tuple[1]),
                                            torch.Size(adj_norm_tuple[2]))

        return adj_norm

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

# Define the VGAE model
class VGAE(nn.Module):
    """
        The GVAE is adapted from https://github.com/zhao-tong/GAug/blob/master/vgae/models.py
    """

    def __init__(self, adj, dim_in, dim_h, dim_z, use_gae=False, adj_louvain=None):
        super(VGAE, self).__init__()

        self.dim_z = dim_z
        self.gae = use_gae
        if adj_louvain is not None:
            self.base_gcn = GraphConvSparse(dim_in, dim_h, adj_louvain)
        else:
            self.base_gcn = GraphConvSparse(dim_in, dim_h, adj)

        self.gcn_mean = GraphConvSparse(dim_h, dim_z, adj, activation=False)
        self.gcn_logstd = GraphConvSparse(dim_h, dim_z, adj, activation=False)

    def encode(self, X):
        # TODO: modify X to fuse basic graph features
        hidden = self.base_gcn(X)
        self.mean = self.gcn_mean(hidden)
        if self.gae:
            # graph auto-encoder
            return self.mean
        else:
            # variational graph auto-encoder
            self.logstd = self.gcn_logstd(hidden)
            gaussian_noise = torch.randn_like(self.mean)
            sampled_z = gaussian_noise * torch.exp(self.logstd) + self.mean
            return sampled_z

    def decode(self, Z):
        A_pred = Z @ Z.T
        return A_pred

    def forward(self, X, F=None):
        if F is not None:
            # TODO design different fusing strategies
            if isinstance(F, np.ndarray):
                F = torch.from_numpy(F).float().to(X.device)
            X = torch.cat([X, F], dim=1)
        Z = self.encode(X)
        A_pred = self.decode(Z)
        return A_pred


class GraphConvSparse(nn.Module):

    def __init__(self, input_dim, output_dim, adj, activation=True):
        super(GraphConvSparse, self).__init__()
        self.weight = self.glorot_init(input_dim, output_dim)
        self.adj = adj
        self.activation = activation

    def glorot_init(self, input_dim, output_dim):
        init_range = np.sqrt(6.0 / (input_dim + output_dim))
        initial = torch.rand(input_dim, output_dim) * 2 * init_range - init_range
        return nn.Parameter(initial)

    def forward(self, inputs):
        x = inputs @ self.weight
        x = self.adj @ x
        if self.activation:
            return F.elu(x)
        else:
            return x


In [None]:

def train_model(cfg, graph_data, model, structural_features=None):
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
    adj_t = graph_data.adj_train
    norm_w = adj_t.shape[0]**2 / float((adj_t.shape[0]**2 - adj_t.sum()) * 2)
    pos_weight = torch.FloatTensor([float(adj_t.shape[0]**2 - adj_t.sum()) / adj_t.sum()]).to(cfg["device"])

    # move input data and label to gpu if needed
    features = graph_data.x.to(cfg["device"])
    adj_label = graph_data.adj_label.to_dense().to(cfg["device"])

    best_vali_criterion = 0.0
    best_state_dict = None
    model.train()

    train_bar = tqdm(range(cfg["epoch"]))
    for epoch in train_bar:
        A_pred = model(X=features, F=structural_features)
        optimizer.zero_grad()
        loss = norm_w * F.binary_cross_entropy_with_logits(A_pred, adj_label, pos_weight=pos_weight)
        if not cfg["use_gae"]:
            kl_divergence = 0.5 / A_pred.size(0) * (1 + 2 * model.logstd - model.mean**2 -
                                                    torch.exp(2 * model.logstd)).sum(1).mean()
            loss -= kl_divergence

        A_pred = torch.sigmoid(A_pred).detach().cpu()
        r = get_scores(graph_data.val_edges, graph_data.val_edges_false, A_pred, graph_data.adj_label)

        if r[cfg["criterion"]] > best_vali_criterion:
            best_vali_criterion = r[cfg["criterion"]]
            best_state_dict = copy.deepcopy(model.state_dict())
            r_test = r

        loss.backward()
        optimizer.step()
        train_bar.set_description(
            f"E: {epoch+1} | L: {loss.item():.4f} | A: {r['acc']:.4f} | ROC: {r['roc']:.4f} | AP: {r['ap']:.4f} | F1: {r['f1']:.4f}"
        )

    print("Training completed. Final results: test_roc: {:.4f} test_ap: {:.4f} test_f1: {:.4f} test_recon_acc: {:.4f}".
          format(r_test['roc'], r_test['ap'], r_test['f1'], r_test['acc']))

    model.load_state_dict(best_state_dict)
    # Dump the best model
    torch.save(model.state_dict(), f'aug_model.pt')
    return model


def gen_graphs(cfg, graph_data, model, structural_features=None):
    adj_orig = graph_data.adj_train

    if cfg["use_gae"]:
        pickle.dump(adj_orig, open(f'graphs/graph_0_gae.pkl', 'wb'))
    else:
        pickle.dump(adj_orig, open(f'graphs/graph_0.pkl', 'wb'))

    features = graph_data.x.to(cfg["device"])
    for i in range(cfg["gen_graphs"]):
        with torch.no_grad():
            A_pred = model(features, structural_features)

        A_pred = torch.sigmoid(A_pred).detach().cpu()
        adj_recon = A_pred.numpy()
        np.fill_diagonal(adj_recon, 0)

        if cfg["use_gae"]:
            filename = f'graphs/graph_{i+1}_logits_gae.pkl'
        else:
            filename = f'graphs/graph_{i+1}_logits.pkl'

        pickle.dump(adj_recon, open(filename, 'wb'))


def main(data_path, cfg):
    torch.manual_seed(cfg["seed"])

    graph_data = GraphData(data_path, cfg["device"]) 

    model = VGAE(
        adj=graph_data.adj_train_norm,
        adj_louvain=graph_data.adj_train_louvain_norm if cfg["use_louvain"] else None,
        dim_in=graph_data.x.shape[1],
        dim_h=cfg["dim_h"],
        dim_z=cfg["dim_z"],
        use_gae=cfg["use_gae"],
    ).to(cfg["device"])

    if cfg["pretrained"]:
        model.load_state_dict(torch.load(cfg["pretrained"]))
    else:
        model = train_model(cfg, graph_data, model)

    if cfg["gen_graphs"] > 0:
        # Generate graphs
        gen_graphs(cfg, graph_data, model)

In [None]:
import os
def from_graph():
    data_loader = GraphData("data\data.pt")

    graph_aug = "graph_1_logits.pkl"  # NOTE: You can change this to a different graph

    with open(os.path.join("graphs", graph_aug), "rb") as f:
        adj = pickle.load(f)

    # No AUG - baseline (no delete edges)
    # [0.792, 0.796, 0.808, 0.796, 0.794, 0.800]

    remove_edges = [100, 200, 300, 400, 500, 600]
    edge_list = []
    for enum in remove_edges:
        edges = sample_graph_community(data_loader.adj_train, adj, enum, tau=0.005)
        edge_list.append(edges.T.reshape(-1).tolist())

    # NOTE: Don't change this, used for generating the submission csv
    df = pd.DataFrame(edge_list).fillna(-1).astype(int)
    # fill those empty units with -1 (don't change it)
    df.insert(0, 'ID', list(range(len(edge_list))))
    df.to_csv('submission.csv', index=False)
