# GraphCL Tutorial
#### This tutorial illustrates the use of GraphCL algorithm ([Graph Contrastive Learning with Augmentations](https://proceedings.nips.cc/paper/2020/file/3fe230348e9a12c13120749e3f9fa4cd-Paper.pdf)), agraph contrastive learning framework for learning unsupervised representations of graph data,  which maximizes the mutual information between two augmented graphs.
#### The tutorial is organized as folows:
#### 1. [Preprocessing Data and Loading Configuration](GraphCL.ipynb#L6)
#### 2. [Graph Augmentation](GraphCL.ipynb#L7)
#### 3. [Training the model](GraphCL.ipynb#L7)
#### 4. [Evaluating the model](GraphCL.ipynb#L8)

## 1. Preprocessing Data and Loading Configuration 
#### First, we load the configuration from yml file and the dataset. 
#### For easy usage, we conduct experiments to search for the best parameter across three datasets and find the proper value of parameters such that the performance of implemented GraphCL is similar to the value reported in the paper.

In [None]:
from src.augment import *
from src.methods import GraphCL, GraphCLEncoder
from src.trainer import SimpleTrainer
from torch_geometric.loader import DataLoader
from src.transforms import NormalizeFeatures, GCNNorm, Edge2Adj, Compose
from src.datasets import Planetoid, Entities, Amazon, WikiCS, Coauthor
from src.evaluation import LogisticRegression
import torch
from src.config import load_yaml
from src.utils.create_data import create_masks
from src.utils.add_adj import add_adj_t

# load the configuration file
# config = load_yaml('./configuration/graphcl_amazon.yml')
# config = load_yaml('./configuration/graphcl_coauthor.yml')
config = load_yaml('./configuration/graphcl_wikics.yml')
# config = load_yaml('./configuration/graphcl_cora.yml')
torch.manual_seed(config.torch_seed)
device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu")

# data

if config.dataset.name == 'pyg_data':
    pre_transforms = Compose([NormalizeFeatures(ord=1), Edge2Adj(norm=GCNNorm(add_self_loops=1))])
    # dataset = Planetoid(root=config.dataset.root, name=config.dataset.name, pre_transform=pre_transforms)
    dataset = Planetoid(root='pyg_data', name=config.dataset.name)
elif config.dataset.name == 'Amazon':
    pre_transforms = NormalizeFeatures(ord=1)
    dataset = Amazon(root='pyg_data', name='Photo', pre_transform=pre_transforms)
elif config.dataset.name == 'WikiCS':
    pre_transforms = NormalizeFeatures(ord=1)
    dataset = WikiCS(root='pyg_data', pre_transform=pre_transforms)
elif config.dataset.name == 'coauthor':
    pre_transforms = NormalizeFeatures(ord=1)
    dataset = Coauthor(root='pyg_data', name='CS', pre_transform=pre_transforms)
else:
    raise 'please specify the correct dataset root'
if config.dataset.name in ['Amazon', 'WikiCS', 'coauthor']:
    dataset.data = create_masks(dataset.data, config.dataset.name)
dataset = add_adj_t(dataset)
data_loader = DataLoader(dataset)

## 2. Graph Augmentation
#### GraphCL list several graph agumentation methods, ranging from the node level to edge level and subgraph level.
#### You may change the aug_type to select different graph augmentation methods.

In [None]:
aug_type = config.model.aug_type
if aug_type == 'edge':
    augment_neg = AugmentorList([RandomDropEdge()])
elif aug_type == 'mask':
    augment_neg = AugmentorList([RandomMask()])
elif aug_type == 'node':
    augment_neg = AugmentorList([RandomDropNode()])
elif aug_type == 'subgraph':
    augment_neg = AugmentorList([AugmentSubgraph()])
else:
    assert 'unrecognized augmentation method'

## 3. Training the Model
#### In the second step, we first initialize the parameters of GraphCL. The backbone of GraphCL is Deep InfoMax (DGI). 
#### You may replace the encoder with the user-defined encoder. Please refer to the framework of encoder in the directory (./src/methods/graphcl.py#L96). Keep in mind that the encoder consists of class initialization, forward function and get_embs() function.

In [None]:
# ------------------- Method -----------------
encoder = GraphCLEncoder(in_channels=config.model.in_channels, hidden_channels=config.model.hidden_channels)
method = GraphCL(encoder=encoder, corruption=augment_neg, hidden_channels=config.model.hidden_channels)
method.augment_type = aug_type

#### We train the model by calling trainer.train() function.

In [None]:
trainer = SimpleTrainer(method=method, data_loader=data_loader, device=device, n_epochs=config.optim.max_epoch)
trainer.train()

## 4. Evaluating the performance of GraphCL
#### In the last step, we evaluate the performance of GraphCL. We first get the embedding of by calling method.get_embs() function and then we use logistic regression to evaluate its performance. The more choice of classifier could be found in the directory (./src/evaluation/classifier.py), including svm, randomforest, etc. Besides, other evaluation methods in unsupervised setting could be found in the directory (./src/evaluation/cluster.py or ./src/evaluation/sim_search.py), including kmean method or similarity search.

In [None]:
# ------------------ Evaluator -------------------
data_pyg = dataset.data.to(method.device)
embs = method.get_embs(data_pyg, data_pyg.adj_t).detach()

lg = LogisticRegression(lr=config.classifier.base_lr, weight_decay=config.classifier.weight_decay,
                        max_iter=config.classifier.max_epoch, n_run=1, device=device)
lg(embs=embs, dataset=data_pyg)