# Tutorial5: Aggregation


In this tutorial we will override the aggregation method of the GIN convolution module of Pytorch Geometric implementing the following methods:

- Principal Neighborhood Aggregation (PNA)
- Learning Aggregation Functions (LAF)

In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
torch.manual_seed(42)

### Message Passing Class

In [None]:
from torch_geometric.nn import MessagePassing

In [None]:
dir(MessagePassing)

We are interested in the <span style='color:Blue'>aggregate</span> method, or, if you are using a sparse adjacency matrix, in the <span style='color:Blue'>message_and_aggregate</span> method. Convolutional classes in PyG extend MessagePassing, we construct our custom convoutional class extending GINConv.

Scatter operation in <span style='color:Blue'>aggregate</span>:

<img src="https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true" width="500">

In [None]:
from torch.nn import Parameter, Module, Sigmoid
import torch
import torch_scatter
import torch.nn.functional as F

class AbstractLAFLayer(Module):
    def __init__(self, **kwargs):
        super(AbstractLAFLayer, self).__init__()
        assert 'units' in kwargs or 'weights' in kwargs
        if 'device' in kwargs.keys():
            self.device = kwargs['device']
        else:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.ngpus = torch.cuda.device_count()
        
        if 'kernel_initializer' in kwargs.keys():
            assert kwargs['kernel_initializer'] in [
                'random_normal',
                'glorot_normal',
                'he_normal',
                'random_uniform',
                'glorot_uniform',
                'he_uniform']
            self.kernel_initializer = kwargs['kernel_initializer']
        else:
            self.kernel_initializer = 'random_normal'

        if 'weights' in kwargs.keys():
            self.weights = Parameter(kwargs['weights'].to(self.device), \
                                     requires_grad=True)
            self.units = self.weights.shape[1]
        else:
            self.units = kwargs['units']
            params = torch.empty(12, self.units, device=self.device)
            if self.kernel_initializer == 'random_normal':
                torch.nn.init.normal_(params)
            elif self.kernel_initializer == 'glorot_normal':
                torch.nn.init.xavier_normal_(params)
            elif self.kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(params)
            elif self.kernel_initializer == 'random_uniform':
                torch.nn.init.uniform_(params)
            elif self.kernel_initializer == 'glorot_uniform':
                torch.nn.init.xavier_uniform_(params)
            elif self.kernel_initializer == 'he_uniform':
                torch.nn.init.kaiming_uniform_(params)
            self.weights = Parameter(params, \
                                     requires_grad=True)
        e = torch.tensor([1,-1,1,-1], dtype=torch.float32, device=self.device)
        self.e = Parameter(e, requires_grad=False)
        num_idx = torch.tensor([1,1,0,0], dtype=torch.float32, device=self.device).\
                                view(1,1,-1,1)
        self.num_idx = Parameter(num_idx, requires_grad=False)
        den_idx = torch.tensor([0,0,1,1], dtype=torch.float32, device=self.device).\
                                view(1,1,-1,1)
        self.den_idx = Parameter(den_idx, requires_grad=False)
        

class LAFLayer(AbstractLAFLayer):
    def __init__(self, eps=1e-7, **kwargs):
        super(LAFLayer, self).__init__(**kwargs)
        self.eps = eps
    
    def forward(self, data, index, dim=0, **kwargs):
        eps = self.eps
        sup = 1.0 - eps 
        e = self.e

        x = torch.clamp(data, eps, sup)
        x = torch.unsqueeze(x, -1)
        e = e.view(1,1,-1)        

        exps = (1. - e)/2. + x*e 
        exps = torch.unsqueeze(exps, -1)
        exps = torch.pow(exps, torch.relu(self.weights[0:4]))

        scatter = torch_scatter.scatter_add(exps, index.view(-1), dim=dim)
        scatter = torch.clamp(scatter, eps)

        sqrt = torch.pow(scatter, torch.relu(self.weights[4:8]))
        alpha_beta = self.weights[8:12].view(1,1,4,-1)
        terms = sqrt * alpha_beta

        num = torch.sum(terms * self.num_idx, dim=2)
        den = torch.sum(terms * self.den_idx, dim=2)
        
        multiplier = 2.0*torch.clamp(torch.sign(den), min=0.0) - 1.0

        den = torch.where((den < eps) & (den > -eps), multiplier*eps, den)

        res = num / den
        return res

