# GCN

In [1]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch.nn.init import xavier_uniform_, zeros_
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops, degree, softmax

In [2]:
num_nodes = 4 # number of nodes in the graph
embed_size = 5 # initial node features

# Creating node features
node_feat = torch.rand((num_nodes, embed_size), dtype=torch.float)
x = node_feat
print('Node features\n', node_feat, node_feat.shape)

# Creating COO format edge_indexes for the graph
src_index = torch.tensor([0,0,0,1,1,2,3,3], dtype=torch.long)
target_index = torch.tensor([1,2,3,3,0,0,0,2], dtype=torch.long)
edge_index = torch.zeros(size=(2, src_index.size()[0]), dtype=torch.int64)
edge_index[0] = src_index
edge_index[1] = target_index
print('\nedge_index\n', edge_index, edge_index.shape)

# Creating edge features
x_edge = torch.rand((edge_index.shape[1], 5), dtype=torch.float) # (num_edges, embed_size)
print('\nEdge features\n', x_edge, x_edge.shape)

Node features
 tensor([[0.1339, 0.6032, 0.3207, 0.7724, 0.3698],
        [0.7954, 0.7916, 0.1310, 0.0559, 0.8457],
        [0.5854, 0.0704, 0.9769, 0.9597, 0.7408],
        [0.7319, 0.3205, 0.5486, 0.9603, 0.5263]]) torch.Size([4, 5])

edge_index
 tensor([[0, 0, 0, 1, 1, 2, 3, 3],
        [1, 2, 3, 3, 0, 0, 0, 2]]) torch.Size([2, 8])

Edge features
 tensor([[0.2021, 0.4291, 0.1266, 0.4451, 0.1104],
        [0.3644, 0.1243, 0.6007, 0.1617, 0.0896],
        [0.8721, 0.3000, 0.7694, 0.2707, 0.6138],
        [0.8695, 0.2297, 0.7972, 0.4406, 0.9780],
        [0.7250, 0.0640, 0.8526, 0.1394, 0.9993],
        [0.4752, 0.5227, 0.4312, 0.0431, 0.8705],
        [0.5375, 0.9636, 0.4342, 0.2797, 0.7332],
        [0.3975, 0.5066, 0.2974, 0.2837, 0.9275]]) torch.Size([8, 5])


In [3]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add') # aggregation 
        self.lin = torch.nn.Linear(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        # add self loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # initial feature transform
        x = self.lin(x)
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
    
    def message(self, x_j, edge_index, size):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
        # print(norm.view(-1, 1))
        return norm.view(-1, 1) * x_j
    
    def update(self, aggr_out):
        return aggr_out

In [4]:
gcn = GCNConv(embed_size, 12)
print('Forward pass')
res = gcn.forward(x, edge_index)
print('\nres\n', res, res.shape)

Forward pass

res
 tensor([[-0.2242, -0.3125,  0.8685, -0.8613, -0.2454,  0.3702, -0.8251,  0.1788,
         -0.0387, -0.1061,  0.2458,  0.0949],
        [-0.1680, -0.0349,  0.3368, -0.3738, -0.0404,  0.1945, -0.3948,  0.1143,
         -0.1057,  0.0622,  0.1184,  0.0454],
        [-0.1638, -0.4710,  1.0498, -0.9654, -0.3504,  0.4334, -0.8794,  0.0887,
          0.1200, -0.1498,  0.2659,  0.1456],
        [-0.2308, -0.1513,  0.6065, -0.6451, -0.1382,  0.3111, -0.6684,  0.1496,
         -0.1025, -0.0131,  0.2315,  0.0826]], grad_fn=<ScatterAddBackward0>) torch.Size([4, 12])


# GraphSAGE

In [5]:
class GraphSAGE(MessagePassing):
    def __init__(self, in_channels, out_channels, reducer='mean', normalize_embedding=True):
        super(GraphSAGE, self).__init__(aggr='mean') # Aggregate
        self.aggr_lin = torch.nn.Linear(in_channels * 2, out_channels)
        
        if normalize_embedding:
            self.normalize = True
            
    def forward(self, x, edge_index):
        num_nodes = x.size(0)
        
        return self.propagate(edge_index, size=(num_nodes, num_nodes), x=x)
    
    def message(self, x_j, edge_index, size):
        return x_j
    
    def update(self, aggr_out):
        # Concate and transform
        concat_out = torch.cat((x, aggr_out), dim=1)
        aggr_out = self.aggr_lin(concat_out)
        aggr_out = F.relu(aggr_out)
        
        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=1)
        # print(aggr_out.shape)
       
        return aggr_out

