<a href="https://colab.research.google.com/github/L40S38/GAT_practice/blob/main/gat_node_classification_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAT node classification with pytorch

## import packages

In [167]:
!pip install torchinfo

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import numpy as np
import os
import urllib.request
import tarfile
import pandas as pd



In [168]:
# 訓練に際して、可能であればGPU（cuda）を設定します。GPUが搭載されていない場合はCPUを使用します
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cuda:0 device


## download Cora dataset

In [169]:
!wget "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" -O cora.tgz
!tar -xvzf cora.tgz

--2023-08-05 18:01:25--  https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
Resolving linqs-data.soe.ucsc.edu (linqs-data.soe.ucsc.edu)... 128.114.47.74
Connecting to linqs-data.soe.ucsc.edu (linqs-data.soe.ucsc.edu)|128.114.47.74|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 168052 (164K) [application/x-gzip]
Saving to: ‘cora.tgz’


2023-08-05 18:01:27 (247 KB/s) - ‘cora.tgz’ saved [168052/168052]

cora/
cora/README
cora/cora.cites
cora/cora.content


In [170]:
data_dir = "cora"

citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)

papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"),
    sep="\t",
    header=None,
    names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)

In [171]:
# subjectのリナンバリング
class_idx = {name: id for id, name in enumerate(sorted(papers["subject"].unique()))}
# paperのリナンバリング
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

In [172]:
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

In [173]:
features = np.array(papers.iloc[:, 1:-1])
edges = np.array(citations[["target","source"]])
labels = np.array(papers["subject"])

In [174]:
n_features = features.shape[1]
n_classes = len(np.unique(labels))
n_nodes = features.shape[0]

In [175]:
features = torch.from_numpy(features).float()
edges = torch.from_numpy(edges)
labels = torch.from_numpy(labels)

In [176]:
indices = np.array([i for i in range(n_nodes)])

In [177]:
train_length = int(0.5*n_nodes)
boolean = [True] * train_length + [False] * (n_nodes-train_length)
boolean

[True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,

In [178]:
#データセットの分割
# train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.5, stratify=labels)
np.random.shuffle(boolean)
indices_bool = torch.from_numpy(np.array(boolean))
indices_bool

tensor([False,  True,  True,  ..., False, False,  True])

## Implement Graph Attention Layer

In [304]:
# Graph Attention Layer
class GraphAttention(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2):
        super(GraphAttention, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.alpha = alpha

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))

        self.leakyrelu = nn.LeakyReLU(self.alpha)

        self.init_weights()

    def init_weights(self):
        torch.nn.init.xavier_uniform_(self.W.data, gain=1.414)
        torch.nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, inputs, edges):
        # 線形変換
        #print(f"inputs.size():{inputs.size()},self.W.size():{self.W.size()}")
        z = torch.mm(inputs,self.W)
        #print(z.shape)

        # e_ijの計算
        edges_new_axis = edges.reshape(edges.size(0),edges.size(1),-1)
        edges_expand = edges_new_axis.expand(edges.size(0),edges.size(1),z.size(1))
        z_new_axis = z.reshape(z.size(0),-1,z.size(1))
        z_expand = z_new_axis.expand(z.size(0),edges.size(1),z.size(1))
        features_previous_concat = torch.gather(z_expand,0,edges_expand)
        #print(f"features_previous_concat.size():{features_previous_concat.size()}")
        features_concat = features_previous_concat.reshape(edges.size(0),-1)
        #print(features_concat.shape)
        attention_score = self.leakyrelu(torch.mm(features_concat,self.a))
        #print(f"attention_score.size():{attention_score.size()}")

        # 正規化
        E = torch.stack([torch.where(edges[:,0]==i,1,0) for i in range(n_nodes)]).float()
        attention_score_sum_expand = torch.mm(torch.transpose(E,1,0),torch.mm(E,torch.exp(attention_score)))
        attention_score_norm = attention_score/attention_score_sum_expand
        #print(f"attention_score_noem.size():{attention_score_norm.size()}")

        # z_jの更新
        to_renew_z = attention_score_norm * features_previous_concat[:,1::2,:].squeeze() #.reshape(-1, out_features)
        #print(f"to_renew_z.size():{to_renew_z.size()}")
        D = torch.stack([torch.where(edges[:,1]==j,1,0) for j in range(n_nodes)]).float()
        #print(f"D.size():{D.size()}")
        out = z + torch.matmul(D,to_renew_z)
        #print(f"out.size() of {self.__class__.__name__}:{out.size()}")

        return out

