In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import networkx as nx
import matplotlib.pyplot as plt
import torch.optim as optim
from scipy import sparse as sp
import random
from graphviz import Graph
import pickle
import numpy as np

device = torch.device('cuda')

In [2]:
# Encoder
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, n_heads, is_concat = True, dropout = 0.6, leacky_relu_negative_slope = 0.2):
        super(GraphAttentionLayer, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.is_concat = is_concat
        self.n_heads = n_heads

        if is_concat:
            assert out_features % n_heads == 0

            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features

        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias = False)

        self.attn = nn.Linear(self.n_hidden * 2, 1, bias = False)
        self.activation = nn.LeakyReLU(negative_slope = leacky_relu_negative_slope)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout) 

    def forward(self, x, adj):
        n_nodes = x.shape[0]
        g=self.linear(x).view(n_nodes, self.n_heads, self.n_hidden)
        g_repeat = g.repeat(n_nodes, 1,1)
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim = -1)
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
        e = self.activation(self.attn(g_concat))
        e = e.squeeze(-1)
        adj = adj.repeat(1, 1, self.n_heads)
        assert adj.shape[0] == 1 or adj.shape[0] == n_nodes
        assert adj.shape[1] == 1 or adj.shape[1] == n_nodes
        assert adj.shape[2] == 1 or adj.shape[2] == self.n_heads
        e=e.masked_fill(adj == 0, 1)
        a = self.softmax(e)
        a = self.dropout(a)
        attn_res = torch.einsum('ijh,jhf->ihf', a, g)
        if self.is_concat:
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            attn_res = attn_res.mean(dim=1)
            return attn_res


