<a href="https://colab.research.google.com/github/adithyamauryakr/pytorchtutorials/blob/main/graph_nn_stuff/aggregation_funcs_gnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch_geometric torch_scatter

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: torch_scatter
  Building wheel for torch_scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch_scatter: filename=torch_scatter-2.1.2-cp311-cp311-linux_x86_64.whl size=547368 sha256=59d96437f18d7d382d8d086fe352d19238ec6037046c9a5d7c809f2e87da7981
  Stored in directory: /root/.cache/pip/w

In [2]:
import torch_geometric
from torch_geometric.datasets import Planetoid

In [3]:
import torch
from torch_geometric.nn import MessagePassing

In [4]:
dir(MessagePassing)

['SUPPORTS_FUSED_EDGE_INDEX',
 'T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_call_impl',
 '_check_input',
 '_collect',
 '_compiled_call_impl',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_edge_updater_signature',
 '_get_name',
 '_get_propagate_signature',
 '_index_select',
 '_index_select_safe',
 '_lift',
 '_load_from_state_dict',
 '_maybe_warn_non_full_backward_hook',
 '_named_members',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_set_jittable_templates',
 '_

In [5]:
from torch_geometric.nn import GINConv
from torch.nn import Linear
import torch.nn.functional as F
import torch_scatter


class AbstractLAFLayer(torch.nn.Module):
    def __init__(self, **kwargs):
        super(AbstractLAFLayer, self).__init__()
        assert 'units' in kwargs or 'weights' in kwargs
        if 'device' in kwargs.keys():
            self.device = kwargs['device']
        else:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.ngpus = torch.cuda.device_count()

        if 'kernel_initializer' in kwargs.keys():
            assert kwargs['kernel_initializer'] in [
                'random_normal',
                'glorot_normal',
                'he_normal',
                'random_uniform',
                'glorot_uniform',
                'he_uniform']
            self.kernel_initializer = kwargs['kernel_initializer']
        else:
            self.kernel_initializer = 'random_normal'

        if 'weights' in kwargs.keys():
            self.weights = Parameter(kwargs['weights'].to(self.device), \
                                     requires_grad=True)
            self.units = self.weights.shape[1]
        else:
            self.units = kwargs['units']
            params = torch.empty(12, self.units, device=self.device)
            if self.kernel_initializer == 'random_normal':
                torch.nn.init.normal_(params)
            elif self.kernel_initializer == 'glorot_normal':
                torch.nn.init.xavier_normal_(params)
            elif self.kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(params)
            elif self.kernel_initializer == 'random_uniform':
                torch.nn.init.uniform_(params)
            elif self.kernel_initializer == 'glorot_uniform':
                torch.nn.init.xavier_uniform_(params)
            elif self.kernel_initializer == 'he_uniform':
                torch.nn.init.kaiming_uniform_(params)
            self.weights = Parameter(params, \
                                     requires_grad=True)
        e = torch.tensor([1,-1,1,-1], dtype=torch.float32, device=self.device)
        self.e = Parameter(e, requires_grad=False)
        num_idx = torch.tensor([1,1,0,0], dtype=torch.float32, device=self.device).\
                                view(1,1,-1,1)
        self.num_idx = Parameter(num_idx, requires_grad=False)
        den_idx = torch.tensor([0,0,1,1], dtype=torch.float32, device=self.device).\
                                view(1,1,-1,1)
        self.den_idx = Parameter(den_idx, requires_grad=False)


class LAFLayer(AbstractLAFLayer):
    def __init__(self, eps=1e-7, **kwargs):
        super(LAFLayer, self).__init__(**kwargs)
        self.eps = eps

    def forward(self, data, index, dim=0, **kwargs):
        eps = self.eps
        sup = 1.0 - eps
        e = self.e

        x = torch.clamp(data, eps, sup)
        x = torch.unsqueeze(x, -1)
        e = e.view(1,1,-1)

        exps = (1. - e)/2. + x*e
        exps = torch.unsqueeze(exps, -1)
        exps = torch.pow(exps, torch.relu(self.weights[0:4]))

        scatter = torch_scatter.scatter_add(exps, index.view(-1), dim=dim)
        scatter = torch.clamp(scatter, eps)

        sqrt = torch.pow(scatter, torch.relu(self.weights[4:8]))
        alpha_beta = self.weights[8:12].view(1,1,4,-1)
        terms = sqrt * alpha_beta

        num = torch.sum(terms * self.num_idx, dim=2)
        den = torch.sum(terms * self.den_idx, dim=2)

        multiplier = 2.0*torch.clamp(torch.sign(den), min=0.0) - 1.0

        den = torch.where((den < eps) & (den > -eps), multiplier*eps, den)

        res = num / den
        return res

### LAF Aggregation

In [6]:
class GINLAFConv(GINConv):
  def __init_(self, nn, units=1, node_dim=32, **kwargs):

    super(GINLAFConv, self).__init__(nn, **kwargs)
    self.laf = LAFLayer(units=units, kernel_initializer='random_uniform')
    self.mlp = torch.nn.Linear(node_dim*units, node_dim)
    self.dim = node_dim
    self.units = units

  def aggregate(self, input, index):
    x = torch.sigmoid(inputs)
    x = self.laf(x, index)
    x = x.view((-1, self.dim*self.units))
    x = self.mlp(x)

    return(x)

### PNA Aggregation

In [7]:
class GINPNAConv(GINConv):
  def __init__(self, nn, node_dim = 32, **kwargs):
    super(GINPNAConv, self).__init__(nn, **kwargs)
    self.mlp = torch.nn.Linear(node_dim*12, node_dim)
    self.delta = 2.5749

  def aggregate(self, inputs, index):
    sums = torch_scatter.scatter_add(inputs, index, dim=0)
    maxs = torch_scatter.scatter_max(inputs, index, dim=0)
    means = torch_scatter.scatter_mean(inputs, index, dim=0)
    var = torch.relu(torch_scatter.scatter_mean(inputs**2, index, dim-0)-means**2)

    aggrs = [sums, maxs, means, var]
    c_idx = index.bincount().float().view(-1, 1)
    l_idx = torch.log(c_idx + 1.)

    amplication_scaler = [c_idx/self.delta*a for a in aggrs]
    attenuation_scaler = [self.data/c_idx * a for a in aggrs]
    combinations = torch.cat(aggrs+amplication_scaler+attenuation_scaler, dim=1)
    x = self.mlp(combinations)
    return x

### Test for new classes

In [8]:
from torch_geometric.nn import MessagePassing, SAGEConv, GINConv, global_add_pool
import torch_scatter
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
import os.path as osp

In [9]:
path = osp.join('./', 'data', 'TU')
dataset = TUDataset(path, name='MUTAG').shuffle()
test_dataset = dataset[:len(dataset)//10]
train_dataset = dataset[len(dataset)//10:]
test_loader = DataLoader(test_dataset, batch_size =128)
train_loader = DataLoader(train_dataset, batch_size=128)


Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Processing...
Done!


In [22]:
class LAFNet(torch.nn.Module):
  def __init__(self):
    super(LAFNet, self).__init__()

    num_features = dataset.num_features
    dim = 32
    units = 3

    nn1 = Sequential(
        Linear(num_features, dim), ReLU(), Linear(dim,dim)
    )
    self.conv1 = GINLAFConv(nn1, units=units, node_dim=num_features)
    self.bn1 = torch.nn.BatchNorm1d(dim)

    nn2 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv2 = GINLAFConv(nn2, units=units, node_dim= dim)
    self.bn2 = torch.nn.BatchNorm1d(dim)

    nn3 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv3 = GINLAFConv(nn3, units=units, node_dim=dim)
    self.bn3 = self.BatchNorm1d(dim)

    nn4 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv4 = GINLAFConv(nn4, units=units, node_dim=dim)
    self.bn4 = self.BatchNorm1d(dim)

    nn5 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv5 = GINLAFConv(nn5, units=units, node_dim=dim)
    self.bn5 = self.BatchNorm1d(dim)

    self.fc1 = Linear(dim, dim)
    self.fc2 = Linear(dim, dataset.num_classes)

  def forward(self, x, edge_index, batch):
    x = F.relu(self.conv1(x, edge_index))
    x = self.bn1(x)
    x = F.relu(self.conv2(x, edge_index))
    x = self.bn2(x)
    x = F.relu(self.conv3(x, edge_index))
    x = self.bn3(x)
    x = F.relu(self.conv4(x, edge_index))
    x = self.bn4(x)
    x = F.relu(self.conv5(x, edge_index))
    x = self.bn5(x)
    x = global_add_pool(x, batch)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, p=0.5, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)


In [11]:
class PNANet(torch.nn.Module):
  def __init(self):
    super(PNANet, self).__init__()

    num_features = dataset.num_features
    dim = 32

    nn1 = Sequential(
        Linear(num_features, dim), ReLU(), Linear(dim,dim)
    )
    self.conv1 = GINPNAConv(nn1, node_dim=num_features)
    self.bn1 = torch.nn.BatchNorm1d(dim)

    nn2 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv2 = GINPNAConv(nn2, node_dim=dim)
    self.bn2 = torch.nn.BatchNorm1d(dim)

    nn3 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv3 = GINPNAConv(nn3, node_dim=dim)
    self.bn3 = self.BatchNorm1d(dim)

    nn4 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv4 = GINPNAConv(nn4, node_dim=dim)
    self.bn4 = self.BatchNorm1d(dim)

    nn5 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv5 = GINPNAConv(nn5, node_dim=dim)
    self.bn5 = self.BatchNorm1d(dim)

    self.fc1 = Linear(dim, dim)
    self.fc2 = Linear(dim, dataset.num_classes)

  def forward(self, x, edge_index, batch):
    x = F.relu(self.conv1(x, edge_index))
    x = self.bn1(x)
    x = F.relu(self.conv2(x, edge_index))
    x = self.bn2(x)
    x = F.relu(self.conv3(x, edge_index))
    x = self.bn3(x)
    x = F.relu(self.conv4(x, edge_index))
    x = self.bn4(x)
    x = F.relu(self.conv5(x, edge_index))
    x = self.bn5(x)
    x = global_add_pool(x, batch)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, p=0.5, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)


In [12]:
class GINNet(torch.nn.Module):
  def __init__(self):
    super(GINNet, self).__init__()

    num_features = dataset.num_features
    dim = 32

    nn1 = Sequential(
        Linear(num_features, dim), ReLU(), Linear(dim,dim)
    )
    self.conv1 = GINConv(nn1)
    self.bn1 = torch.nn.BatchNorm1d(dim)

    nn2 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv2 = GINConv(nn2)
    self.bn2 = torch.nn.BatchNorm1d(dim)

    nn3 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv3 = GINConv(nn3)
    self.bn3 = self.BatchNorm1d(dim)

    nn4 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv4 = GINConv(nn4)
    self.bn4 = self.BatchNorm1d(dim)

    nn5 = Sequential(
        Linear(dim, dim), ReLU(), Linear(dim, dim)
    )
    self.conv5 = GINConv(nn5)
    self.bn5 = self.BatchNorm1d(dim)

    self.fc1 = Linear(dim, dim)
    self.fc2 = Linear(dim, dataset.num_classes)

  def forward(self, x, edge_index, batch):
    x = F.relu(self.conv1(x, edge_index))
    x = self.bn1(x)
    x = F.relu(self.conv2(x, edge_index))
    x = self.bn2(x)
    x = F.relu(self.conv3(x, edge_index))
    x = self.bn3(x)
    x = F.relu(self.conv4(x, edge_index))
    x = self.bn4(x)
    x = F.relu(self.conv5(x, edge_index))
    x = self.bn5(x)
    x = global_add_pool(x, batch)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, p=0.5, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)


