In [102]:
from random import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
import warnings


warnings.filterwarnings('ignore')



In [103]:
zip_file = keras.utils.get_file(
    fname="cora.tgz",
    origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
    extract=True,
)

data_dir = os.path.join(os.path.dirname(zip_file), "cora")

citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)

papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"),
    sep="\t",
    header=None,
    names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)

class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

print(citations)

print(papers)


      target  source
0          0      21
1          0     905
2          0     906
3          0    1909
4          0    1940
...      ...     ...
5424    1873     328
5425    1873    1876
5426    1874    2586
5427    1876    1874
5428    1897    2707

[5429 rows x 2 columns]
      paper_id  term_0  term_1  term_2  term_3  term_4  term_5  term_6  \
0          462       0       0       0       0       0       0       0   
1         1911       0       0       0       0       0       0       0   
2         2002       0       0       0       0       0       0       0   
3          248       0       0       0       0       0       0       0   
4          519       0       0       0       0       0       0       0   
...        ...     ...     ...     ...     ...     ...     ...     ...   
2703      2370       0       0       0       0       0       0       0   
2704      2371       0       0       0       0       0       0       0   
2705      2372       0       0       0       0       0   

In [104]:
np.random.seed(2)
random_indices = np.random.permutation(range(papers.shape[0]))

train_data = papers.iloc[random_indices[: len(random_indices)//2]]
test_data = papers.iloc[random_indices[len(random_indices)//2:]]


In [105]:
train_indices = train_data['paper_id'].to_numpy()
test_indices = test_data['paper_id'].to_numpy()

train_labels = train_data['subject'].to_numpy()
test_labels = test_data['subject'].to_numpy()

edges = tf.convert_to_tensor(citations[['target', 'source']])
node_states = tf.convert_to_tensor(papers.sort_values('paper_id').iloc[:, 1: -1])


In [106]:
class GraphAttention(layers.Layer):
    def __init__(
        self,
        units,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):

        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        self.built = True

    def call(self, inputs):
        node_states, edges = inputs

        # Linearly transform node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores
        node_states_expanded = tf.gather(node_states_transformed, edges)
        node_states_expanded = tf.reshape(
            node_states_expanded, (tf.shape(edges)[0], -1)
        )
        attention_scores = tf.nn.leaky_relu(
            tf.matmul(node_states_expanded, self.kernel_attention)
        )
        attention_scores = tf.squeeze(attention_scores, -1)

        # (2) Normalize attention scores
        attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))


        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=edges[:, 0],
            num_segments=tf.reduce_max(edges[:, 0]) + 1,
        )

        segment_ids = tf.cast(edges[:, 0], tf.int32)
        expanded_attention_scores_sum = tf.gather(attention_scores_sum, segment_ids)
        attention_scores_norm = attention_scores / expanded_attention_scores_sum


        # (3) Gather node states of neighbors, apply attention scores and aggregate
        node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
        out = tf.math.unsorted_segment_sum(
            data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
            segment_ids=edges[:, 0],
            num_segments=tf.shape(node_states)[0],
        )
        return out




In [107]:
class MultiHeadGraphAttention(layers.Layer):
    def __init__(
        self,
        units,
        num_heads,
        merge_type='concat',
        **kwargs
    ):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        features, indices = inputs
        outputs = [
            attention_layer([features, indices]) for attention_layer in self.attention_layers
        ]

        if self.merge_type == 'concat':
            outputs = tf.concat(outputs, axis=-1)
        else:
            outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)

        return tf.nn.relu(outputs)



In [108]:
class GraphAttentionNetwork(keras.Model):
    def __init__(
            self,
            node_states,
            edges,
            hidden_units,
            num_heads,
            num_layers,
            output_dim,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = layers.Dense(hidden_units * num_heads, activation='relu')
        self.attention_layers = [
            MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
        ]
        self.output_layer = layers.Dense(output_dim)

    def call(self, inputs):
        node_states, edges = inputs
        x = self.preprocess(node_states)
        for attention_layer in self.attention_layers:
            x = attention_layer([x, edges]) + x
        output = self.output_layer(x)
        return output

    def train_step(self, data):
        indices, labels = data

        with tf.GradientTape() as tape:
             outputs = self([self.node_states, self.edges])
             loss = self.compiled_loss(labels, tf.gather(outputs, indices))

        grads = tape.gradient(loss, self.trainable_weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        indices = data
        output = self([self.node_states, self.edges])
        probs = tf.nn.softmax(output, axis=-1)
        return tf.gather(probs, indices)

    def test_step(self, data):
        indices, labels = data
        outputs = self([self.node_states, self.edges])
        loss = self.compiled_loss(labels, tf.gather(outputs, indices))

        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}


In [109]:
hidden_units = 128
num_heads = 8
num_layers = 3
output_dim = len(class_values)

num_epochs = 100
batch_size = 256
validation_split = 0.1
learning_rate = 3e-1
momentum = 0.9

In [110]:
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.SGD(learning_rate, momentum=momentum)
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="accuracy")
early_stopping = keras.callbacks.EarlyStopping(
    monitor="val_accuracy",
    min_delta=1e-5,
    patience=5,
    restore_best_weights=True
)