In [305]:
# Multi-Head Graph Attention Layer
class MultiHeadGraphAttention(nn.Module):
    def __init__(self, in_features, n_hidden, n_layers, n_heads, merge_type="concat", dropout=0.6, alpha=0.2):
        super(MultiHeadGraphAttention, self).__init__()
        self.heads = nn.ModuleList([GraphAttention(in_features, n_hidden, dropout=dropout, alpha=alpha) for _ in range(n_layers)])
        self.merge_type = merge_type

    def forward(self, inputs, edges):
        head_outs = [head(inputs, edges) for head in self.heads]
        #print(f"head_outs.size() of {self.__class__.__name__}:{torch.tensor(head_outs).size()}")
        if self.merge_type == "concat":
            out = torch.cat(head_outs, dim=1)
        else:
            out =  torch.mean(torch.stack(head_outs), dim=0)
        #print(f"out.size() of {self.__class__.__name__}:{out.size()}")
        return F.relu(out)

In [306]:
# Graph Attention Network
class GAT(nn.Module):
    def __init__(self, n_features, n_classes, n_hidden, n_layers, n_heads, dropout=0.6, alpha=0.2):
        super(GAT, self).__init__()
        self.n_features = n_features
        self.n_classes = n_classes
        self.n_hidden = n_hidden
        self.n_heads = n_heads
        self.dropout = dropout
        self.alpha = alpha
        self.preprocess = nn.Linear(n_features, n_hidden * n_layers)
        self.relu = F.relu
        self.attentions = nn.ModuleList([MultiHeadGraphAttention(n_hidden * n_layers, n_hidden, n_layers, n_heads, dropout=dropout, alpha=alpha) for _ in range(n_heads)])
        #self.out_att = GraphAttention(n_hidden*n_heads, n_classes, dropout=dropout, alpha=alpha)
        self.output = nn.Linear(n_hidden * n_layers, n_classes)

    def forward(self, inputs, edges):
        x = self.preprocess(inputs)
        x = self.relu(x)
        for att in self.attentions:
            x = att(x, edges) + x
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.output(x)
        return F.log_softmax(x, dim=1)

In [307]:
# モデルの定義
n_hidden = 100
n_heads = 8
n_layers = 3
model = GAT(n_features, n_classes, n_hidden, n_layers, n_heads)

## Model Training and Evaluating

In [320]:
# モデルの学習
def train_model(model, optimizer, criterion, indices_bool, data_length, features, labels, edges):
    model.train()
    features, labels, edges = features.to(device), labels.to(device), edges.to(device)
    #features, labels = features.to(device), labels.to(device)
    output = model(features, edges)
    loss = criterion(output[indices_bool], labels[indices_bool])
    loss.backward()
    optimizer.step()
    total_loss = loss.item() * features.size(0)
    with torch.no_grad():
        pred = output[indices_bool].argmax(dim=1, keepdim=True).squeeze()
        correct = (pred==labels[indices_bool]).sum().item()
        print(f"Train correct:{correct}/{data_length}")
    #optimizer.zero_grad()
    return total_loss / data_length, correct / data_length

# モデルの評価
def evaluate_model(model, criterion, indices_bool, data_length, features, labels, edges):
    model.eval()
    features, labels, edges = features.to(device), labels.to(device), edges.to(device)
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        #features, labels = features.to(device), labels.to(device)
        output = model(features, edges)
        loss = criterion(output[indices_bool], labels[indices_bool])
        total_loss = loss.item() * features.size(0)
        pred = output[indices_bool].argmax(dim=1, keepdim=True).squeeze()
        correct = (pred==labels[indices_bool]).sum().item()
        print(f"Test correct:{correct}/{data_length}")
    return total_loss / data_length, correct / data_length

In [321]:
# 損失関数と最適化手法の定義
criterion = nn.NLLLoss
optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