In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# net = 'PNA'

# if net == 'LAF':
#   model = LAFNet().to(device)
if net == 'PNA':
  model = PNANet().to(device)
else:
  model = GINNet().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def  train(epoch):
  model.train()

  if epoch == 51:
    for param_group in optimizer.param_groups:
      param_group['lr'] = 0.5*param_group['lr']

  loss_all = 0
  for data in train_loader:
    data = data.to(device)
    optimizer.zero_grad()
    output = model(data.x, data.edge_index, data.batch)
    loss = F.nll_loss(output, data.y)
    loss.backward()
    loss_all += loss.item()*data.num_graphs
    optimizer.step()

  return loss_all/len(train_dataset)

def test(loader):
  model.eval()
  correct = 0
  for data in loader:
    data = data.to(device)
    output = model(data.x, data.edge_index, data.batch)
    pred = output.max(dim=1)[1]
    correct += pred.eg(data.y).sum().item()

  return correct/len(loader.dataset)

for epoch in range(1, 101):
  train_loss = train(epoch)
  train_acc = test(train_loader)
  test_acc = test(test_loader)
  print("Epoch: {:03d}, Train Loss: {:.7f}, " "Train Acc: {:.7f}, Test Acc: {:.7f}". format(epoch, train_loss, train_acc, test_acc))

AttributeError: 'GINNet' object has no attribute 'BatchNorm1d'