In [6]:
graphsage = GraphSAGE(embed_size, 12)
res = graphsage.forward(x, edge_index)
print('\nres\n', res, res.shape)


res
 tensor([[0.1311, 0.2816, 0.2891, 0.2310, 0.8252, 0.0000, 0.0000, 0.1016, 0.1328,
         0.0000, 0.2400, 0.0000],
        [0.2585, 0.0000, 0.0000, 0.2647, 0.8878, 0.0000, 0.0000, 0.2189, 0.0619,
         0.1523, 0.0000, 0.0000],
        [0.1815, 0.3459, 0.2146, 0.2531, 0.6026, 0.0000, 0.0000, 0.4979, 0.2567,
         0.0000, 0.2456, 0.0000],
        [0.1105, 0.0880, 0.2681, 0.3348, 0.7681, 0.0000, 0.0000, 0.4339, 0.0354,
         0.0000, 0.1284, 0.0000]], grad_fn=<DivBackward0>) torch.Size([4, 12])


# GAT

In [7]:
class GAT(MessagePassing):
    def __init__(self, in_channels, out_channels, num_heads=1, concat=True ,dropout=0, bias=True, **kwargs):
        super(GAT, self).__init__(aggr='add', **kwargs)
        
        self.in_channels = in_channels
        self.out_channels = int(out_channels / num_heads) # out_channels must be multiplication of num_heads
        self.heads = num_heads
        self.concat = concat
        self.dropout = dropout
        
        self.lin = Linear(self.in_channels, self.out_channels * num_heads) # Linear transformation
        self.att = Parameter(torch.Tensor(1, self.heads, self.out_channels * 2)) # Learnable attention matrix
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(self.heads * self.out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(self.out_channels))
        else:
            self.register_parameter('bias', None)
            
        xavier_uniform_(self.att)
        zeros_(self.bias)
    
    
    def forward(self, x, edge_index, size=None):
        # Adding self-loops for the graph...
        if size is None and torch.is_tensor(x):
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        
        x = self.lin(x) # linear transformation
        
        return self.propagate(edge_index, size=size, x=x)
    
    def message(self, edge_index_i, x_i, x_j, size_i):
        # Compute attention coefficient
        x_i = x_i.view(-1, self.heads, self.out_channels) # split hidden features into multi-heads
        x_j = x_j.view(-1, self.heads, self.out_channels)
        
        # Concate source and target node hidden features
        # Compute cosine similarity on the last axis: inner product
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, 0.2)
        print(alpha.shape)
        print(edge_index_i)
        print(size_i)
        
        # Softmax: will call 'scatter_add' internaly
        alpha = softmax(alpha, edge_index_i, num_nodes=size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        print(alpha.view(-1, self.heads, 1).shape)
        print(x_j.shape)
        print('tmp', (x_j * torch.unsqueeze(alpha, dim=-1)).shape)
        
        # return x_j * alpha.view(-1, self.heads, 1) # weighted input: (alpha * src nodes features)
        return x_j * torch.unsqueeze(alpha, dim=-1)
        
    def update(self, aggr_out):
        print('aggr', aggr_out.shape)
        if self.concat is True:
            aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
        else:
            aggr_out = aggr_out.mean(dim=1) # For the last layer, just aggregation
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        
        return aggr_out

In [8]:
gat = GAT(embed_size, 100)
gat.forward(x, edge_index)

torch.Size([12, 1])
tensor([1, 2, 3, 3, 0, 0, 0, 2, 0, 1, 2, 3])
4
torch.Size([12, 1, 1])
torch.Size([12, 1, 100])
tmp torch.Size([12, 1, 100])


RuntimeError: The expanded size of the tensor (1) must match the existing size (12) at non-singleton dimension 1.  Target sizes: [12, 1, 100].  Tensor sizes: [1, 12, 1]

In [9]:
from torch_geometric.nn import GATv2Conv

gat = GATv2Conv(5, 100)
res1, res2 = gat.forward(x, edge_index, return_attention_weights=True)
print(res1[0].shape, res2[0].shape)

torch.Size([100]) torch.Size([2, 12])


In [89]:
class GAT(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True ,dropout=0, bias=True, **kwargs):
        super(GAT, self).__init__(aggr='add', **kwargs)
        
        self.in_channels = in_channels
        self.out_channels = out_channels # out_channels must be multiplication of num_heads
        self.heads = heads
        self.concat = concat
        self.dropout = dropout
        
        # Linear transformations
        self.lin_src = Linear(in_channels, heads * out_channels, bias=False)
        self.lin_dst = Linear(in_channels, heads * out_channels, bias=False)
        self.lin = Linear(in_channels, heads * out_channels, bias=False)
        
        # The learnable parameters to compute attention coefficients:
        self.att_src = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_dst = Parameter(torch.Tensor(1, heads, out_channels))

        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(self.heads * self.out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(self.out_channels))
        else:
            self.register_parameter('bias', None)
            
        xavier_uniform_(self.att_src)
        xavier_uniform_(self.att_dst)
        zeros_(self.bias)
    
    
    def forward(self, x, edge_index, size=None):
        # Adding self-loops for the graph...
        if size is None and torch.is_tensor(x):
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        
        H, C = self.heads, self.out_channels
        
        x_src = torch.index_select(x, 0, edge_index[0]) # selecting source node features
        x_dst = torch.index_select(x, 0, edge_index[1]) # selecting target node features
        
        # Source and Target wise Linear Transformation
        x_src = self.lin_src(x_src).view(-1, H, C)
        x_dst = self.lin_src(x_dst).view(-1, H, C)
        
        x = (x_src, x_dst)
        # x = self.lin(x).view(-1, H, C)
        # print((self.lin(x).view(-1, H, C).shape))
        # print(x.shape)
        # print(x_src.shape)
        # x = torch.cat([x, x], dim=1)
        
        
        # Calcualting attention coefficients
        alpha_src = (x_src * self.att_src).sum(dim=-1)
        alpha_dst = (x_dst * self.att_dst).sum(dim=-1)
        alpha = (alpha_src, alpha_dst)
            
        alpha = self.edge_updater(edge_index, alpha=alpha)
        out = self.propagate(edge_index, x=x, size=size, alpha=alpha)
       
        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)
            
        if self.bias is not None:
            out = out + self.bias
        
        return out
        
    def edge_update(self, alpha_j, alpha_i, index, ptr, size_i):
        # This will calculate attetntion score for the whole graph...
        
        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
        if index.numel() == 0:
            return alpha
        
        alpha = F.leaky_relu(alpha, 0.2)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return alpha
    
    def message(self, x_j, alpha):
        # return alpha.unsqueeze(-1) * x_j
        # print((alpha * x_j).shape)
        return x_j * alpha
        

