# Import necessary Libs

In [5]:
from typing import Optional, Tuple, Union
import os

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import NoneType
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax

In [6]:
import math
from typing import Any

import torch
from torch import Tensor

def glorot(value: Any):
    if isinstance(value, Tensor):
        stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1)))
        value.data.uniform_(-stdv, stdv)
        
    else: 
        for v in value.parameters() if hasattr(value, 'parameters') else[]:
            glorot(v)
        for v in value.buffers() if hasattr(value, 'buffers') else []:
            glorot(v)
            
def constant(value: Any, fill_value: float):
    if isinstance(value, Tensor):
        value.data.fill_(fill_value)
    
    else: 
        for v in value.parameters() if hasattr(value, 'parameters') else[]:
            constant(v, fill_value)
        for v in value.buffers() if hasattr(value, 'buffers') else []:
            constant(v, fill_value)
            
def zeros(value: Any):
    constant(value, 0.)

# Creating the Edge-featured Graph Attention Layer

In [7]:
class EFGAL(MessagePassing):
    """
    Args:
        in_channels (int or tuple): Size of each input sample
        
        out_channels (int): Size of each output sample
        
        heads (int, optional): Number of multi-head atttention
        
        concat (bool, optional):    If set to 'False', the multi-head
                                    attentions are average instead of concatenated (default: 'True')
            
        negative_slope (float, optional):   LeakyReLU angle of the negative
                                            slope (default: '0.2')
                                            
        dropout (float, optional):  Dropout probability of the normalized
                                    attention coefficients which exposes each node to to a stochastically
                                    sampled neighbourhood during training (default: '0')
            
        add_self_loops (bool, optional):    If set to 'False', will not add
                                            self loops to the input graph. (default: 'True')
            
        edge_dim (int, optional): Edge feature dimensionality (if there any), default('None')
        
        fill_value (float or Tensor or str, optional):  The way to generate 
                                                        edge features of self-loops (in case 'edge_dim' != None).
            
        bias (bool, optional):  If set to 'False', the layer wil not learn
                                an additive bias (default: 'True')
            
        **kwargs (optional):    Additional arguements of
                                :class 'torch_geometric.nn.conv.MessgePassing'
    """
    
    def __init__(self,
                in_channels: Union[int, Tuple[int, int]],
                out_channels: int,
                heads: int = 1,
                concat: bool = True,
                negative_slope: float = 0.2,
                dropout: float = 0.0,
                add_self_loops: bool = True,
                edge_dim: Optional[int] = None,
                fill_value: Union[float, Tensor, int] = 'mean',
                bias: bool = True,
                **kwargs,
                ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.edge_dim = edge_dim,
        self.fill_value = fill_value
        
        # Bipartite graphs -> seperate transformations 'lin_src' and 'lin_dst' to source  and target nodes:
        if isinstance(in_channels, int):
            self.lin_src = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot')
            self.lin_dst = self.lin_src
            
        else:
            self.lin_src = Linear(in_channels[0], heads * out_channels, bias=False, weight_initializer='glorot')
            self.lin_dst = Linear(in_channels[1], heads * out_channels, bais=False, weight_initializer='glorot')
            
        # Learnable parameters to compute attention coefficients
        self.att_src = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_dst = Parameter(torch.Tensor(1, heads, out_channels))
        
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, weight_initializer='glorot')
            self.att_edge = Parameter(torch.Tensor(1, heads, out_channels))
        else:
            self.lin_edge = None
            self.register_parameter('att_edge', None)
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
    