In [3]:
# Decoder
class Decoder(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(Decoder, self).__init__()
        self.n_heads = n_heads
        self.hidden_features = hidden_features
        self.d_h = d_h
        self.in_features = in_features

        self.phi1 = torch.nn.Linear(d_h, 1)
        self.phi2 = torch.nn.Linear(d_h, 1)
        self.softmax = nn.Softmax(dim=1)
        self.C = torch.nn.Parameter(torch.randn(1)) # constant C
        self.activation = nn.Tanh()

    def forward(self, x, v_i, v_j):
        v_i = v_i.unsqueeze(0)
        phi1_v_i = torch.matmul(v_i, self.phi1.state_dict()['weight']) # phi1_v_i 의 사이즈 ()
        print("phi_v_i.size()")
        print(phi1_v_i.size())

        phi2_v_j = torch.matmul(v_j, self.phi2.state_dict()['weight']) # phi2_neighbors 의 사이즈 ()
        print("phi_v_j.size()")
        print(phi2_v_j.size())
        
        attn_input = torch.matmul(phi1_v_i, phi2_v_j.transpose(0,1)) / (self.d_h ** 0.5) # (1,n) 의 크기를 갖는 attn_input
        # attn_input = attn_input.squeeze(0)  # Remove the extra dimension
        attn_input = attn_input

        attn_output = self.C * self.activation(attn_input)
        print("attn_output.size")
        print(attn_output.size())
        # v_j의 크기를 (n, 1)로 변형하여 크기를 맞춤
        # v_j = v_j.unsqueeze(1)
        # masked_attn_output = attn_output.masked_fill(v_j == 0, float('-inf'))
        # masked_attn_output = masked_attn_output.squeeze(1)
        masked_attn_output = torch.where(v_j == 0, float('-inf'), attn_output)
        print("masked_attn_")
        print(masked_attn_output.size())

        masked_attn_output= masked_attn_output[0]
        masked_attn_output = masked_attn_output.unsqueeze(0)
        print("masked_attn_1")
        print(masked_attn_output.size())
        attn_weights = self.softmax(masked_attn_output)
        # attn_weights = self.softmax(attn_output)        
        print("attn_weights.size")
        print(attn_weights.size())


        output = torch.matmul(attn_weights, x)
        print("output_size")
        print(output.size())
        output = output.squeeze(0)
        print("output_size1")
        print(output.size())
        return permutation

In [4]:
class GAT(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(GAT, self).__init__()
        self.n_heads = n_heads
        self.attention1 = GraphAttentionLayer(in_features, hidden_features, n_heads, is_concat = True, dropout = dropout)
        self.attention2 = GraphAttentionLayer(hidden_features, out_features, 1, is_concat = False, dropout = dropout)
        self.norm= nn.LayerNorm(out_features)
        self.decoder = Decoder(in_features, hidden_features, out_features, n_heads, d_h)
    
    def forward(self, x, adj):
        x = self.attention1(x, adj)
        x = self.attention2(x, adj)
        x = self.norm(x)
        output = F.softmax(x, dim=0)
        output = torch.mean(output, dim=1)
        output = output.unsqueeze(0)
        output = output.transpose(0,1)
        return output
    
    def decode(self, output, v_i, v_j):
        return self.decoder(output, v_i, v_j)

In [51]:
in_features =  1
n_heads = 4

def generate_random_weighted_graph(num_nodes, num_edges, max_weight=10):
    # 방향 그래프 생성
    graph = nx.Graph()
    
    # 노드 추가
    nodes = range(num_nodes)
    graph.add_nodes_from(nodes)
    
    # 노드에 가중치 할당 및 노드 특징 벡터 생성
    x = torch.zeros(num_nodes, in_features)
    for node in graph.nodes:
        weight = random.randint(1, max_weight)
        graph.nodes[node]['weight'] = weight
        x[node] = weight

    # 간선 추가
    edges = []
    for i in range(num_edges):
        # 임의의 출발 노드와 도착 노드 선택
        source = random.choice(nodes)
        target = random.choice(nodes)
        
        # 출발 노드와 도착 노드가 같은 경우 건너뜀
        if source == target:
            continue
        
        # 간선 추가
        edges.append((source, target))

    graph.add_edges_from(edges)

    graph_original = graph

        # Generate v_prev tensor
    j = random.randint(0, num_nodes-1)

    adj_matrix = nx.adjacency_matrix(graph)
    adj_matrix_original = torch.Tensor(adj_matrix.todense())
    
    adj_matrix = adj_matrix + sp.eye(adj_matrix.shape[0]) # Add self-loop
    adj_tensor = torch.Tensor(adj_matrix.todense())

    adj_tensor = adj_tensor.unsqueeze(2) # adj_tensor (num_nodes, num_nodes, n_heads)
    # adj_tensor = adj_tensor.repeat(1, 1, n_heads) #
    
    return graph, x, adj_tensor, adj_matrix_original, graph_original, j

In [55]:
num_graphs = 100
output_file = 'random_undirected_graphs.pkl'

graphs = []

for _ in range(num_graphs):
    num_nodes, num_edges, max_weight = np.random.randint(1,20), np.random.randint(1,30), np.random.randint(1,30)
    graph, x, adj_tensor, adj_matrix_original, graph_original, j= generate_random_weighted_graph(num_nodes, num_edges, max_weight)
    graphs.append((x, adj_tensor, j, adj_matrix_original))


# 그래프를 pickle 파일로 저장
with open(output_file, 'wb') as f:
    pickle.dump(graphs, f)

In [56]:
# pickle 파일에서 그래프 데이터 로드
with open('random_undirected_graphs.pkl', 'rb') as f:
    graphs = pickle.load(f)

In [57]:
hidden_features = 4 * n_heads
out_features = n_heads
d_h = 4 * n_heads
dropout = 0.6
gat_model = GAT(in_features, hidden_features, out_features, n_heads, d_h).cuda()
gat_models = []
for graph_idx, (x, adj_tensor, j, adj_matrix_original) in enumerate(graphs):
    gat_models.append(gat_model)
    x = x.cuda()
    adj_tensor = adj_tensor.cuda()
    output = gat_model(x, adj_tensor)
    print(f"Graph {graph_idx+1} - Output:")
    print(output.shape)
    print(output)

    v_i = output[j,0]
    print("v_i")
    print(v_i, j)
    # Generate neighbors tensor
    print()
    print(adj_matrix_original.size())
    print(adj_tensor.squeeze(2).size())
    v_j = adj_matrix_original[-1,j].cuda() * output
    print("")
    print("v_j")
    print(v_j)


    output, decode_output, attn_weights = gat_model.decode(output, v_i, v_j)
    print(f"Graph {graph_idx+1} - Decode Output:")
    print(decode_output)
    print("Attention Weights:")
    print(attn_weights)


Graph 1 - Output:
torch.Size([4, 1])
tensor([[0.2426],
        [0.2580],
        [0.2497],
        [0.2497]], device='cuda:0', grad_fn=<TransposeBackward0>)
v_i
tensor(0.2426, device='cuda:0', grad_fn=<SelectBackward0>) 0

torch.Size([4, 4])
torch.Size([4, 4])

v_j
tensor([[0.2426],
        [0.2580],
        [0.2497],
        [0.2497]], device='cuda:0', grad_fn=<MulBackward0>)
phi_v_i.size()
torch.Size([16])
phi_v_j.size()
torch.Size([4, 16])
attn_output.size
torch.Size([4])
masked_attn_
torch.Size([4, 4])
masked_attn_1
torch.Size([1, 4])
attn_weights.size
torch.Size([1, 4])
output_size
torch.Size([1, 1])
output_size1
torch.Size([1])
Graph 1 - Decode Output:
tensor([0.2500], device='cuda:0', grad_fn=<SqueezeBackward1>)
Attention Weights:
tensor([[0.2500, 0.2500, 0.2500, 0.2500]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)
Graph 2 - Output:
torch.Size([3, 1])
tensor([[0.3149],
        [0.3165],
        [0.3686]], device='cuda:0', grad_fn=<TransposeBackward0>)
v_i
tensor(0.3149,

참고
https://chioni.github.io/posts/gat/