# Implementation of Deep Graph InfoMax

Based upon:
- **Paper:** Deep Graph Infomax (Veličković et al., ICLR 2019) https://arxiv.org/pdf/1809.10341.pdf

- **Implementation:** https://github.com/PetarV-/DGI


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os.path as osp
import numpy as np
import scipy.sparse as sp
import pickle as pkl
from scipy.sparse.linalg.eigen.arpack import eigsh
import sys
import networkx as nx

## Load Dataset

In [2]:
def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)

def load_data(path): 
    """Load data."""
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    dataset_str = "cora"
    for i in range(len(names)):
        with open("{}/ind.{}.{}".format(path, dataset_str, names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file("{}/ind.{}.test.index".format(path, dataset_str))
    test_idx_range = np.sort(test_idx_reorder)

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y)+500)

    return adj, features, labels, idx_train, idx_val, idx_test

def sparse_to_tuple(sparse_mx, insert_batch=False):
    """Convert sparse matrix to tuple representation."""
    """Set insert_batch=True if you want to insert a batch dimension."""
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        if insert_batch:
            coords = np.vstack((np.zeros(mx.row.shape[0]), mx.row, mx.col)).transpose()
            values = mx.data
            shape = (1,) + mx.shape
        else:
            coords = np.vstack((mx.row, mx.col)).transpose()
            values = mx.data
            shape = mx.shape
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx

def standardize_data(f, train_mask):
    """Standardize feature matrix and convert to tuple representation"""
    # standardize data
    f = f.todense()
    mu = f[train_mask == True, :].mean(axis=0)
    sigma = f[train_mask == True, :].std(axis=0)
    f = f[:, np.squeeze(np.array(sigma > 0))]
    mu = f[train_mask == True, :].mean(axis=0)
    sigma = f[train_mask == True, :].std(axis=0)
    f = (f - mu) / sigma
    return f

def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features.todense(), sparse_to_tuple(features)

def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))
    return sparse_to_tuple(adj_normalized)

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


## Deep Graph InfoMax Layers

### GCN Model

We define an encoder $ \mathcal{E}:\mathcal{R}^{N\times F}\times\mathcal{R}^{N\times N}\rightarrow \mathcal{R}^{N\times F'}$ such that $ \mathcal{E}(X,A)=H=\{\vec{h}_1, \vec{h}_2,.., \vec{h}_N\}$ represents high-level representations $\vec{h}_i\in \mathcal{R}^{F'} $ for each node $i$.


Here, we use the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) defined as: $ \mathcal{E}(X,A)=\sigma(\hat{D}^{-1/2}\hat{A}\hat{D}^{-1/2}X\Theta)
$.

In [3]:
class GCN(nn.Module):
    def __init__(self, in_ft, out_ft, bias=True):
        super(GCN, self).__init__()
        self.fc = nn.Linear(in_ft, out_ft, bias=False)
        self.act = nn.PReLU()
        
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_ft))
            self.bias.data.fill_(0.0)
        else:
            self.register_parameter('bias', None)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    # Shape of seq: (batch, nodes, features)
    def forward(self, seq, adj, sparse=False):
        seq_fts = self.fc(seq)
        if sparse:
            out = torch.unsqueeze(torch.spmm(adj, torch.squeeze(seq_fts, 0)), 0)
        else:
            out = torch.bmm(adj, seq_fts)
        if self.bias is not None:
            out += self.bias
        
        return self.act(out)

### Readout Function
We leverage a readout function $\mathcal{R}:\mathbb{R} ^{N\times F}\rightarrow \mathbb{R}^F$ and use it to summarize the obtained patch representations into a graph-level representation, i.e. $\vec{s}=\mathcal{R}(\mathcal{E}(X,A))$.

We use a simple averaging of all the nodes' features: $ \mathcal{R}(H)=\sigma(\frac{1}{N}\sum_{i=1}^N\vec{h}_i) $.

In [4]:
class AvgReadout(nn.Module):
    def __init__(self):
        super(AvgReadout, self).__init__()

    def forward(self, seq, msk):
        if msk is None:
            return torch.mean(seq, 1)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 1) / torch.sum(msk)

