# Graph Contrastive Learning with Adaptive Augmentation
Based on:
- **Paper:** https://arxiv.org/pdf/2010.14945.pdf
- **Implementation:** https://github.com/CRIPAC-DIG/GCA

In [1]:
# Install required packages.
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-geometric

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.datasets import WikiCS
from torch_geometric.utils import dropout_adj, degree, to_undirected
import torch_geometric.transforms as T
from torch.utils.data import random_split
from torch_geometric.nn import GCNConv

In [3]:
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_type)

## Dataset

In [4]:
dataset = WikiCS(root="data/WikiCS", transform=T.NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.
data = data.to(device)

print()
print(data)
print('===========================================================================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: WikiCS():
Number of graphs: 1
Number of features: 300
Number of classes: 10

Data(edge_index=[2, 297110], stopping_mask=[11701, 20], test_mask=[11701], train_mask=[11701, 20], val_mask=[11701, 20], x=[11701, 300], y=[11701])
Number of nodes: 11701
Number of edges: 297110
Average node degree: 25.39
Number of training nodes: 11600
Training node label rate: 0.99
Contains isolated nodes: True
Contains self-loops: True
Is undirected: False


In [5]:
def generate_split(num_samples: int, train_ratio: float, val_ratio: float):
    train_len = int(num_samples * train_ratio)
    val_len = int(num_samples * val_ratio)
    test_len = num_samples - train_len - val_len

    train_set, test_set, val_set = random_split(torch.arange(0, num_samples), (train_len, test_len, val_len))

    idx_train, idx_test, idx_val = train_set.indices, test_set.indices, val_set.indices
    train_mask = torch.zeros((num_samples,)).to(torch.bool)
    test_mask = torch.zeros((num_samples,)).to(torch.bool)
    val_mask = torch.zeros((num_samples,)).to(torch.bool)

    train_mask[idx_train] = True
    test_mask[idx_test] = True
    val_mask[idx_val] = True

    return {
            'train': train_mask,
            'test': test_mask,
            'val': val_mask
        }

In [6]:
split = generate_split(data.num_nodes, train_ratio=0.1, val_ratio=0.1)

## Models

In [7]:
class Encoder(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_layers: int = 2):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        
        self.conv = [GCNConv(in_channels, 2 * out_channels).jittable()]
        for _ in range(1, num_layers - 1):
            self.conv.append(base_model(2 * out_channels, 2 * out_channels))
        self.conv.append(GCNConv(2 * out_channels, out_channels))

        self.conv = nn.ModuleList(self.conv)
        self.activation = nn.PReLU()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
        for i in range(self.num_layers):
                x = self.activation(self.conv[i](x, edge_index))
        return x

In [8]:
class GRACE(torch.nn.Module):
    def __init__(self, encoder: Encoder, num_hidden: int, num_proj_hidden: int, tau: float = 0.5):
        super(GRACE, self).__init__()
        self.encoder: Encoder = encoder
        self.tau: float = tau

        self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden)
        self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)

        self.num_hidden = num_hidden

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        return self.encoder(x, edge_index)

    def projection(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        return self.fc2(z)

    def sim(self, z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())

    def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor):
        f = lambda x: torch.exp(x / self.tau)
        refl_sim = f(self.sim(z1, z1))
        between_sim = f(self.sim(z1, z2))

        return -torch.log(between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()))

    def loss(self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True):
        h1 = self.projection(z1)
        h2 = self.projection(z2)

        l1 = self.semi_loss(h1, h2)
        l2 = self.semi_loss(h2, h1)

        ret = (l1 + l2) * 0.5
        ret = ret.mean() if mean else ret.sum()

        return ret

In [9]:
class LogReg(nn.Module):
    def __init__(self, ft_in, nb_classes):
        super(LogReg, self).__init__()
        self.fc = nn.Linear(ft_in, nb_classes)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, seq):
        ret = self.fc(seq)
        return ret

## Utilities

In [10]:
def degree_drop_weights(edge_index):
    edge_index_ = to_undirected(edge_index)
    deg = degree(edge_index_[1])
    deg_col = deg[edge_index[1]].to(torch.float32)
    s_col = torch.log(deg_col)
    weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean())

    return weights

In [11]:
def feature_drop_weights_dense(x, node_c):
    x = x.abs()
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())

    return s

In [12]:
def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.):
    edge_weights = edge_weights / edge_weights.mean() * p
    edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)
    sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)

    return edge_index[:, sel_mask]

In [13]:
def drop_feature(x, drop_prob):
    drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0

    return x

In [14]:
def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w

    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)

    x = x.clone()
    x[:, drop_mask] = 0.

    return x

In [15]:
class MulticlassEvaluator:
    def __init__(self, *args, **kwargs):
        pass

    @staticmethod
    def _eval(y_true, y_pred):
        y_true = y_true.view(-1)
        y_pred = y_pred.view(-1)
        total = y_true.size(0)
        correct = (y_true == y_pred).to(torch.float32).sum()
        return (correct / total).item()

    def eval(self, res):
        return {'acc': self._eval(**res)}

