## Description

##### Implementation of "Deep Residual Learning for Image Recognition" paper - https://iopscience.iop.org/article/10.1088/1742-6596/1871/1/012071/meta

### Libraries

In [None]:
import tensorflow as tf

### Layers

In [None]:
class GraphAttention(tf.keras.layers.Layer):
    def __init__(self, num_heads:int, hidden_size:int, max_neighbors:int):
        super(GraphAttention, self).__init__()
        self.hidden_size = hidden_size
        self.max_neighbors = max_neighbors
        self.num_heads = num_heads
        # trainable alignment matrix for general-style self-attention
        self.attn_ws = [tf.keras.layers.Dense(hidden_size) for _ in range(num_heads)]
        # output layer
        self.out_w = tf.keras.layers.Dense(hidden_size)


    def call(self, inputs: dict, debug=False):
        """Calculate scores of all neighboring nodes using general Luong-style
        self-attention. If multiple attention heads are used, contexts are
        concatenated.

        # Inputs:
            value: the vector relating to the node to encode. Should be of size:
              [batch_size, 1, hidden_size]
            query: a tensor providing embeddings for the neighbors of the node
              which are to be attended to. Each query should be of size:
              [batch_size, max_neighbors, hidden_size]
            num_neighbors: the number of neighbors to attend to from the query
              tensor. Only the first `max_neighbors` nodes are attended to.
        """

        value = inputs['value']
        query = inputs['query']
        num_neighbors = inputs['num_neighbors']

        assert value.shape[1] == 1, f'second dim of value should be 1, but was {value.shape[1]}'
        assert query.shape[1] == self.max_neighbors, f'second dim of query should equal max_neighbors, but was {query.shape[1]}'
        assert num_neighbors < self.max_neighbors, f'num_neighbors input of {num_neighbors} cannot be greater than max neighbors'

        # aggregate features from all neighbors, including the node itself
        query = tf.concat([value, query], axis=1)

        # multi-head self-attention
        contexts = []
        for i in range(self.num_heads):
            query = self.attn_ws[i](query)
            e = tf.matmul(value, query, transpose_b=True)
            e = tf.nn.swish(e)
            # apply mask before softmaxing
            mask = tf.sequence_mask(num_neighbors + 1, maxlen=self.max_neighbors)[:, tf.newaxis]
            e = tf.where(mask, e, tf.ones_like(e) * -1e9)
            scores = tf.nn.softmax(e)
            # sum all query embeddings according to attention scores
            context = tf.matmul(scores, query)
            contexts.append(context)

        # concatenate contexts from each attention head
        contexts = tf.concat(contexts, axis=-1)

        # produce new features from full context
        x = self.out_w(contexts)
        x = tf.nn.relu(x)
        return x

In [None]:
value = tf.random.normal([8, 1, 512])
query = tf.random.normal([8, 10, 512])

In [None]:
value

In [None]:
g_attn = GraphAttention(num_heads=4, hidden_size=512, max_neighbors=10)
x = g_attn({
  'query': query,
  'value': value,
  'num_neighbors': 5
})

print(f'Input shape: {value.shape}')
print(f'Output shape: {x.shape}')

### GNN Block

In [None]:
class GraphAttentionNetwork(tf.Module):
    
    def __init__(self,
                num_heads: int,
                hidden_sizes: int,
                num_neighbors: int,
                max_neighbors: int):
        super(GraphAttentionNetwork, self)
        
        self.num_neighbors = num_neighbors
        
        self.gat1 = GraphAttention(num_heads, hidden_sizes, max_neighbors)
        self.gat2 = GraphAttention(num_heads, hidden_sizes, max_neighbors)
        self.gat3 = GraphAttention(num_heads, hidden_sizes, max_neighbors)
        self.gat4 = GraphAttention(num_heads, hidden_sizes, max_neighbors)
        
        self.flatten = tf.keras.layers.Flatten()
        self.mlp = tf.keras.layers.Dense(4) # 4 classification

    def __call__(self, inputs: dict):
        
        x = self.gat1(inputs=inputs)
        temp_1 = x
        
        inputs_2 = {
            'query': x,
            'value': inputs['value'],
            'num_neighbors': self.num_neighbors
        }
        x = self.gat2(inputs_2)
        temp_2 = x
        
        inputs_3 = {
            'query': x,
            'value': inputs['value'],
            'num_neighbors': self.num_neighbors
        }
        x = self.gat3(inputs_3)
        temp_3 = x
        
        x = tf.add(temp_1, temp_3)
        
        inputs_4 = {
            'query': x,
            'value': inputs['value'],
            'num_neighbors': self.num_neighbors
        }
        x = self.gat4(inputs_4)
        temp4 = x
        
        x = tf.concat([temp_1, temp_2, temp_3, x], 1) # concat on axis 1, axis 0 increases the batch
        # dimensions which is not desired
        
        x = self.flatten(x) # flatten before passing in MLP
        
        x = self.mlp(x)
        
        x = tf.nn.softmax(x)
        
        print(x.shape)  
        
        return x

In [None]:
gat_network = GraphAttentionNetwork(num_heads=4, hidden_sizes=512, num_neighbors=5, max_neighbors=10)

In [None]:
in_test = {
  'query': query,
  'value': value,
  'num_neighbors': 5
}

In [None]:
out_test = gat_network(inputs=in_test)

In [None]:
out_test

### TODO:
- figure out what num_heads is
- figure out how all the parameters relate to the graph representation of an image