In [322]:
# モデルの学習と評価
n_epochs = 30
model = model.to(device)

In [323]:
from torchinfo import summary
summary(model=model,depth=4)

Layer (type:depth-idx)                        Param #
GAT                                           --
├─Linear: 1-1                                 430,200
├─ModuleList: 1-2                             --
│    └─MultiHeadGraphAttention: 2-1           --
│    │    └─ModuleList: 3-1                   --
│    │    │    └─GraphAttention: 4-1          30,200
│    │    │    └─GraphAttention: 4-2          30,200
│    │    │    └─GraphAttention: 4-3          30,200
│    └─MultiHeadGraphAttention: 2-2           --
│    │    └─ModuleList: 3-2                   --
│    │    │    └─GraphAttention: 4-4          30,200
│    │    │    └─GraphAttention: 4-5          30,200
│    │    │    └─GraphAttention: 4-6          30,200
│    └─MultiHeadGraphAttention: 2-3           --
│    │    └─ModuleList: 3-3                   --
│    │    │    └─GraphAttention: 4-7          30,200
│    │    │    └─GraphAttention: 4-8          30,200
│    │    │    └─GraphAttention: 4-9          30,200
│    └─MultiHeadGraphAt

In [324]:
from torch.utils.tensorboard import SummaryWriter
import datetime
t_delta = datetime.timedelta(hours=9)
JST = datetime.timezone(t_delta, 'JST')
now = datetime.datetime.now(JST)
writer = SummaryWriter(log_dir=f"./logs/{now:%Y%m%d%H%M}")

