<a href="https://colab.research.google.com/github/AslantheAslan/Node-classification-by-several-methods/blob/main/Cora_Node_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd

from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
a = np.genfromtxt(r'/content/drive/My Drive/Colab Notebooks/Data/cora_data/edges.csv', delimiter=',', dtype=int)
print(type(a))

<class 'numpy.ndarray'>


In [3]:
print(a.shape)

(5429, 2)


In [4]:
edges = list(map(tuple, a))

In [None]:
dt = np.dtype([('nodes', np.integer), ('classes', np.unicode_, 25)])

In [6]:
df = pd.read_csv("/content/drive/My Drive/Colab Notebooks/Data/cora_data/group-edges.csv")
df["Class"] = df["Class"].map({"Rule_Learning": 0, "Neural_Networks": 1, "Theory": 2, "Case_Based": 3, "Probabilistic_Methods": 4, "Genetic_Algorithms": 5, "Reinforcement_Learning": 6})
df['Class'].value_counts()


1    818
4    426
5    418
2    351
3    298
6    217
0    180
Name: Class, dtype: int64

In [7]:
b = df.to_numpy(dtype=int)
b

array([[1000012,       0],
       [ 100197,       1],
       [ 100701,       3],
       ...,
       [  99023,       1],
       [  99025,       1],
       [  99030,       1]])

In [8]:
classes = list(map(tuple, b))
len(classes)

2708

In [9]:
## Enumerating the data samples once again since we need to assign nodes numbers from 0 to 2707 adjacently.

for i in range(len(b)):
  k = b[i][0]
  for j in range(len(a)):
    for m in range(2):
      if a[j][m] == k:
        a[j][m] = i + 2000000
  b[i][0] = i

In [10]:
a1 = a
for i in range(len(a)):
  for j in range(2):
    a1[i][j] = a[i][j] - 2000000

In [11]:
edges = list(map(tuple, a1))
classes = list(map(tuple, b))

In [None]:
edges

In [None]:
classes

In [None]:
!pip install git+git://github.com/deepmind/jraph.git
!pip install flax
!pip install dm-haiku

In [15]:
%matplotlib inline
import functools
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph
import flax
import haiku as hk
import optax
import numpy as onp
import networkx as nx
from typing import Tuple

In [16]:
def get_cora_dataset() -> jraph.GraphsTuple:
  """Returns GraphsTuple representing cora dataset."""
  social_graph = edges
  # Add reverse edges.
  social_graph += [(edge[1], edge[0]) for edge in social_graph]
  n_club_members = 2708

  return jraph.GraphsTuple(
      n_node=jnp.asarray([n_club_members]),
      n_edge=jnp.asarray([len(social_graph)]),
      # One-hot encoding for nodes, i.e. argmax(nodes) = node index.
      nodes=jnp.eye(n_club_members),
      # No edge features.
      edges=None,
      globals=None,
      senders=jnp.asarray([edge[0] for edge in social_graph]),
      receivers=jnp.asarray([edge[1] for edge in social_graph]))

def get_ground_truth_assignments_for_cora_dataset() -> jnp.ndarray:
  """Returns ground truth assignments for cora dataset."""
  return jnp.asarray([element[1] for element in classes])

In [17]:
graph = get_cora_dataset()



In [18]:
print(f'Number of nodes: {graph.n_node[0]}')
print(f'Number of edges: {graph.n_edge[0]}')

Number of nodes: 2708
Number of edges: 10858


In [19]:
def convert_jraph_to_networkx_graph(jraph_graph):
  nodes, edges, receivers, senders, _, _, _ = jraph_graph
  nx_graph = nx.DiGraph()
  if nodes is None:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n)
  else:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n, node_feature=nodes[n])
  if edges is None:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(int(senders[e]), int(receivers[e]))
  else:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(int(senders[e]), int(receivers[e]), edge_feature=edges[e])
  return nx_graph

def draw_jraph_graph_structure(jraph_graph: jraph.GraphsTuple):
  nx_graph = convert_jraph_to_networkx_graph(jraph_graph)
  pos = nx.spring_layout(nx_graph)
  nx.draw(nx_graph, pos=pos, with_labels = True, node_size=500, font_color='yellow')