In [None]:
from torch_geometric.nn import GINConv
from torch.nn import Linear

### LAF Aggregation Module

In [None]:
class GINLAFConv(GINConv):
    def __init__(self, nn, units=1, node_dim=32, **kwargs): # TODO change dim of node embeddings? But how? 
        super(GINLAFConv, self).__init__(nn, **kwargs)
        self.laf = LAFLayer(units=units, kernel_initializer='random_uniform')
        self.mlp = torch.nn.Linear(node_dim*units, node_dim)
        self.dim = node_dim
        self.units = units
        #self.batch_size = len(self.batch) # here we control TODO requires self.batch to be initialized
    
    def aggregate(self, inputs, index):
        x = torch.sigmoid(inputs)
        x = self.laf(x, index)
        x = x.view((-1, self.dim * self.units))
        x = self.mlp(x)
        return x
    
    # we add the forward, to control for the batch size (vs 945/946 error of model prediction)
    def forward(self, x, edge_index, batch):
        x = (x, x) # propagate_type: (x: OptPairTensor)
        # here we control TODO requires self.batch to be initialized
        batch_size=len(batch)
        batch_size = (batch_size, batch_size) # double it as x=(x,x) information is duplicated
        out = self.propagate(edge_index, x=x, size=batch_size)
        
        x_r = x[1]
        #print(f"x_r.shape: {x_r.shape}")
        #print(f"out.shape: {out.shape} \n")        
        if x_r is not None:
            #print("no_null")
            out = out + (1 + self.eps) * x_r

        return self.nn(out)

In [None]:
class LAFNet(torch.nn.Module):
    def __init__(self):
        super(LAFNet, self).__init__()

        num_features = dataset.num_features #TODO pass data as input
        dim = 9
        units = 5
        
        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
        self.conv1 = GINLAFConv(nn1, units=units, node_dim=num_features)
        self.bn1 = torch.nn.BatchNorm1d(dim)

        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv2 = GINLAFConv(nn2, units=units, node_dim=dim)
        self.bn2 = torch.nn.BatchNorm1d(dim)

        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv3 = GINLAFConv(nn3, units=units, node_dim=dim)
        self.bn3 = torch.nn.BatchNorm1d(dim)

        nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv4 = GINLAFConv(nn4, units=units, node_dim=dim)
        self.bn4 = torch.nn.BatchNorm1d(dim)

        nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv5 = GINLAFConv(nn5, units=units, node_dim=dim)
        self.bn5 = torch.nn.BatchNorm1d(dim)

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        
        #self.batch_size=len(batch) # control for the batch size in forward propagation
        
        x = F.relu(self.conv1(x, edge_index, batch)) # we add batch in the BN (hmm, des Chocapic...)
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index, batch))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index, batch))
        x = self.bn3(x)
        x = F.relu(self.conv4(x, edge_index, batch))
        x = self.bn4(x)
        x = F.relu(self.conv5(x, edge_index, batch))
        x = self.bn5(x)
        #print(f"x just after bn1, ..., 5: {x.shape[0]}")
        # here, we test without global pooling -> useless batch indicator?
        #x = global_add_pool(x, batch, size=len(batch)) # control for the passed size of batches
        #print(f"x just after pool (with batch): {x.shape[0]}")
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)

### Test the new classes

In [None]:
from torch_geometric.nn import MessagePassing, SAGEConv, GINConv, global_add_pool
import torch_scatter
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
import os.path as osp