for epoch in range(n_epochs):
    train_loss, train_accuracy  = train_model(model, optimizer, criterion, indices_bool, train_length, features, labels, edges)
    test_loss, test_accuracy = evaluate_model(model, criterion, (indices_bool==False), n_nodes-train_length, features, labels, edges)
    print(f"Epoch {epoch + 1}/{n_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
    writer.add_scalar("train/loss",train_loss,epoch)
    writer.add_scalar("train/accuracy",train_accuracy,epoch)
    writer.add_scalar("test/loss",test_loss,epoch)
    writer.add_scalar("test/accuracy",test_accuracy,epoch)

writer.close()

RuntimeError: ignored

In [None]:
from sklearn.metrics import confusion_matrix

def eval_check(model, device, indices_bool, features, labels, edges):
  features, labels, edges = features.to(device), labels.to(device), edges.to(device)
  with torch.no_grad():
    output = model(features, edges)
    print(output[indices_bool==False].size(),labels[indices_bool==False].size())

    pred = output[indices_bool==False].argmax(dim=1, keepdim=True).squeeze()
    correct = (pred==labels[indices_bool==False]).sum().item()
    print(pred.size())
    print(correct)
    print(confusion_matrix(labels[indices_bool==False].cpu(), pred.cpu()))

eval_check(model, device, indices_bool, features, labels, edges)

In [None]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  print("エラー、やり直してください")
  pass
%load_ext tensorboard

In [None]:
# tensorboardで結果を見るときはこのコメントアウトを外す
# %tensorboard --logdir="./logs/TTTT"

# for debug

## Graph Attention Layerの計算

In [None]:
features.shape,edges.shape,labels.shape

### 0: 線形変換

In [None]:
in_features = n_features
out_features = n_hidden
W = torch.randn(size=(in_features, out_features))
W.shape

In [None]:
z = torch.mm(features,W)
z.shape

### 1: $e_{ij}$の計算

In [None]:
edges_new_axis = edges.reshape(edges.shape[0],edges.shape[1],-1)
edges_expand = edges_new_axis.expand(edges.shape[0],edges.shape[1],z.shape[1])
edges_expand,edges_expand.shape

In [None]:
z_new_axis = z.reshape(z.shape[0],-1,z.shape[1])
z_expand = z_new_axis.expand(z.shape[0],edges.shape[1],z.shape[1])
z_expand, z_expand.shape

$features\_previous\_concat[i][j][k] = z\_expand[edges\_expand[i][j][k]][j][k]$

$features\_concat[i] = [features\_previous\_concat[i][0] || features\_previous\_concat[i][1]] = [z[edges[0]] || z[edges[1]]]$

In [None]:
features_previous_concat = torch.gather(z_expand,0,edges_expand)
features_previous_concat

In [None]:
features_concat = features_previous_concat.reshape(edges.shape[0],-1)
features_concat.shape

$LeakyReLU(x)=max(0,x)+negative\_slope∗min(0,x)$

Reference:https://pytorch.org/docs/stable/generated/torch.nn.functional.leaky_relu.html

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

negative_slope = 0.01 #default value
X = np.arange(-0.5,0.5,0.01)
LeakyReLU = np.maximum(X,0) + negative_slope*np.minimum(X,0)
plt.plot(X,LeakyReLU)

In [None]:
#LeakyReluの適用
a = torch.randn(size=(2*out_features, 1))
alpha = 0.2
leakyrelu = nn.LeakyReLU(alpha)
attention_score = leakyrelu(torch.mm(features_concat,a))
attention_score, attention_score.shape

### 2: 正規化

$edges[k]=[i,j]$のとき$attention\_score[k] = e_{ij}$

$E_{kl} =
  \begin{cases}
    1 & \quad \textrm{if } edges[k][0]==l \\
    0                 & \quad \textrm{otherwise}
  \end{cases}
$

$e\_sum_i = \sum_j exp(e_{ij}) = E exp(e_{ij})$

In [None]:
E = torch.stack([torch.where(edges[:,0]==i,1,0) for i in range(n_nodes)]).float()
E, E.shape

In [None]:
attention_score_sum = torch.matmul(E,torch.exp(attention_score))
attention_score_sum, attention_score_sum.shape

In [None]:
attention_score_sum_expand = torch.matmul(torch.transpose(E,1,0),attention_score_sum)
attention_score_sum_expand, attention_score_sum_expand.shape

In [None]:
attention_score_norm = attention_score/attention_score_sum_expand
attention_score_norm, attention_score_norm.shape

### 3: $z_j$ の更新

$to\_renew\_z_{e_{ij}} = norm(e_{ij})*\boldsymbol{h}_j$

In [None]:
to_renew_z = attention_score_norm * features_previous_concat[:,1::2,:].reshape(-1,out_features)
to_renew_z, to_renew_z.shape

In [None]:
D = torch.stack([torch.where(edges[:,1]==j,1,0) for j in range(n_nodes)]).float()
D, D.shape

In [None]:
out = z + torch.matmul(D,to_renew_z)
out, out.shape

## Graph Attention Layerのまとめ

In [None]:
# 線形変換
z = torch.mm(features,W)

# e_ijの計算
edges_new_axis = edges.reshape(edges.shape[0],edges.shape[1],-1)
edges_expand = edges_new_axis.expand(edges.shape[0],edges.shape[1],z.shape[1])
z_new_axis = z.reshape(z.shape[0],-1,z.shape[1])
z_expand = z_new_axis.expand(z.shape[0],edges.shape[1],z.shape[1])
features_previous_concat = torch.gather(z_expand,0,edges_expand)
features_concat = features_previous_concat.reshape(edges.shape[0],-1)
attention_score = leakyrelu(torch.mm(features_concat,a))

# 正規化
E = torch.stack([torch.where(edges[:,0]==i,1,0) for i in range(n_nodes)]).float()
attention_score_sum_expand = torch.mm(torch.transpose(E,1,0),torch.mm(E,torch.exp(attention_score)))
attention_score_norm = attention_score/attention_score_sum_expand

# z_jの更新
to_renew_z = attention_score_norm * features_previous_concat[:,1::2,:].reshape(edges.shape[0],-1)
D = torch.stack([torch.where(edges[:,1]==j,1,0) for j in range(n_nodes)]).float()
out = z + torch.matmul(D,to_renew_z)

In [None]:
out.shape

## torch.catについて

In [None]:
head_out = [torch.arange(0,2708*8).reshape(2708,8) for _ in range(n_heads)]
head_out

In [None]:
torch.cat(head_out,dim=1),torch.cat(head_out,dim=0)

In [None]:
torch.cat(head_out,dim=1).size(),torch.cat(head_out,dim=0).size()