In [None]:
"""
# Only make it work when you need to visualize the whole graph. 
# Note that it costs too much of computation time on Google Colab.

nx_graph = convert_jraph_to_networkx_graph(graph)
pos = nx.random_layout(nx_graph)
plt.figure(figsize=(12, 12))

nx.draw(nx_graph, pos=pos, with_labels = True, node_size=100, font_color='yellow')

"""

In [21]:
"""
Here, the necessary functions to implement a GCN were defined
"""

class MLP(hk.Module):
  def __init__(self, features: jnp.ndarray):
    super().__init__()
    self.features = features

  def __call__(self, x: jnp.ndarray):
    layers = []
    for feat in self.features[:-1]:
      layers.append(hk.Linear(feat))
      layers.append(jax.nn.relu)
    layers.append(hk.Linear(self.features[-1]))

    mlp = hk.Sequential(layers)
    return mlp(x)

def apply_simplified_gcn(graph: jraph.GraphsTuple):
  # Unpack GraphsTuple
  nodes, _, receivers, senders, _, _, _ = graph

  # 1. Update node features
  # For simplicity, we will first use an identify function here, and replace it
  # with a trainable MLP block later.
  update_node_fn = lambda nodes: nodes
  nodes = update_node_fn(nodes)

  # 2. Aggregate node features over nodes in neighborhood
  # Equivalent to jnp.sum(n_node), but jittable
  total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
  aggregate_nodes_fn = jax.ops.segment_sum

  # Compute new node features by aggregating messages from neighboring nodes
  nodes = tree.tree_map(lambda x: aggregate_nodes_fn(x[senders], receivers,
                                        total_num_nodes), nodes)
  out_graph = graph._replace(nodes=nodes)
  return out_graph

def add_self_edges_fn(receivers, senders, total_num_nodes):
  """Adds self edges. Assumes self edges are not in the graph yet."""
  receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
  senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
  return receivers, senders

