this notebook is a simple replication of the [GraphMAE paper](https://arxiv.org/abs/2205.10803)

which uses a masked GNN autoencoder 

its pretraining task is to reconstruct masked node features from neighboring nodes 

before passing the graph (nodes and edge_index) to the encoder it masks some of the nodes (it replaces its features with a learnable vector like the  \[MASK] token in BERT) 

the notebook is focused on the effectiveness of SSL so it might lack some improvements for the classifiers

[colab link](https://colab.research.google.com/drive/1-uoWt0KtpAYs0egjAhJPhGGUFhHgxVwF?usp=sharing)

In [None]:
!pip install torch-scatter torch-sparse torch-cluster 

In [None]:
from sklearn.preprocessing import StandardScaler
import torch 
import numpy as np
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.datasets import Planetoid
from torch import nn
from typing import Optional,Tuple,List
import torch_geometric

if torch.cuda.is_available():  
  dev = "cuda:0" 
  print("gpu up")
else:  
  dev = "cpu"  

# dev = "cpu"  

device = torch.device(dev)


scaler =  StandardScaler()



torch.manual_seed(20)

gpu up


<torch._C.Generator at 0x7f9b7cdc6550>

In [None]:

def sce(x,y,alpha=2):
    # scaled cosine error 
    # mean( (1 - cosine similarity)^alpha ) 
    xnorm = torch.max(x.norm(p=2, dim=-1).unsqueeze(-1),torch.tensor(1e-12))
    ynorm = torch.max(y.norm(p=2, dim=-1).unsqueeze(-1),torch.tensor(1e-12))
    
    cosine_error = 1- (x*y/(xnorm * ynorm)).sum(dim=-1)
    return  cosine_error.pow(alpha).mean()

#these .float() calls are just for unifying the tensor datatypes 
class GConv(nn.Module):

    def __init__(self, emb_dim:int=500, num_layers:int=2, encode:bool=True, concat_out:bool=False, device='cpu', dropout=0.2):
     
        super(GConv,self).__init__()
        self.num_layers = num_layers
        self.gconv_layers = []
        self.norm_layers = []
        self.encode = encode
        for _ in range(num_layers):
            self.gconv_layers.append(gnn.SAGEConv(emb_dim, emb_dim, dropout=dropout,).to(device))
            if self.encode:
                self.norm_layers.append(nn.LayerNorm(emb_dim).to(device))

        self.concat_out = concat_out
    
    def forward(self, x, edge_index):
        
        outs = []
        if self.encode:
            outs.append(self.norm_layers[0](self.gconv_layers[0](x, edge_index)))
        else:
            outs.append(self.gconv_layers[0](x, edge_index))
        for i in range(1,self.num_layers):
            if self.encode:
                outs.append(self.norm_layers[i](self.gconv_layers[i](outs[-1], edge_index)))
            else:
                outs.append(self.gconv_layers[i](outs[-1], edge_index))
        if self.concat_out:
            return torch.cat(outs, dim = -1)
        
        return outs[-1]
        
class GraphMAE(nn.Module):

    def __init__(self, 
            emb_dim:int=500,
            masked_ratio:float=0.3, 
            num_encode_layers:int=2, 
            concat_out:bool=False, 
            device:torch.device='cpu',
            dropout:float=0.2
            ):
        super(GraphMAE, self).__init__()
        self.device = device
        self.encoder = GConv(emb_dim, num_encode_layers, concat_out=concat_out, encode=True, device=device, dropout=dropout).float().to(device)
        self.decoder = GConv(emb_dim, 1, encode=False, device=device, dropout=0).float().to(device)
        self.masked_ratio = masked_ratio 
        self.encoder_mask_token = nn.Parameter(torch.zeros(1,emb_dim)).float().to(device) 
        self.remask_token = nn.Parameter(torch.zeros(1,emb_dim)).float().to(device)
        if concat_out:
            self.encoder_to_decoder = nn.Linear(num_encode_layers*emb_dim, emb_dim).to(device)
        else:
            self.encoder_to_decoder = nn.Linear(emb_dim, emb_dim).to(device)

    
    def forward(self, x, edge_index):
        num_nodes = len(x)
        masked_nodes_index = torch.randperm(num_nodes)[:int(num_nodes * self.masked_ratio)]
        recon_x = x.clone().float().to(self.device)
        recon_x[masked_nodes_index,:] = torch.tensor(0).float()
        recon_x[masked_nodes_index,:] += self.encoder_mask_token
        
        recon_x = self.encoder(recon_x, edge_index)

        recon_x = self.encoder_to_decoder(recon_x)

        #remasking         
        recon_x[masked_nodes_index,:] = 0
        recon_x[masked_nodes_index,:] += self.remask_token

        recon_x = self.decoder(recon_x, edge_index)

        return recon_x, masked_nodes_index

    def encode(self, x, edge_index):
        return self.encoder_to_decoder(self.encoder(x.float(), edge_index.long()))
    
    def decode(self, x, edge_index):
      return self.decode(x, edge_index)

In [None]:
dataset = Planetoid("./pubmed","PubMed")
graph  =  dataset[0]

graphmae = GraphMAE(masked_ratio=0.4, num_encode_layers=2, dropout=0.2, concat_out=True, device=device).float().to(device)
scaler = StandardScaler() 
scaler.fit(graph.x)

scaled_x = torch.tensor(scaler.transform(graph.x)).to(device)
graph = graph.to(device)
# edge_index = torch.tensor(graph.edge_index, dtype= torch.double)
# some modules might use lazy initialization (which means parameters are initialized on the first forward call) 
graphmae(scaled_x, graph.edge_index.to(device))
optimizer = torch.optim.Adam(graphmae.parameters(), lr=0.001, weight_decay=1e-5)


In [None]:
import warnings
warnings.filterwarnings('ignore')
epochs  = 301
svc_accs = []
lr_accs = []
for e in range(epochs):

    optimizer.zero_grad()
    out, masked_nodes_index = graphmae(scaled_x, graph.edge_index.to(device))
    # break
    err = sce(out[masked_nodes_index,:], scaled_x[masked_nodes_index,:], alpha=2)    
    err.backward()
    optimizer.step()

    if e % 50 == 0: 
      print("scaled cosine error:",err.item())
      

scaled cosine error: 0.9978873160494309
scaled cosine error: 0.6998318488349362
scaled cosine error: 0.6720454481031665
scaled cosine error: 0.6661558207945321
scaled cosine error: 0.6524083907653202
scaled cosine error: 0.6536281098493184
scaled cosine error: 0.6462239074967576


In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

def evaluate_embeddings(x, edge_index, target, train_mask, val_mask):
  graphmae.eval()
  torch.manual_seed(20)
  encoded_embeddings = graphmae.encode(x, edge_index)
  # print(encoded_embeddings[:5,:5])
  lr = LogisticRegression()
  svc = SVC()
  lr.fit(encoded_embeddings[train_mask].detach().cpu().numpy(), target[train_mask].cpu().numpy())
  preds = lr.predict(encoded_embeddings[val_mask].detach().cpu().numpy())
  print("Logistic Regression Accuracy:",accuracy_score(target[val_mask].cpu().numpy(), preds))
 
  svc.fit(encoded_embeddings[train_mask].detach().cpu().numpy(), target[train_mask].cpu().numpy())
  preds = svc.predict(encoded_embeddings[val_mask].detach().cpu().numpy())
  print("SVC Accuracy:",accuracy_score(target[val_mask].cpu().numpy(), preds))
  

In [None]:

def fit_classifiers(nodes, y, train_mask, val_mask):
  svc = SVC()
  lr = LogisticRegression()
  print("scores on non-encoded vectors")
  lr.fit(nodes[train_mask].detach().cpu().numpy(), y[train_mask].cpu().numpy())
  preds = lr.predict(nodes[val_mask].detach().cpu().numpy())
  print("Logistic regression ",accuracy_score(y[val_mask].cpu().numpy(), preds))

  svc.fit(nodes[train_mask].detach().cpu().numpy(), y[train_mask].cpu().numpy())
  preds = svc.predict(nodes[val_mask].detach().cpu().numpy())
  print("SVC score",accuracy_score(y[val_mask].cpu().numpy(), preds))


comparing generated embeddings to just using scaled features for classification

In [None]:
evaluate_embeddings(scaled_x, graph.edge_index.to(device), graph.y, graph.train_mask, graph.test_mask)


Logistic Regression Accuracy: 0.743
SVC Accuracy: 0.76


In [None]:
fit_classifiers(scaled_x, graph.y, graph.train_mask, graph.test_mask)

scores on non-encoded vectors
Logistic regression  0.702
SVC score 0.677


but if increase training examples(by using val_mask) pretraining value is much less

In [None]:
evaluate_embeddings(scaled_x, graph.edge_index, graph.y , graph.val_mask, graph.test_mask)

Logistic Regression Accuracy: 0.737
SVC Accuracy: 0.803


In [None]:
fit_classifiers(scaled_x, graph.y, graph.val_mask, graph.test_mask)

scores on non-encoded vectors
Logistic regression  0.768
SVC score 0.762


Doing the same comparison using a GNN to for classification

In [None]:

class GNNClassifier(nn.Module):
  def __init__(self, 
              emb_dim:int=500,
              masked_ratio:float=0.3, 
              num_encode_layers:int=2, 
              concat_out:bool=False, 
              device:torch.device='cpu',
              dropout:float=0.2
              ):
    super(GNNClassifier, self).__init__()
    self.device = device
    self.gnn_conv = GConv(emb_dim, num_encode_layers, concat_out=concat_out, encode=True, device=device, dropout=dropout).float().to(device)
    self.criterion = nn.CrossEntropyLoss()
    self.classifier = nn.Linear(emb_dim, 3).to(device)

  def forward(self, x, edge_index):
    return self.classifier(self.gnn_conv(x.float(), edge_index.long()))

  def loss(self, x, edge_index, target, mask):

    outs =  self.classifier(self.gnn_conv(x.float(), edge_index.long()))

    return outs, self.criterion(outs[mask], target[mask])
      


In [None]:
import warnings
warnings.filterwarnings('ignore')
gnn_classifier = GNNClassifier(device=device)
epochs  = 50  # low number of epochs to not give much of a chance to overfit
svc_accs = []
lr_accs = []
optimizer = torch.optim.Adam(gnn_classifier.parameters(), lr=0.001, weight_decay=1e-5)
for e in range(epochs):

    optimizer.zero_grad()
    outs, err = gnn_classifier.loss(scaled_x, graph.edge_index.to(device), graph.y, graph.train_mask)
    # break
    err.backward()
    optimizer.step()

    if e % 10 == 0: 
      print("Cross Entropy:",err.item())
      print(accuracy_score(outs[graph.test_mask].argmax(-1).cpu().detach().numpy(),graph.y[graph.test_mask].cpu()))
      

Cross Entropy: 1.2189218997955322
0.345
Cross Entropy: 0.535369336605072
0.536
Cross Entropy: 0.21646060049533844
0.596
Cross Entropy: 0.08846046775579453
0.61
Cross Entropy: 0.03282538428902626
0.616


In [None]:
import warnings
warnings.filterwarnings('ignore')
gnn_classifier = GNNClassifier(device=device)
epochs  = 50
svc_accs = []
lr_accs = []
optimizer = torch.optim.Adam(gnn_classifier.parameters(), lr=0.001, weight_decay=1e-5)

with torch.no_grad():
  graphmae.eval()
  ssl_embeddings = graphmae.encode(scaled_x, graph.edge_index.to(device))

for e in range(epochs):
    
    optimizer.zero_grad()
    outs, err = gnn_classifier.loss(ssl_embeddings, graph.edge_index.to(device), graph.y, graph.train_mask)
    # break
    err.backward()
    optimizer.step()

    if e % 10 == 0: 
      print("Cross Entropy:",err.item())
      # print(accuracy_score(outs[graph.test_mask+graph.val_mask].argmax(-1).cpu().detach().numpy(),graph.y[graph.test_mask+graph.val_mask].cpu()))
      print(accuracy_score(outs[graph.test_mask].argmax(-1).cpu().detach().numpy(),graph.y[graph.test_mask].cpu()))
      

Cross Entropy: 1.1787744760513306
0.476
Cross Entropy: 0.469384104013443
0.722
Cross Entropy: 0.36979395151138306
0.714
Cross Entropy: 0.3411937654018402
0.726
Cross Entropy: 0.26469433307647705
0.738