In [90]:
gat = GAT(embed_size, 10)
gat.forward(x, edge_index).shape

RuntimeError: INDICES element is out of DATA bounds, id=1 axis_dim=1

In [120]:
class GAT(MessagePassing):
    def __init__(
        self,
        in_channels, 
        out_channels,
        heads=1,
        concat=True,
        negative_slope=0.2,
        dropout=0.0,
        add_self_loops=True,
        bias=True,
        share_weights=False,
        **kwargs,
    ):
        super().__init__(node_dim=0, **kwargs)
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,)
        
        if share_weights:
            self.lin_r = self.lin_l
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)
            
        self.att = Parameter(torch.Tensor(1, heads, out_channels))
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self._alpha = None

        xavier_uniform_(self.att)
        zeros_(self.bias)        
        
    def forward(self, x, edge_index, return_attention_weights=None):
        H, C = self.heads, self.out_channels
        
        x_l = None
        x_r = None
        
        x_l = self.lin_l(x).view(-1, H, C)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)
            
        assert x_l is not None
        assert x_r is not None
        
        if self.add_self_loops:
            if torch.is_tensor(edge_index):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
                
        out = self.propagate(edge_index, x=(x_l, x_r), size=None)
        print(out.shape)
        
        alpha = self._alpha
        assert alpha is not None
        self._alpha = None
        
        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)
        
        if self.bias is not None:
            out = out + self.bias
            
        if isinstance(return_attention_weights, bool):
            if torch.is_tensor(edge_index):
                return out, (edge_index, alpha)
        else:   
            return out
    
    def message(self, x_j, x_i, index, ptr, size_i):
        x = x_i + x_j
        
        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        return x_j * alpha.unsqueeze(-1)
                

In [122]:
gat = GAT(embed_size, 10)
gat.forward(x, edge_index).shape

torch.Size([4, 1, 10])


torch.Size([4, 10])