# Adapted from https://github.com/deepmind/jraph/blob/master/jraph/_src/models.py#L506
def GraphConvolution(
    update_node_fn,
    aggregate_nodes_fn=jax.ops.segment_sum,
    add_self_edges: bool = False,
    symmetric_normalization: bool = True):
  """Returns a method that applies a Graph Convolution layer.
  Graph Convolutional layer as in https://arxiv.org/abs/1609.02907,
  NOTE: This implementation does not add an activation after aggregation.
  If you are stacking layers, you may want to add an activation between
  each layer.
  Args:
    update_node_fn: function used to update the nodes. In the paper a single
      layer MLP is used.
    aggregate_nodes_fn: function used to aggregates the sender nodes.
    add_self_edges: whether to add self edges to nodes in the graph as in the
      paper definition of GCN. Defaults to False.
    symmetric_normalization: whether to use symmetric normalization. Defaults
      to True.
  Returns:
    A method that applies a Graph Convolution layer.
  """
  def _ApplyGCN(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Applies a Graph Convolution layer."""
    nodes, _, receivers, senders, _, _, _ = graph

    # First pass nodes through the node updater.
    nodes = update_node_fn(nodes)
    # Equivalent to jnp.sum(n_node), but jittable
    total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
    if add_self_edges:
      # We add self edges to the senders and receivers so that each node
      # includes itself in aggregation.
      # In principle, a `GraphsTuple` should partition by n_edge, but in
      # this case it is not required since a GCN is agnostic to whether
      # the `GraphsTuple` is a batch of graphs or a single large graph.
      conv_receivers, conv_senders = add_self_edges_fn(receivers, senders, total_num_nodes)
    else:
      conv_senders = senders
      conv_receivers = receivers

    # pylint: disable=g-long-lambda
    if symmetric_normalization:
      # Calculate the normalization values.
      count_edges = lambda x: jax.ops.segment_sum(
          jnp.ones_like(conv_senders), x, total_num_nodes)
      sender_degree = count_edges(conv_senders)
      receiver_degree = count_edges(conv_receivers)

      # Pre normalize by sqrt sender degree.
      # Avoid dividing by 0 by taking maximum of (degree, 1).
      nodes = tree.tree_map(
          lambda x: x * jax.lax.rsqrt(jnp.maximum(sender_degree, 1.0))[:, None],
          nodes,
      )
      # Aggregate the pre-normalized nodes.
      nodes = tree.tree_map(
          lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers,
                                       total_num_nodes), nodes)
      # Post normalize by sqrt receiver degree.
      # Avoid dividing by 0 by taking maximum of (degree, 1).
      nodes = tree.tree_map(
          lambda x:
          (x * jax.lax.rsqrt(jnp.maximum(receiver_degree, 1.0))[:, None]),
          nodes,
      )
    else:
      nodes = tree.tree_map(
          lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers,
                                       total_num_nodes), nodes)
    # pylint: enable=g-long-lambda
    return graph._replace(nodes=nodes)

  return _ApplyGCN

def gcn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Defines a graph neural network with 3 GCN layers.
  Args:
    graph: GraphsTuple the network processes.

  Returns:
    output graph with updated node values.
  """
  gn = GraphConvolution(
      update_node_fn=lambda n: jax.nn.relu(hk.Linear(64)(n)),
      add_self_edges=True)
  graph = gn(graph)

  gn = GraphConvolution(
      update_node_fn=lambda n: jax.nn.relu(hk.Linear(32)(n)),
      add_self_edges=True)
  graph = gn(graph)

  gn = GraphConvolution(
      update_node_fn=hk.Linear(7))
  graph = gn(graph)
  return graph

def gcn_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Defines a GCN for the Cora dataset task.
  Args:
    graph: GraphsTuple the network processes.

  Returns:
    output graph with updated node values.
  """
  gn = GraphConvolution(
      update_node_fn=lambda n: jax.nn.relu(hk.Linear(64)(n)),
      add_self_edges=True)
  graph = gn(graph)

  gn = GraphConvolution(
      update_node_fn=hk.Linear(7)) # output dim is 7 because we have 7 output classes.
  graph = gn(graph)
  return graph

In [22]:
def optimize_cora_dataset(network, num_steps: int):
  """Solves the Cora dataset problem by optimizing."""
  cora_dataset = get_cora_dataset()
  labels = get_ground_truth_assignments_for_cora_dataset()
  params = network.init(jax.random.PRNGKey(42), cora_dataset)

  @jax.jit
  def predict(params):
    decoded_graph = network.apply(params, cora_dataset)
    return jnp.argmax(decoded_graph.nodes, axis=1)

  @jax.jit
  def prediction_loss(params):
    decoded_graph = network.apply(params, cora_dataset)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_graph.nodes)
    # Here I have given some aprioric knowledge to the training set. I have
    # used 2166 randomly shuffled samples to predict whole labels in the dataset
    sonuc = 0
    iterate = np.arange(2708)
    np.random.shuffle(iterate)
    for i in iterate[:2166]:
      j = classes[i][1]
      sonuc += log_prob[i,j]
    return -(sonuc)

  opt_init, opt_update = optax.adam(1e-2)
  opt_state = opt_init(params)

  @jax.jit
  def update(params, opt_state):
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params):
    decoded_graph = network.apply(params, cora_dataset)
    return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)

  for step in range(num_steps):
    print(f"step {step} accuracy {accuracy(params).item():.4f}")
    params, opt_state = update(params, opt_state)

  return predict(params)

In [23]:
network = hk.without_apply_rng(hk.transform(gcn_definition))
result = optimize_cora_dataset(network, num_steps=500)

# accuracy : %96.38 for 2166 samples after 500 epochs


step 0 accuracy 0.1208
step 1 accuracy 0.5853
step 2 accuracy 0.5853
step 3 accuracy 0.5842
step 4 accuracy 0.5835
step 5 accuracy 0.5831
step 6 accuracy 0.5890
step 7 accuracy 0.6126
step 8 accuracy 0.6536
step 9 accuracy 0.6983
step 10 accuracy 0.7194
step 11 accuracy 0.7397
step 12 accuracy 0.7703
step 13 accuracy 0.7973
step 14 accuracy 0.8257
step 15 accuracy 0.8497
step 16 accuracy 0.8652
step 17 accuracy 0.8733
step 18 accuracy 0.8818
step 19 accuracy 0.8914
step 20 accuracy 0.8970
step 21 accuracy 0.9014
step 22 accuracy 0.9036
step 23 accuracy 0.9055
step 24 accuracy 0.9077
step 25 accuracy 0.9117
step 26 accuracy 0.9129
step 27 accuracy 0.9140
step 28 accuracy 0.9165
step 29 accuracy 0.9191
step 30 accuracy 0.9180
step 31 accuracy 0.9191
step 32 accuracy 0.9202
step 33 accuracy 0.9217
step 34 accuracy 0.9221
step 35 accuracy 0.9236
step 36 accuracy 0.9247
step 37 accuracy 0.9261
step 38 accuracy 0.9273
step 39 accuracy 0.9306
step 40 accuracy 0.9335
step 41 accuracy 0.9346
st

In [24]:
######### GAT Implementation #########

# GAT implementation adapted from https://github.com/deepmind/jraph/blob/master/jraph/_src/models.py#L442.
def GAT(attention_query_fn,
        attention_logit_fn,
        node_update_fn=None,
        add_self_edges=True):
  """Returns a method that applies a Graph Attention Network layer.
  Graph Attention message passing as described in
  https://arxiv.org/pdf/1710.10903.pdf. This model expects node features as a
  jnp.array, may use edge features for computing attention weights, and
  ignore global features. It does not support nests.
  Args:
    attention_query_fn: function that generates attention queries
      from sender node features.
    attention_logit_fn: function that converts attention queries into logits for
      softmax attention.
    node_update_fn: function that updates the aggregated messages. If None,
      will apply leaky relu and concatenate (if using multi-head attention).
  Returns:
    A function that applies a Graph Attention layer.
  """
  # pylint: disable=g-long-lambda
  if node_update_fn is None:
    # By default, apply the leaky relu and then concatenate the heads on the
    # feature axis.
    node_update_fn = lambda x: jnp.reshape(
        jax.nn.leaky_relu(x), (x.shape[0], -1))

  def _ApplyGAT(graph):
    """Applies a Graph Attention layer."""
    nodes, edges, receivers, senders, _, _, _ = graph
    # Equivalent to the sum of n_node, but statically known.
    try:
      sum_n_node = nodes.shape[0]
    except IndexError:
      raise IndexError('GAT requires node features')

    # Pass nodes through the attention query function to transform
    # node features, e.g. with an MLP.
    nodes = attention_query_fn(nodes)

    total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
    if add_self_edges:
      # We add self edges to the senders and receivers so that each node
      # includes itself in aggregation.
      receivers, senders = add_self_edges_fn(receivers, senders, total_num_nodes)

    # We compute the softmax logits using a function that takes the
    # embedded sender and receiver attributes.
    sent_attributes = nodes[senders]
    received_attributes = nodes[receivers]
    att_softmax_logits = attention_logit_fn(
        sent_attributes, received_attributes, edges)

    # Compute the attention softmax weights on the entire tree.
    att_weights = jraph.segment_softmax(att_softmax_logits, segment_ids=receivers,
                                    num_segments=sum_n_node)

    # Apply attention weights.
    messages = sent_attributes * att_weights
    # Aggregate messages to nodes.
    nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)

    # Apply an update function to the aggregated messages.
    nodes = node_update_fn(nodes)

    return graph._replace(nodes=nodes)
  # pylint: enable=g-long-lambda
  return _ApplyGAT

In [25]:
def attention_logit_fn(sender_attr, receiver_attr, edges):
  del edges
  x = jnp.concatenate((sender_attr, receiver_attr), axis=1)
  return hk.Linear(1)(x)

gat_layer = GAT(
    attention_query_fn=lambda n: hk.Linear(7)(n),  # Applies W to the node features
    attention_logit_fn=attention_logit_fn,
    node_update_fn=None,
    add_self_edges=True,
)

In [26]:
def gat_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Defines a GAT network for the Cora dataset node classification task.
  Args:
    graph: GraphsTuple the network processes.

  Returns:
    output graph with updated node values.
  """
  def _attention_logit_fn( sender_attr, receiver_attr, edges):
    del edges
    x = jnp.concatenate((sender_attr, receiver_attr), axis=1)
    return hk.Linear(1)(x)

  gn = GAT(
    attention_query_fn=lambda n: hk.Linear(64)(n),
    attention_logit_fn=_attention_logit_fn,
    node_update_fn=None,
    add_self_edges=True)
  graph = gn(graph)

  gn = GAT(
    attention_query_fn=lambda n: hk.Linear(32)(n),
    attention_logit_fn=_attention_logit_fn,
    node_update_fn=hk.Linear(7),
    add_self_edges=True)
  graph = gn(graph)
  return graph

In [27]:
network = hk.without_apply_rng(hk.transform(gat_definition))
result = optimize_cora_dataset(network, num_steps=500)
# best accuracy : %96.94 for 2166 samples after 500 epochs

step 0 accuracy 0.1030
step 1 accuracy 0.3021
step 2 accuracy 0.3021
step 3 accuracy 0.3021
step 4 accuracy 0.3021
step 5 accuracy 0.3021
step 6 accuracy 0.3021
step 7 accuracy 0.3021
step 8 accuracy 0.3220
step 9 accuracy 0.5391
step 10 accuracy 0.6547
step 11 accuracy 0.8058
step 12 accuracy 0.8386
step 13 accuracy 0.8578
step 14 accuracy 0.9007
step 15 accuracy 0.9062
step 16 accuracy 0.9129
step 17 accuracy 0.9147
step 18 accuracy 0.9151
step 19 accuracy 0.9213
step 20 accuracy 0.9243
step 21 accuracy 0.9280
step 22 accuracy 0.9298
step 23 accuracy 0.9346
step 24 accuracy 0.9361
step 25 accuracy 0.9387
step 26 accuracy 0.9402
step 27 accuracy 0.9428
step 28 accuracy 0.9450
step 29 accuracy 0.9479
step 30 accuracy 0.9494
step 31 accuracy 0.9494
step 32 accuracy 0.9542
step 33 accuracy 0.9583
step 34 accuracy 0.9579
step 35 accuracy 0.9583
step 36 accuracy 0.9579
step 37 accuracy 0.9594
step 38 accuracy 0.9601
step 39 accuracy 0.9631
step 40 accuracy 0.9627
step 41 accuracy 0.9631
st

In [28]:
########## Harmonic Function ##########

# Still under construction. Previously got the accuracy as %100 for Karate Dataset. Now trying to implement it for cora dataset.

def cora_graph():
  G = nx.Graph()
  len(edges)
  for i in range(len(edges)):
    G.add_edge(edges[i][0],edges[i][1])
  return G

In [29]:
G = cora_graph()

In [None]:
def predicted(G):
  split = np.arange(2708)
  np.random.shuffle(split)
  for i in split[:100]:
    G.nodes[i]["label"] = f'{classes[i][1]}'
  predicted = nx.node_classification.harmonic_function(G, max_iter=30)
  #print(predicted)
  return predicted

epochs = 100
for i in range(epochs):
  predicted(G)


#print(len(predicted(G)))
#print(len(classes))
print(G.nodes)

def resulted(G):

  node_array = np.array(G.nodes)
  prediction = np.array(predicted(G))
  resulted = [None] * len(node_array)
  resulted = np.array(resulted)

  for i in range(len(node_array)):
    resulted[node_array[i]] = prediction[i]

  return resulted

epochs = 100
for i in range(epochs):
  predicted(G)



In [None]:
G.nodes(data=True)

In [32]:
real = np.asarray(classes)
real = real[:,1]
real = list(map(str,real))

predicted = resulted(G)

In [33]:
correct = 0
for i in range(len(real)):
        if real[i] == predicted[i]:
                correct += 1
miss = len(real)-correct
accuracy = correct/len(real)
print("Accuracy is %f " %accuracy)
print("Hits are %d and misses are %d" %(correct,miss))

Accuracy is 1.000000 
Hits are 2708 and misses are 0


In [34]:
predicted

array(['0', '1', '3', ..., '1', '1', '1'], dtype=object)