In [111]:
model = GraphAttentionNetwork(
    node_states,
    edges,
    hidden_units,
    num_heads,
    num_layers,
    output_dim,
)

In [112]:
model.compile(
    loss=loss_fn,
    optimizer=optimizer,
    metrics=[accuracy_fn],
)

In [113]:
train_indices.shape

(1354,)

In [114]:
train_labels.shape

(1354,)

In [115]:
model.fit(
    x=train_indices,
    y=train_labels,
    validation_split=validation_split,
    batch_size=batch_size,
    epochs=num_epochs,
    callbacks=[early_stopping],
    verbose=2,
)

Epoch 1/100
5/5 - 41s - 8s/step - accuracy: 0.3456 - loss: 0.1281 - val_loss: 0.0997
Epoch 2/100
5/5 - 0s - 55ms/step - accuracy: 0.5956 - loss: 0.0783 - val_loss: 0.0744
Epoch 3/100
5/5 - 0s - 30ms/step - accuracy: 0.7647 - loss: 0.0753 - val_loss: 0.0893
Epoch 4/100
5/5 - 0s - 30ms/step - accuracy: 0.8235 - loss: 0.1046 - val_loss: 0.1214
Epoch 5/100
5/5 - 0s - 30ms/step - accuracy: 0.7941 - loss: 0.1257 - val_loss: 0.1347
Epoch 6/100
5/5 - 0s - 29ms/step - accuracy: 0.8309 - loss: 0.1393 - val_loss: 0.1442
Epoch 7/100
5/5 - 0s - 30ms/step - accuracy: 0.8235 - loss: 0.1476 - val_loss: 0.1435
Epoch 8/100
5/5 - 0s - 29ms/step - accuracy: 0.8015 - loss: 0.1477 - val_loss: 0.1468
Epoch 9/100
5/5 - 0s - 29ms/step - accuracy: 0.8088 - loss: 0.1555 - val_loss: 0.1579
Epoch 10/100
5/5 - 0s - 29ms/step - accuracy: 0.8088 - loss: 0.1661 - val_loss: 0.1657
Epoch 11/100
5/5 - 0s - 29ms/step - accuracy: 0.8162 - loss: 0.1717 - val_loss: 0.1682
Epoch 12/100
5/5 - 0s - 29ms/step - accuracy: 0.8235 

<keras.src.callbacks.history.History at 0x7f8eb8110b00>

In [117]:
test_probs = model.predict(x=test_indices)
mapping = {v: k for (k, v) in class_idx.items()}

for i, (probs, labels) in enumerate(zip(test_probs[:10], test_labels[:10])):
    print(f'example {i+1}: {mapping[labels]}')
    for j, c in zip(probs, class_idx.keys()):
        print(f'probability of {c: <25} = {j*100:7.3f}%')
    print('---' * 20)

[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 11ms/step
example 1: Probabilistic_Methods
probability of Case_Based                =   0.290%
probability of Genetic_Algorithms        =   0.001%
probability of Neural_Networks           =  13.548%
probability of Probabilistic_Methods     =  86.046%
probability of Reinforcement_Learning    =   0.025%
probability of Rule_Learning             =   0.000%
probability of Theory                    =   0.089%
------------------------------------------------------------
example 2: Genetic_Algorithms
probability of Case_Based                =   0.000%
probability of Genetic_Algorithms        = 100.000%
probability of Neural_Networks           =   0.000%
probability of Probabilistic_Methods     =   0.000%
probability of Reinforcement_Learning    =   0.000%
probability of Rule_Learning             =   0.000%
probability of Theory                    =   0.000%
------------------------------------------------------------
example 3: Th