# Graph Attention Network

(Multi-head) [graph attention layer](https://keras.io/examples/graph/gat_node_classification/)
The GAT model implements multi-head graph attention layers. The MultiHeadGraphAttention layer is simply a concatenation (or averaging) of multiple graph attention layers (GraphAttention), each with separate learnable weights W. The GraphAttention layer does the following:

Consider inputs node states $h^{l}$ which are linearly transformed by $W^{l}$, resulting in $z^{l}$.

For each target node:

Computes pair-wise attention scores $a^{l/T}(z^{l}_{i}||z^{l}_{j})$ for all $j$, resulting in $e_{ij}$ (for all $j$). $||$ denotes a concatenation, $_{i}$ corresponds to the target node, and $_{j}$ corresponds to a given 1-hop neighbor/source node.
Normalizes $e_{ij}$ via softmax, so as the sum of incoming edges' attention scores to the target node $(sum_{k}{e_{norm/ik}})$ will add up to 1.
Applies attention scores $e_{norm/ij}$ to $z_{j}$ and adds it to the new target node state $h^{l+1}_{i}$, for all $j$.

![avatar](img/gat.PNG)

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Layer
import numpy as np
import pandas as pd
import os
import warnings

In [None]:
# * 探索代码.

In [2]:
LEARNING_RATE = 3e-1
MOMENTUM = 0.9

optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)

In [5]:
zip_file = keras.utils.get_file(
    fname="cora.tgz",
    origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
    extract=True,
)

data_dir = os.path.join(os.path.dirname(zip_file), "cora")

citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)

papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"),
    sep="\t",
    header=None,
    names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)

class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

print(citations)

print(papers)

Downloading data from https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
      target  source
0          0      21
1          0     905
2          0     906
3          0    1909
4          0    1940
...      ...     ...
5424    1873     328
5425    1873    1876
5426    1874    2586
5427    1876    1874
5428    1897    2707

[5429 rows x 2 columns]
      paper_id  term_0  term_1  term_2  term_3  term_4  term_5  term_6  \
0          462       0       0       0       0       0       0       0   
1         1911       0       0       0       0       0       0       0   
2         2002       0       0       0       0       0       0       0   
3          248       0       0       0       0       0       0       0   
4          519       0       0       0       0       0       0       0   
...        ...     ...     ...     ...     ...     ...     ...     ...   
2703      2370       0       0       0       0       0       0       0   
2704      2371       0       0       0       0       0   

In [6]:
# Obtain random indices
random_indices = np.random.permutation(range(papers.shape[0]))

# 50/50 split
train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]

In [7]:
# Obtain paper indices which will be used to gather node states
# from the graph later on when training the model
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()

# Obtain ground truth labels corresponding to each paper_id
train_labels = train_data["subject"].to_numpy()
test_labels = test_data["subject"].to_numpy()

# Define graph, namely an edge tensor and a node feature tensor
edges = tf.convert_to_tensor(citations[["target", "source"]])
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])

# Print shapes of the graph
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)

Edges shape:		 (5429, 2)
Node features shape: (2708, 1433)


In [8]:
# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = len(class_values)

NUM_EPOCHS = 100
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.1
LEARNING_RATE = 3e-1
MOMENTUM = 0.9

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
early_stopping = keras.callbacks.EarlyStopping(
    monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
)

# Build model
gat_model = GraphAttentionNetwork(
    node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)

# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])

gat_model.fit(
    x=train_indices,
    y=train_labels,
    validation_split=VALIDATION_SPLIT,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    callbacks=[early_stopping],
    verbose=2,
)

_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)

print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")