In [16]:
def log_regression(z,
                   dataset,
                   evaluator,
                   num_epochs: int = 5000,
                   test_device=None,
                   split: str = 'rand:0.1',
                   verbose: bool = False,
                   preload_split=None):
    test_device = z.device if test_device is None else test_device
    z = z.detach().to(test_device)
    num_hidden = z.size(1)
    y = dataset[0].y.view(-1).to(test_device)
    num_classes = dataset[0].y.max().item() + 1
    classifier = LogReg(num_hidden, num_classes).to(test_device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01, weight_decay=0.0)

    split = {k: v.to(test_device) for k, v in split.items()}
    f = nn.LogSoftmax(dim=-1)
    nll_loss = nn.NLLLoss()

    best_test_acc = 0
    best_val_acc = 0
    best_epoch = 0

    for epoch in range(num_epochs):
        classifier.train()
        optimizer.zero_grad()

        output = classifier(z[split['train']])
        loss = nll_loss(f(output), y[split['train']])

        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            if 'val' in split:
                # val split is available
                test_acc = evaluator.eval({
                    'y_true': y[split['test']].view(-1, 1),
                    'y_pred': classifier(z[split['test']]).argmax(-1).view(-1, 1)
                })['acc']
                val_acc = evaluator.eval({
                    'y_true': y[split['val']].view(-1, 1),
                    'y_pred': classifier(z[split['val']]).argmax(-1).view(-1, 1)
                })['acc']
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_test_acc = test_acc
                    best_epoch = epoch
            else:
                acc = evaluator.eval({
                    'y_true': y[split['test']].view(-1, 1),
                    'y_pred': classifier(z[split['test']]).argmax(-1).view(-1, 1)
                })['acc']
                if best_test_acc < acc:
                    best_test_acc = acc
                    best_epoch = epoch
            if verbose:
                print(f'logreg epoch {epoch}: best test acc {best_test_acc}')

    return {'acc': best_test_acc}

## Training

In [17]:
param = {
    "learning_rate": 0.01,
    "num_hidden": 256,
    "num_proj_hidden": 256,
    'num_layers': 2,
    "drop_edge_rate_1": 0.2,
    "drop_edge_rate_2": 0.3,
    "drop_feature_rate_1": 0.1,
    "drop_feature_rate_2": 0.1,
    "tau": 0.4,
    "num_epochs": 1000,
    'weight_decay': 1e-5,
    'drop_scheme': 'degree',
}

In [18]:
encoder = Encoder(dataset.num_features, 
                  param['num_hidden'],
                  param['num_layers']).to(device)

model = GRACE(encoder, 
              param['num_hidden'], 
              param['num_proj_hidden'], 
              param['tau']).to(device)

optimizer = torch.optim.Adam(
        model.parameters(),
        lr=param['learning_rate'],
        weight_decay=param['weight_decay']
    )

In [19]:
if param['drop_scheme'] == 'degree':
    drop_weights = degree_drop_weights(data.edge_index).to(device)

    edge_index_ = to_undirected(data.edge_index)
    node_deg = degree(edge_index_[1])
    feature_weights = feature_drop_weights_dense(data.x, node_c=node_deg).to(device)

In [20]:
def train():
    model.train()
    optimizer.zero_grad()

    def drop_edge(idx: int):
        global drop_weights

        return drop_edge_weighted(data.edge_index, 
                                  drop_weights, 
                                  p=param[f'drop_edge_rate_{idx}'], 
                                  threshold=0.7)

    edge_index_1 = drop_edge(1)
    edge_index_2 = drop_edge(2)
    x_1 = drop_feature_weighted_2(data.x, feature_weights, param['drop_feature_rate_1'])
    x_2 = drop_feature_weighted_2(data.x, feature_weights, param['drop_feature_rate_2'])

    z1 = model(x_1, edge_index_1)
    z2 = model(x_2, edge_index_2)

    loss = model.loss(z1, z2, None)
    loss.backward()
    optimizer.step()

    return loss.item()

In [21]:
def test(final=False):
    model.eval()
    z = model(data.x, data.edge_index)

    evaluator = MulticlassEvaluator()
    accs = []
    for i in range(10):
        acc = log_regression(z, dataset, evaluator, split=split, num_epochs=800)['acc']
        accs.append(acc)
    acc = sum(accs) / len(accs)

    return acc

In [22]:
for epoch in range(1, param['num_epochs'] + 1):
  loss = train()


  if epoch % 100 == 0:
      acc = test()
      print(f'Epoch={epoch:04d}, avg_acc = {acc}, loss={loss:.4f}')

Epoch=0100, avg_acc = 0.7366947889328003, loss=95959.8750
Epoch=0200, avg_acc = 0.750774472951889, loss=94538.9141
Epoch=0300, avg_acc = 0.7531460106372834, loss=93321.4062
Epoch=0400, avg_acc = 0.7703129947185516, loss=92597.2891
Epoch=0500, avg_acc = 0.7793291091918946, loss=92145.5078
Epoch=0600, avg_acc = 0.7836555778980255, loss=91668.0625
Epoch=0700, avg_acc = 0.784926813840866, loss=91547.6328
Epoch=0800, avg_acc = 0.7875761032104492, loss=91233.2109
Epoch=0900, avg_acc = 0.7916461765766144, loss=90988.6953
Epoch=1000, avg_acc = 0.7905565500259399, loss=90878.7656


In [23]:
 acc = test(final=True)
print(f'{acc}')

0.7908449769020081