In [None]:
path = osp.join('./', 'data', 'TU')
dataset = TUDataset(path, name='MUTAG').shuffle()
test_dataset = dataset[:len(dataset) // 10]
train_dataset = dataset[len(dataset) // 10:]
test_loader = DataLoader(test_dataset, batch_size=128)
train_loader = DataLoader(train_dataset, batch_size=128)

# Here are a few tests of mine...

## Load balanced data

In [None]:
# imports and train/test split (to be put in part 2. of the notebook)
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import torch
try:
    import torch_geometric
except ModuleNotFoundError:
    TORCH = torch.__version__.split("+")[0]
    CUDA = "cu" + torch.version.cuda.replace(".","")
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
#!pip install torch-geometric
#import torch_geometric

import sys
sys.path.append("../")

import time
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

import torch

import numpy as np
import pandas as pd

from classif_basic.data_preparation import handle_cat_features

from classif_basic.graph.data_to_graph import table_to_graph, add_new_edge
from classif_basic.graph.train import train_GNN_ancestor

# preparing the dataset on clients for binary classification
data = fetch_openml(data_id=1590, as_frame=True)

X = data.data
Y = (data.target == '>50K') * 1

SEED = 7
VALID_SIZE = 0.15
preprocessing_cat_features = "label_encoding"

X = handle_cat_features(X=X, preprocessing_cat_features=preprocessing_cat_features)

# first of all, unify features with "redundant" causal information
from classif_basic.graph.utils import get_unified_col

X = get_unified_col(X=X, list_cols_to_join = ["education","education-num"], new_col_name = "education")
X = get_unified_col(X=X, list_cols_to_join = ["relationship","marital-status"], new_col_name = "relationship")
X = get_unified_col(X=X, list_cols_to_join = ["occupation","workclass"], new_col_name = "job")
X = get_unified_col(X=X, list_cols_to_join = ["capital-gain","capital-loss"], new_col_name = "capital")

# select equal proportion of classes "wealthy" and "not wealthy", and generates the new dataset accordingly
from classif_basic.graph.utils import get_balanced_df

balanced_df = get_balanced_df(X=X, Y=Y)

X_balanced = balanced_df.drop("target", axis=1)
Y_balanced = balanced_df["target"]

#X=X_balanced # here, we try with the whole dataset (assuming it is imbalanced, but counts almost 50 000 nodes)
#Y=Y_balanced

# then, normalize the df categories for better neural-network computation
from classif_basic.graph.utils import normalize_df

X=normalize_df(df=X, normalization='min_max')

# Split valid set for early stopping & model selection
# "stratify=Y" to keep the same proportion of target classes in train/valid (i.e. model) and test sets 
X_model, X_test, Y_model, Y_test = train_test_split(
    X, Y, test_size=VALID_SIZE, random_state=SEED, stratify=Y
)

#from classif_basic.graph.data_to_graph import table_to_graph

#data_with_batch = table_to_graph(X=X_model,
                                #Y=Y_model,
                                #list_edges_names=['education', 'relationship'],
                                #nb_batches=3)

In [None]:
# get batch size equally splitting individuals 
# TODO check of user inputs (batch_size, nb_batches) before LAF computation

def smallest_prime_factor(x):
    """Returns the smallest prime number that is a divisor of x"""
    # Start checking with 2, then move up one by one
    n = 2
    while n*n <= x:
        if x % n == 0:
            return n
        n += 1
    return x

pgcd_model = smallest_prime_factor(X_model.shape[0])
pgcd_model = smallest_prime_factor(X_model.shape[0]/pgcd_model)

pgcd_test = smallest_prime_factor(X_test.shape[0])

In [None]:
nb_batches_model = X_model.shape[0]/pgcd_model
print(nb_batches_model)

In [None]:
nb_batches_test = X_test.shape[0]/pgcd_test
print(nb_batches_test)

In [None]:
from classif_basic.model import pickle_load_model
from classif_basic.graph.loader import get_loader

#data_with_batch=pickle_load_model("/work/data/graph_data/unbalanced/data_full_features_education_relationship.pkl")

#data_with_batch_test=pickle_load_model("work/data/graph_data/unbalanced/test/data_full_features_education_relationship.pkl")

# TODO explicit: pass sizes to LAF module
data_with_batch = table_to_graph(X=X_model,
                                Y=Y_model,
                                list_edges_names=['education', 'relationship'],
                                nb_batches=nb_batches_model) # TODO change name in batch_size?

data_with_batch_test = table_to_graph(X=X_test,
                                Y=Y_test,
                                list_edges_names=['education', 'relationship'],
                                nb_batches=nb_batches_test)

In [None]:
loader_method='index_groups'

print("Train&Valid Set")
train_loader = get_loader(data_total=data_with_batch, 
                    loader_method=loader_method,
                    batch_size=nb_batches_model)

print("Test Set")
test_loader = get_loader(data_total=data_with_batch_test, 
                    loader_method=loader_method,
                    batch_size=nb_batches_test)

## Train LAF

In [None]:
import time
import statistics

from sklearn.utils import class_weight

dataset=data_with_batch # TODO explicit: pass it as class LAFNet input

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = "LAF"
if net == "LAF":
    model = LAFNet().to(device)
elif net == "PNA":
    model = PNANet().to(device)
elif net == "GIN":
    GINNet().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(epoch):
    model.train()

    if epoch == 51:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.5 * param_group['lr']

    list_batches_loss = []
    # test with the 6 first batches (which hold an appropriate shape)
    for i, data in enumerate(train_loader): #zip(range(7), train_loader):
        #print(f"Training of subgraph {i}")
        data.x = data.x.to(torch.float32) # to avoid dtypes differences
        #data.edge_index = data.edge_index.to(torch.float32) # to avoid dtypes differences
        #data.batch = data.batch.to(torch.float32) # to avoid dtypes differences
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        #print(f"len(data.x): {len(data.x)}")
        #print(f"len(output): {len(output)}")
        #print(f"len(data.y): {len(data.y)}")
        class_weights=class_weight.compute_class_weight(class_weight='balanced', # alternative: dict_class_weights for large imbalance?
                                                        classes=np.unique(data.y.cpu()),
                                                        y=data.y.cpu().numpy())
        class_weights=torch.tensor(class_weights,dtype=torch.float).to(device)
        cr_loss = torch.nn.CrossEntropyLoss(weight=class_weights,reduction='mean')  
        loss=cr_loss(output, data.y)
        #loss = F.nll_loss(output, data.y)
        loss.backward()
        list_batches_loss.append(loss.item())
        #loss_all += loss.item() #* data.num_graphs
        optimizer.step()

    loss_mean = round(statistics.mean(list_batches_loss), 2)
    
    return loss_mean


def test(loader):
    model.eval()

    accuracy_total=0
    list_batches_accuracy = []
    
    for i, data in enumerate(loader):
    #for i, data in zip(range(7),loader):
        data.x = data.x.to(torch.float32) # to avoid dtypes differences
        #data.edge_index = data.edge_index.to(torch.float32) # to avoid dtypes differences
        #data.batch = data.batch.to(torch.float32) # to avoid dtypes differences
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        correct = pred.eq(data.y).sum().item()
        total = len(data.y)
        accuracy_sample = round(correct/total, 2)
        list_batches_accuracy.append(accuracy_sample)
        #print(f"Accuracy on training batch {i}: {accuracy_sample}")
        
    accuracy_mean = round(statistics.mean(list_batches_accuracy), 2)
    
    return accuracy_mean

time0=time.time()

epoch_nb = 1
for epoch in range(1, 15+1):
    print(f"\n '''Epoch {epoch_nb}''' ")
    train_loss = train(epoch)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                       train_acc, test_acc))
    print()
    epoch_nb = epoch_nb + 1

time1=time.time()

print(f"Training LAF took {round((time1-time0)/60)} mn")