### Discriminator Function

As a proxy for maximizing the local mutual information, we employ a discriminator $\mathcal{D}:\mathbb{R}^F\times \mathbb{R}^F\rightarrow \mathbb{R}$ such that $\mathcal{D}(\vec{h}_i,\vec{s})$ represents the probability scores assigned to this patch-summary pair.

Negative samples for $\mathcal{D}$ are provided by pairing the summary $\vec{s}$ from $(X,A)$ with patch representations $\vec{\tilde{h}}_j$ of an alternative graph $(\tilde{X},\tilde{A})$:

- In a multi-graph setting, such graphs may be obtained as other elements of a training set
- For a single graph, an explicit corruption function is required to obtain a negative example from the original graph

Here, we score summary-patch representation pairs by applying a simple bilinear scoring function with W as a learnable scoring matrix: $D(\vec{h}_i,\vec{s})=\sigma(\vec{h}_i^TW\vec{s})$

In [5]:
class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
        c_x = torch.unsqueeze(c, 1)
        c_x = c_x.expand_as(h_pl)

        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2)
        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2)

        if s_bias1 is not None:
            sc_1 += s_bias1
        if s_bias2 is not None:
            sc_2 += s_bias2

        logits = torch.cat((sc_1, sc_2), 1)

        return logits

## Deep Graph InfoMax Models

### DGI

In [6]:
class DGI(nn.Module):
    def __init__(self, n_in, n_h):
        super(DGI, self).__init__()
        self.gcn = GCN(n_in, n_h)
        self.read = AvgReadout()

        self.sigm = nn.Sigmoid()

        self.disc = Discriminator(n_h)

    def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2):
        h_1 = self.gcn(seq1, adj, sparse)

        c = self.read(h_1, msk)
        c = self.sigm(c)

        h_2 = self.gcn(seq2, adj, sparse)

        ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2)

        return ret

    # Detach the return variables
    def embed(self, seq, adj, sparse, msk):
        h_1 = self.gcn(seq, adj, sparse)
        c = self.read(h_1, msk)

        return h_1.detach(), c.detach()

### Logistic Regression

In [7]:
class LogReg(nn.Module):
    def __init__(self, ft_in, nb_classes):
        super(LogReg, self).__init__()
        self.fc = nn.Linear(ft_in, nb_classes)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, seq):
        ret = self.fc(seq)
        return ret


## Load Dataset

In [8]:
# Load data
from google.colab import drive
drive.mount('/content/drive')
path = "./drive/MyDrive/Colab Notebooks/Representation Learning for GNNs/data/cora/"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
adj, features, labels, idx_train, idx_val, idx_test = load_data(path)
features, _ = preprocess_features(features)

## Training

In [10]:
args = {
    "device" : 'cuda' if torch.cuda.is_available() else 'cpu',
    "epochs" : 500,
    "patience": 20,
    "lr" : 0.001,
    "weight_decay": 0.0,
    "dropout": 0.0,
    "hidden" : 512,
    "batch_size": 1,
    "sparse": True
}

In [11]:
nb_nodes = features.shape[0]
ft_size = features.shape[1]
nb_classes = labels.shape[1]

In [12]:
adj = normalize_adj(adj + sp.eye(adj.shape[0]))

if args["sparse"]:
    sp_adj = sparse_mx_to_torch_sparse_tensor(adj)
else:
    adj = (adj + sp.eye(adj.shape[0])).todense()

features = torch.FloatTensor(features[np.newaxis])
if not args["sparse"]:
    adj = torch.FloatTensor(adj[np.newaxis])
labels = torch.FloatTensor(labels[np.newaxis])
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)

In [13]:
if torch.cuda.is_available():
    print('Using CUDA')
    features = features.cuda()
    if args["sparse"]:
        sp_adj = sp_adj.cuda()
    else:
        adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

In [14]:
b_xent = nn.BCEWithLogitsLoss()
xent = nn.CrossEntropyLoss()
cnt_wait = 0
best = 1e9
best_t = 0
batch_size = 1