Epoch 1/100
5/5 - 13s - loss: 1.8701 - acc: 0.2644 - val_loss: 1.6706 - val_acc: 0.3897 - 13s/epoch - 3s/step
Epoch 2/100
5/5 - 2s - loss: 1.3011 - acc: 0.5608 - val_loss: 1.1037 - val_acc: 0.6250 - 2s/epoch - 402ms/step
Epoch 3/100
5/5 - 2s - loss: 0.7499 - acc: 0.7775 - val_loss: 0.8104 - val_acc: 0.7426 - 2s/epoch - 422ms/step
Epoch 4/100
5/5 - 2s - loss: 0.4483 - acc: 0.8711 - val_loss: 0.7907 - val_acc: 0.7500 - 2s/epoch - 403ms/step
Epoch 5/100
5/5 - 2s - loss: 0.2606 - acc: 0.9228 - val_loss: 0.7132 - val_acc: 0.7868 - 2s/epoch - 379ms/step
Epoch 6/100
5/5 - 2s - loss: 0.1734 - acc: 0.9458 - val_loss: 0.7998 - val_acc: 0.7647 - 2s/epoch - 377ms/step
Epoch 7/100
5/5 - 2s - loss: 0.1009 - acc: 0.9787 - val_loss: 0.7493 - val_acc: 0.8015 - 2s/epoch - 384ms/step
Epoch 8/100
5/5 - 2s - loss: 0.0657 - acc: 0.9893 - val_loss: 0.7791 - val_acc: 0.7794 - 2s/epoch - 421ms/step
Epoch 9/100
5/5 - 2s - loss: 0.0414 - acc: 0.9967 - val_loss: 0.8241 - val_acc: 0.7868 - 2s/epoch - 395ms/step
Ep

# Code Repo

In [3]:
class GraphAttention(Layer):
    def __init__(
        self,
        units,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):

        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        self.built = True

    def call(self, inputs):
        node_states, edges = inputs

        # Linearly transform node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores
        node_states_expanded = tf.gather(node_states_transformed, edges)
        node_states_expanded = tf.reshape(
            node_states_expanded, (tf.shape(edges)[0], -1)
        )
        attention_scores = tf.nn.leaky_relu(
            tf.matmul(node_states_expanded, self.kernel_attention)
        )
        attention_scores = tf.squeeze(attention_scores, -1)

        # (2) Normalize attention scores
        attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=edges[:, 0],
            num_segments=tf.reduce_max(edges[:, 0]) + 1,
        )
        attention_scores_sum = tf.repeat(
            attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
        )
        attention_scores_norm = attention_scores / attention_scores_sum

        # (3) Gather node states of neighbors, apply attention scores and aggregate
        node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
        out = tf.math.unsorted_segment_sum(
            data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
            segment_ids=edges[:, 0],
            num_segments=tf.shape(node_states)[0],
        )
        return out


class MultiHeadGraphAttention(Layer):
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        atom_features, pair_indices = inputs

        # Obtain outputs from each attention head
        outputs = [
            attention_layer([atom_features, pair_indices])
            for attention_layer in self.attention_layers
        ]
        # Concatenate or average the node states from each head
        if self.merge_type == "concat":
            outputs = tf.concat(outputs, axis=-1)
        else:
            outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
        # Activate and return node states
        return tf.nn.relu(outputs)

In [4]:
class GraphAttentionNetwork(Model):
    def __init__(
        self,
        node_states,
        edges,
        hidden_units,
        num_heads,
        num_layers,
        output_dim,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = Dense(hidden_units * num_heads, activation="relu")
        self.attention_layers = [
            MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
        ]
        self.output_layer = Dense(output_dim)

    def call(self, inputs):
        node_states, edges = inputs
        x = self.preprocess(node_states)
        for attention_layer in self.attention_layers:
            x = attention_layer([x, edges]) + x
        outputs = self.output_layer(x)
        return outputs

    def train_step(self, data):
        indices, labels = data

        with tf.GradientTape() as tape:
            # Forward pass
            outputs = self([self.node_states, self.edges])
            # Compute loss
            loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Compute gradients
        grads = tape.gradient(loss, self.trainable_weights)
        # Apply gradients (update weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        indices = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute probabilities
        return tf.nn.softmax(tf.gather(outputs, indices))

    def test_step(self, data):
        indices, labels = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute loss
        loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}