<a href="https://colab.research.google.com/github/Lee-Gunju/AI-paper-code-review-for-personal-project/blob/master/danielegrattarola_keras_gat_GAT_github.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone  https://github.com/danielegrattarola/keras-gat.git

Cloning into 'keras-gat'...
remote: Enumerating objects: 225, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 225 (delta 2), reused 1 (delta 0), pack-reused 218[K
Receiving objects: 100% (225/225), 5.76 MiB | 26.00 MiB/s, done.
Resolving deltas: 100% (97/97), done.


In [None]:
cd "/content/keras-gat/"

/content/keras-gat


In [None]:
from __future__ import division

import numpy as np
from keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
from keras.layers import Input, Dropout
from keras.models import Model
from keras.optimizers import Adam
from keras.regularizers import l2

from keras_gat import GraphAttention
from keras_gat.utils import load_data, preprocess_features

# Read data
A, X, Y_train, Y_val, Y_test, idx_train, idx_val, idx_test = load_data('cora')

# Parameters
N = X.shape[0]                # Number of nodes in the graph
F = X.shape[1]                # Original feature dimension
n_classes = Y_train.shape[1]  # Number of classes
F_ = 8                        # Output size of first GraphAttention layer
n_attn_heads = 8              # Number of attention heads in first GAT layer
dropout_rate = 0.6            # Dropout rate (between and inside GAT layers)
l2_reg = 5e-4/2               # Factor for l2 regularization
learning_rate = 5e-3          # Learning rate for Adam
epochs = 10000                # Number of training epochs
es_patience = 100             # Patience fot early stopping

# Preprocessing operations
X = preprocess_features(X)
A = A + np.eye(A.shape[0])  # Add self-loops

# Model definition (as per Section 3.3 of the paper)
X_in = Input(shape=(F,))
A_in = Input(shape=(N,))

dropout1 = Dropout(dropout_rate)(X_in)
graph_attention_1 = GraphAttention(F_,
                                   attn_heads=n_attn_heads,
                                   attn_heads_reduction='concat',
                                   dropout_rate=dropout_rate,
                                   activation='elu',
                                   kernel_regularizer=l2(l2_reg),
                                   attn_kernel_regularizer=l2(l2_reg))([dropout1, A_in])
dropout2 = Dropout(dropout_rate)(graph_attention_1)
graph_attention_2 = GraphAttention(n_classes,
                                   attn_heads=1,
                                   attn_heads_reduction='average',
                                   dropout_rate=dropout_rate,
                                   activation='softmax',
                                   kernel_regularizer=l2(l2_reg),
                                   attn_kernel_regularizer=l2(l2_reg))([dropout2, A_in])

# Build model
model = Model(inputs=[X_in, A_in], outputs=graph_attention_2)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              weighted_metrics=['acc'])
model.summary()

# Callbacks
es_callback = EarlyStopping(monitor='val_weighted_acc', patience=es_patience)
tb_callback = TensorBoard(batch_size=N)
mc_callback = ModelCheckpoint('logs/best_model.h5',
                              monitor='val_weighted_acc',
                              save_best_only=True,
                              save_weights_only=True)

# Train model
validation_data = ([X, A], Y_val, idx_val)
model.fit([X, A],
          Y_train,
          sample_weight=idx_train,
          epochs=epochs,
          batch_size=N,
          validation_data=validation_data,
          shuffle=False,  # Shuffling data means shuffling the whole graph
          callbacks=[es_callback, tb_callback, mc_callback])

# Load best model

#model.load_weights('logs/best_model.h5')

# Evaluate model
eval_results = model.evaluate([X, A],
                              Y_test,
                              sample_weight=idx_test,
                              batch_size=N,
                              verbose=0)
print('Done.\n'
      'Test loss: {}\n'
      'Test accuracy: {}'.format(*eval_results))

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
Epoch 4141/10000
Epoch 4142/10000
Epoch 4143/10000
Epoch 4144/10000
Epoch 4145/10000
Epoch 4146/10000
Epoch 4147/10000
Epoch 4148/10000
Epoch 4149/10000
Epoch 4150/10000
Epoch 4151/10000
Epoch 4152/10000
Epoch 4153/10000
Epoch 4154/10000
Epoch 4155/10000
Epoch 4156/10000
Epoch 4157/10000
Epoch 4158/10000
Epoch 4159/10000
Epoch 4160/10000
Epoch 4161/10000
Epoch 4162/10000
Epoch 4163/10000
Epoch 4164/10000
Epoch 4165/10000
Epoch 4166/10000
Epoch 4167/10000
Epoch 4168/10000
Epoch 4169/10000
Epoch 4170/10000
Epoch 4171/10000
Epoch 4172/10000
Epoch 4173/10000
Epoch 4174/10000
Epoch 4175/10000
Epoch 4176/10000
Epoch 4177/10000
Epoch 4178/10000
Epoch 4179/10000
Epoch 4180/10000
Epoch 4181/10000
Epoch 4182/10000
Epoch 4183/10000
Epoch 4184/10000
Epoch 4185/10000
Epoch 4186/10000
Epoch 4187/10000
Epoch 4188/10000
Epoch 4189/10000
Epoch 4190/10000
Epoch 4191/10000
Epoch 4192/10000
Epoch 4193/10000
Epoch 4194/10000
Epoch 4195/10000
Epoch 4196/1000

KeyboardInterrupt: ignored