In [15]:
model = DGI(ft_size, args["hidden"]).to(args["device"])
optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])

In [16]:
for epoch in range(args["epochs"] + 1):
  model.train()
  optimizer.zero_grad()

  idx = np.random.permutation(nb_nodes)
  shuf_fts = features[:, idx, :]

  lbl_1 = torch.ones(batch_size, nb_nodes)
  lbl_2 = torch.zeros(batch_size, nb_nodes)
  lbl = torch.cat((lbl_1, lbl_2), 1)

  shuf_fts = shuf_fts.to(args["device"])
  lbl = lbl.to(args["device"])
  
  logits = model(features, shuf_fts, sp_adj if args["sparse"] else adj, args["sparse"], None, None, None) 
  loss = b_xent(logits, lbl)
  
  if epoch % 10 == 0:
    print(f'Epoch: {epoch} - Loss: {loss.cpu().detach().numpy()}')

  if loss < best:
      best = loss
      best_t = epoch
      cnt_wait = 0
      torch.save(model.state_dict(), 'best_dgi.pkl')
  else:
      cnt_wait += 1

  if cnt_wait == args["patience"]:
      print('Early stopping!')
      break

  loss.backward()
  optimizer.step()

Epoch: 0 - Loss: 0.6931074261665344
Epoch: 10 - Loss: 0.6414570212364197
Epoch: 20 - Loss: 0.5115184187889099
Epoch: 30 - Loss: 0.3741282522678375
Epoch: 40 - Loss: 0.2765086889266968
Epoch: 50 - Loss: 0.22733454406261444
Epoch: 60 - Loss: 0.1932801753282547
Epoch: 70 - Loss: 0.15830954909324646
Epoch: 80 - Loss: 0.14913277328014374
Epoch: 90 - Loss: 0.13185185194015503
Epoch: 100 - Loss: 0.11653786152601242
Epoch: 110 - Loss: 0.11042811721563339
Epoch: 120 - Loss: 0.10591542720794678
Epoch: 130 - Loss: 0.09473607689142227
Epoch: 140 - Loss: 0.09520158916711807
Epoch: 150 - Loss: 0.08930528163909912
Epoch: 160 - Loss: 0.08331979811191559
Epoch: 170 - Loss: 0.0773174986243248
Epoch: 180 - Loss: 0.06864462792873383
Epoch: 190 - Loss: 0.07126422226428986
Epoch: 200 - Loss: 0.06315848976373672
Epoch: 210 - Loss: 0.05941636115312576
Epoch: 220 - Loss: 0.06072395294904709
Early stopping!


In [17]:
print('Loading {}th epoch'.format(best_t))
model.load_state_dict(torch.load('best_dgi.pkl'))

Loading 205th epoch


<All keys matched successfully>

In [18]:
embeds, _ = model.embed(features, sp_adj if args["sparse"] else adj, args["sparse"], None)
train_embs = embeds[0, idx_train]
val_embs = embeds[0, idx_val]
test_embs = embeds[0, idx_test]

train_lbls = torch.argmax(labels[0, idx_train], dim=1)
val_lbls = torch.argmax(labels[0, idx_val], dim=1)
test_lbls = torch.argmax(labels[0, idx_test], dim=1)

tot = torch.zeros(1)

accs = []

In [19]:
for index in range(50):
    log = LogReg(args["hidden"], nb_classes)
    opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)

    pat_steps = 0
    best_acc = torch.zeros(1)
    for _ in range(100):
        log.train()
        opt.zero_grad()

        logits = log(train_embs)
        loss = xent(logits, train_lbls)
        
        loss.backward()
        opt.step()

    logits = log(test_embs)
    preds = torch.argmax(logits, dim=1)
    acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
    accs.append(acc * 100)
    tot += acc

print('Average accuracy:', tot / 50)

accs = torch.stack(accs)
print(accs.mean())
print(accs.std())

Average accuracy: tensor([0.8091])
tensor(80.9100)
tensor(0.1474)
