# Layer Building

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, Parameter
from torch.nn.init import xavier_uniform_, zeros_
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, degree, softmax
from typing import Type


In [None]:
class EINv3(MessagePassing):
    """
    A Edge featured attention based Graph Neural Network Layer for Graph Classification / Regression Tasks: V3
    
    Notes:
        Fully Multi-head attention is implemented in this version, compared to previous versions where concatenation is ommited and mean of values are used.
        In this case, mean value of attention is used only with influence mechanism!
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            heads=1,
            negative_slope=0.2,
            dropout=0.0,
            edge_dim=None,
            train_eps=False,
            eps=0.0,
            bias=True,
            share_weights=False,
            concat=True,
            **kwargs,
    ):
        super().__init__(node_dim=0, aggr='add', **kwargs)  # defines the aggregation method: `aggr='add'`

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.share_weights = share_weights
        self.edge_dim = edge_dim
        self.initial_eps = eps
        self.concat = concat

        # Linear Transformation
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)

        if share_weights:
            self.lin_r = self.lin_l  # use same matrix
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        # For attention calculation
        self.att = Parameter(torch.Tensor(1, heads, out_channels))

        # For influence mechanism
        self.inf = Linear(edge_dim, out_channels)

        # Tunable parameter for adding self node features...
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None  # alpha weights

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.inf.reset_parameters()
        self.eps.data.fill_(self.initial_eps)
        xavier_uniform_(self.att)
        zeros_(self.bias)

    def forward(self, x, edge_index, edge_attr=None, return_attention_weights=None):
        ## N - no_of_nodes, NH - no_of heads,  H_in - input_channels, H_out - out_channels

        H, C = self.heads, self.out_channels

        x_l = None  # for source nodes
        x_r = None  # for target nodes

        x_l = self.lin_l(x).view(-1, H, C)  # (N, H_in) -> (N, NH, H_Out)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        # Check the edge features shape: test_case
        # if edge_attr is not None:
        #     print(f'edge_features shape: {edge_attr.shape}')
        # else:
        #     print('No edge features!')

        # Start propagating info...: construct message -> aggregate message -> update/obtain new representations
        out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=None)  # (N, H_out)
        # out += x_r.mean(dim=1) # add the self features
        alpha = self._alpha  # (#edges, NH, H_out)
        assert alpha is not None, 'Alpha weights can not be None value!'

        if self.bias is not None:
            out = out + self.bias

        if self.concat == True: # (N,  NH * H_out), for alpha: (N, NH, H_out)
            if isinstance(return_attention_weights, bool):
                return out.view(-1, self.heads * self.out_channels), alpha
            return out.view(-1, self.heads * self.out_channels)
        
        # Taking the mean of heads  -> (N, H_out)
        if isinstance(return_attention_weights, bool):
            return out.mean(dim=1), alpha.mean(dim=1)

        return out.mean(dim=1)


    def message(self, x_j, x_i, index, size_i, edge_attr):
        # x_j has shape [#edges, NH, H_out]
        # x_i has shape [#edges, NH, H_out]
        # index: target node indexes, where data flows 'source_to_target': this is for computing softmax
        # size: size_i, size_j mean num_nodes in the graph

        x = x_i + x_j  # adding(element-wise) source and target node features together to calculate attention
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1) # (#edges, NH)
        alpha = softmax(alpha, index, num_nodes=size_i)  # spares softmax: groups node's attention and then node-wise softmax
        self._alpha = alpha  # (#edges, NH)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)  # randomly dropping attention during training
        node_out = x_j * alpha.unsqueeze(dim=-1)

        if self.inf is not None and edge_attr is not None:
            if self.edge_dim != edge_attr.size(-1):
                raise ValueError("Node and edge feature dimensionality do not "
                                 "match. Consider setting the 'edge_dim' ""attribute")
            edge_attr = self.inf(self._alpha.mean(dim=-1, keepdim=True) * edge_attr)  # transformed edge features via influence mechanism
            return node_out + edge_attr.unsqueeze(1)  # (#edges, H_out)
        return node_out  # (#edges, H_out)


    def update(self, aggr_out, x):
        aggr_out += (1 + self.eps) * x[1]  # add the self features with a weighting factor
        return aggr_out  # (N, H_out)


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')

In [None]:
x = torch.randn(10, 32)
x.shape

In [None]:
edge_attr = torch.randn(16, 8)
edge_attr.shape

In [None]:
edge_index = torch.randint(1, 10, (2, 16))

In [None]:
conv1 = EINv3(32, 100, 8, edge_dim=8)

In [None]:
out1, out2 = conv1(x, edge_index, edge_attr, return_attention_weights=True)
print(out1.shape, out2.shape)

In [None]:
conv2 = EINv3(out1.shape[-1], 300, 8, edge_dim=8, concat=False)

out3, out4 = conv2(out1, edge_index, edge_attr, return_attention_weights=True)
print(out3.shape, out4.shape)

In [None]:
t1 = torch.randn(40, 16, 50)
t2 = torch.randn(40, 50)

t3 = t1 + t2.unsqueeze(1)
t3.shape

# Model Building

In [None]:
import sys
sys.path.append('../../')
sys.path.append('../../Libs')
from models import *
from dataloaders import create_dataloaders
from train import train, run_experiment

In [None]:
args = {
    'dataset_name': 'MUTAG',
    'batch_size': 64
}

train_loader, val_loader, test_loader, metadata = create_dataloaders(args)

In [None]:
print(len(train_loader))
print(len(val_loader))
print(len(test_loader))

for data in train_loader:
    print(data)
    break

print(metadata)

In [None]:
import torch
from torch.nn import Linear, Parameter, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import MessagePassing, GCNConv, GATv2Conv, GINConv, GINEConv, global_mean_pool
from torch_geometric.utils import add_self_loops, remove_self_loops, degree, softmax

In [None]:
class EINModel_v3(torch.nn.Module):
    def __init__(self, input_dim, dim_h, final_dim, num_heads, edge_dim, **kwargs):
        super().__init__()
        torch.manual_seed(42)

        # Layers
        self.conv1 = EINv3(input_dim, dim_h,
                               edge_dim=edge_dim, heads=num_heads, **kwargs)
        self.conv2 = EINv3(dim_h * num_heads, dim_h,
                               edge_dim=edge_dim, heads=num_heads, **kwargs)
        self.conv3 = EINv3(dim_h * num_heads, dim_h, 
                               edge_dim=edge_dim, heads=num_heads, concat=False, **kwargs)

        # Linear layer
        self.lin1 = Linear(dim_h * 3, dim_h * 3)

        # Classification head
        self.lin2 = Linear(dim_h * 3, final_dim)

    def forward(self, x, edge_index, edge_attr, batch):
        # Embedding
        h1 = self.conv1(x, edge_index, edge_attr)
        h1 = h1.relu()
        h2 = self.conv2(h1, edge_index, edge_attr)
        h2 = h2.relu()
        h3 = self.conv3(h2, edge_index, edge_attr)
        h3 = h3.relu()

        C = h3.shape[-1]  # dim_h
        H = h2.shape[-1] // C  # num_heads

        # Graph-level readout
        h1 = global_mean_pool(h1.view(-1, H, C).mean(dim=1), batch)
        h2 = global_mean_pool(h2.view(-1, H, C).mean(dim=1), batch)
        h3 = global_mean_pool(h3, batch)

        h = torch.cat((h1, h2, h3), dim=1)

        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)

        return F.log_softmax(h, dim=1)

In [None]:
model = EINModel_v3(input_dim=7, dim_h=64, final_dim=2, edge_dim=4, num_heads=16, eps=1)
model

In [None]:
conv_tmp = EINv3(in_channels=7, out_channels=100, heads=16, edge_dim=4, concat=False)
for batch in train_loader:
    print(batch)
    res = conv_tmp(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    print(res.shape)
    break

In [None]:
for batch in train_loader:
    res = model.forward(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    print(res.shape)
    break