# InfoGraph Tutorial
#### This tutorial illustrates the use of InfoGraph algorithm [InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization](https://openreview.net/pdf?id=r1lfF2NYvH), an unsupervised and semisupervised graph-level representation learning method,  which maximizes the mutual information between the graph-level representation and the representations of substructures of different scales.
#### The tutorial is organized as folows:
#### 1. [Preprocessing Data and Loading Configuration](InfoGraph.ipynb#L6)
#### 2. [Training the model](InfoGraph.ipynb#L7)
#### 3. [Evaluating the model](InfoGraph.ipynb#L8)

## 1. Preprocessing Data and Loading Configuration 
#### First, we load the configuration from yml file and the dataset. 
#### For easy usage, we conduct experiments to search for the best parameter across three datasets and find the proper value of parameters such that the performance of implemented InfoGraph is similar to the value reported in the paper.

In [None]:
from torch_geometric.loader import DataLoader
from src.methods.infograph import InfoGraph, Encoder
from src.trainer import SimpleTrainer
from src.evaluation import LogisticRegression
from torch_geometric.datasets import TUDataset, Entities
import os
from torch_geometric.nn import GINConv
from src.config import load_yaml
import torch
import numpy as np


config = load_yaml('./configuration/infograph_mutag.yml')
# config = load_yaml('./configuration/infograph_imdb_b.yml')
# config = load_yaml('./configuration/infograph_imdb_m.yml')
torch.manual_seed(config.torch_seed)
np.random.seed(config.torch_seed)
device = torch.device("cuda:{}".format(config.gpu_idx) if torch.cuda.is_available() and config.use_cuda else "cpu")

# -------------------- Data --------------------
current_folder = os.path.abspath('')
path = os.path.join(current_folder, config.dataset.root, config.dataset.name)
if config.dataset.name in ['IMDB-B', 'IMDB-M', 'mutag', 'COLLAB', 'PROTEINS']:
    # dataset = TUDataset(path, name=config.dataset.name).shuffle()
    dataset = TUDataset(path, name=config.dataset.name)
else:
    raise NotImplementedError
# dataset.x = torch.rand(dataset.y.shape[0], 100)
data_loader = DataLoader(dataset, batch_size=config.dataset.batch_size)

in_channels = max(dataset.num_features, 1)

## 2. Training the Model
#### In the second step, we first initialize the parameters of InfoGraph. The backbone of the encoder is Graph Isomorphism Network (GIN), while InfoGraph adopts the idea of Deep InfoMax as one major loss term. 
#### You may replace the encoder with the user-defined encoder. Please refer to the framework of encoder in the directory (./src/methods/infograph.py#L96). Keep in mind that the encoder consists of class initialization, forward function and get_embs() function.

In [None]:
# ------------------- Method -----------------
encoder = Encoder(in_channels=in_channels, hidden_channels=config.model.hidden_channels,
                  num_layers=config.model.n_layers, GNN=GINConv)
method = InfoGraph(encoder=encoder, hidden_channels=config.model.hidden_channels, num_layers=config.model.n_layers,
                   prior=False)

#### We train the model by calling trainer.train() function.

In [None]:
trainer = SimpleTrainer(method=method, data_loader=data_loader, device=device, n_epochs=config.optim.max_epoch)
trainer.train()

## 3. Evaluating the performance of InfoGraph
#### In the last step, we evaluate the performance of InfoGraph. We first get the embedding of by calling method.get_embs() function and then we use logistic regression to evaluate its performance. The more choice of classifier could be found in the directory (./src/evaluation/classifier.py), including svm, randomforest, etc. Besides, other evaluation methods in unsupervised setting could be found in the directory (./src/evaluation/cluster.py or ./src/evaluation/sim_search.py), including kmean method or similarity search.

In [None]:
# ------------------ Evaluator -------------------
method.eval()
data_pyg = dataset.data.to(method.device)
y, embs = method.get_embs(data_loader)

data_pyg.x = embs
lg = LogisticRegression(lr=config.classifier.base_lr, weight_decay=config.classifier.weight_decay,
                        max_iter=config.classifier.max_epoch, n_run=1, device=device)
lg(embs=embs, dataset=data_pyg)