In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import networkx as nx
import matplotlib.pyplot as plt
import torch.optim as optim
from scipy import sparse as sp
import random
from graphviz import Graph
import pickle
import numpy as np

device = torch.device('cuda')

In [87]:
a_1 = torch.randn(3,1)
a_2 = torch.randn(3,1)
c= a_1*a_2.T
d= torch.matmul(a_1.T,a_2)

d=d.squeeze(0)
d

tensor([-3.7614])

In [127]:
# Decoder
class Decoder(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(Decoder, self).__init__()
        self.n_heads = n_heads
        self.hidden_features = hidden_features
        self.d_h = d_h

        self.phi = torch.nn.Linear(hidden_features, d_h * hidden_features)
        self.softmax = nn.Softmax(dim=1)
        self.C = torch.nn.Parameter(torch.randn(1)) # constant C
        self.activation = nn.Tanh()

    def forward(self, x, v_i, v_j):
        phi1 = torch.randn(d_h,1).cuda()
        phi2 = torch.randn(d_h,1).cuda()

        phi1_v_i = phi1 * v_i # phi1_v_prev 의 사이즈 (d* d_h)
        phi2_v_j = phi2 * v_j #.squeeze(0) # phi2_neighbors 의 사이즈 (n,d* d_h)

        attn_input = torch.matmul(phi1_v_i.transpose(0,1), phi2_v_j) / (self.d_h ** 0.5) # (1,n) 의 크기를 갖는 attn_input
        # attn_input = attn_input.squeeze(0)  # Remove the extra dimension

        attn_output = self.C * self.activation(attn_input)

        masked_attn_output = attn_output.masked_fill(v_j == 0, float('-inf'))
        attn_weights = self.softmax(masked_attn_output)
        # attn_weights = self.softmax(attn_output)

        output = attn_weights.squeeze() * x
        output = output.squeeze(0)

        return x, output, attn_weights

In [128]:
# Encoder
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, n_heads, is_concat = True, dropout = 0.6, leacky_relu_negative_slope = 0.2):
        super(GraphAttentionLayer, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.is_concat = is_concat
        self.n_heads = n_heads

        if is_concat:
            assert out_features % n_heads == 0

            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features

        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias = False)

        self.attn = nn.Linear(self.n_hidden * 2, 1, bias = False)
        self.activation = nn.LeakyReLU(negative_slope = leacky_relu_negative_slope)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout) 

    def forward(self, x, adj):
        n_nodes = x.shape[0]
        g=self.linear(x).view(n_nodes, self.n_heads, self.n_hidden)
        g_repeat = g.repeat(n_nodes, 1,1)
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim = -1)
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
        e = self.activation(self.attn(g_concat))
        e = e.squeeze(-1)
        assert adj.shape[0] == 1 or adj.shape[0] == n_nodes
        assert adj.shape[1] == 1 or adj.shape[1] == n_nodes
        assert adj.shape[2] == 1 or adj.shape[2] == self.n_heads
        e=e.masked_fill(adj == 0, 1)
        a = self.softmax(e)
        a = self.dropout(a)
        attn_res = torch.einsum('ijh,jhf->ihf', a, g)
        if self.is_concat:
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            return attn_res.mean(dim = 1)


In [129]:
class GAT(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(GAT, self).__init__()
        self.n_heads = n_heads
        self.attention1 = GraphAttentionLayer(in_features, hidden_features, n_heads)
        self.attention2 = GraphAttentionLayer(hidden_features, out_features, n_heads)
        self.norm= nn.LayerNorm(out_features)
        self.decoder = Decoder(out_features, hidden_features, out_features, n_heads, d_h)
    
    def forward(self, x, adj):
        x = self.attention1(x, adj)
        x = self.attention2(x, adj)
        x = self.norm(x)
        x = F.softmax(x, dim=-1)
        return x
    
    def decode(self, x, v_i, v_j):
        return self.decoder(x, v_i, v_j)

In [130]:
def generate_random_weighted_graph(num_nodes, num_edges, max_weight=10):
    # 방향 그래프 생성
    graph = nx.Graph()
    
    # 노드 추가
    nodes = range(num_nodes)
    graph.add_nodes_from(nodes)
    
    # 간선 추가
    edges = []
    for i in range(num_edges):
        # 임의의 출발 노드와 도착 노드 선택
        source = random.choice(nodes)
        target = random.choice(nodes)
        
        # 출발 노드와 도착 노드가 같은 경우 건너뜀
        if source == target:
            continue
        
        # 가중치 랜덤 생성
        weight = random.randint(1, max_weight)
        
        # 간선 추가
        edges.append((source, target, weight))

    adj_matrix = nx.adjacency_matrix(graph)
    adj_matrix = adj_matrix + sp.eye(adj_matrix.shape[0]) # Add self-loop
    adj_tensor = torch.Tensor(adj_matrix.todense())

    in_features =  1
    x = torch.randn(num_nodes, in_features)

    adj_tensor = adj_tensor.unsqueeze(0)
    adj_tensor = adj_tensor.repeat(num_nodes, 1, 1)
    adj_tensor = adj_tensor.transpose(0,1)
        
    graph.add_weighted_edges_from(edges)
    
    return graph, x, adj_tensor

In [131]:
num_graphs = 100
output_file = 'random_undirected_graphs.pkl'

graphs = []

for _ in range(num_graphs):
    num_nodes, num_edges, max_weight = np.random.randint(1,20), np.random.randint(1,30), np.random.randint(1,30)
    graph, x, adj_tensor = generate_random_weighted_graph(num_nodes, num_edges, max_weight)
    graphs.append((x, adj_tensor))


# 그래프를 pickle 파일로 저장
with open(output_file, 'wb') as f:
    pickle.dump(graphs, f)

In [132]:
# pickle 파일에서 그래프 데이터 로드
with open('random_undirected_graphs.pkl', 'rb') as f:
    graphs = pickle.load(f)

In [136]:
gat_models = []
for graph_idx, (x, adj_tensor) in enumerate(graphs):
    in_features = x.shape[1]
    n_heads = adj_tensor.shape[2]
    hidden_features = 4 * n_heads
    out_features = n_heads
    d_h = 4 * n_heads
    gat_model = GAT(in_features, hidden_features, out_features, n_heads, d_h).cuda()
    gat_models.append(gat_model)
    x = x.cuda()
    adj_tensor = adj_tensor.cuda()
    output = gat_model(x, adj_tensor)
    print(f"Graph {graph_idx+1} - Output:")
    print(output.shape)
    #output : 각 노드에 대한 클래스 라벨 예측 값

    # Generate v_prev tensor
    v_i = torch.randn(1).cuda()
    print(v_i.size(0))
    # Generate neighbors tensor
    v_j = torch.randn(1).cuda()

    x, decode_output, attn_weights = gat_model.decode(output, v_i, v_j)
    # print(f"Graph {graph_idx+1} - Decode Output:")
    # print(decode_output)
    # print("Attention Weights:")
    # print(attn_weights)
    print(x)

Graph 1 - Output:
torch.Size([15, 15])
1
tensor([[0.0128, 0.3112, 0.0174, 0.0471, 0.0988, 0.0335, 0.0245, 0.0602, 0.0088,
         0.0746, 0.1281, 0.0120, 0.0783, 0.0139, 0.0789],
        [0.0183, 0.1441, 0.0158, 0.0406, 0.0817, 0.0297, 0.0119, 0.0664, 0.0073,
         0.1091, 0.1418, 0.0315, 0.0459, 0.0167, 0.2395],
        [0.0136, 0.0895, 0.0163, 0.0500, 0.0971, 0.0423, 0.0241, 0.1055, 0.0051,
         0.1098, 0.2296, 0.0244, 0.0438, 0.0196, 0.1290],
        [0.0346, 0.1748, 0.0234, 0.0527, 0.0966, 0.0391, 0.0090, 0.0525, 0.0140,
         0.1089, 0.1527, 0.0196, 0.0671, 0.0070, 0.1482],
        [0.0268, 0.0996, 0.0125, 0.0511, 0.0631, 0.0362, 0.0563, 0.1255, 0.0061,
         0.1168, 0.2097, 0.0296, 0.0800, 0.0077, 0.0789],
        [0.0295, 0.0952, 0.0150, 0.0503, 0.0612, 0.0323, 0.0144, 0.0895, 0.0101,
         0.0426, 0.2300, 0.0297, 0.1556, 0.0074, 0.1372],
        [0.0228, 0.0543, 0.0299, 0.0421, 0.0615, 0.0373, 0.0099, 0.1168, 0.0057,
         0.0809, 0.2637, 0.0209, 0.1253, 0.0

참고
https://chioni.github.io/posts/gat/

In [138]:
a=[0.0128, 0.3112, 0.0174, 0.0471, 0.0988, 0.0335, 0.0245, 0.0602, 0.0088, 0.0746, 0.1281, 0.0120, 0.0783, 0.0139, 0.0789]
sum(a)

1.0001

In [None]:
# 그래프 데이터로 모델 학습
for graph_idx, (graph, x, adj_tensor) in enumerate(graphs):
    # Initialize the GAT model for the current graph
    in_features = x.shape[1]
    n_heads = adj_tensor.shape[1]
    hidden_features = 4 * n_heads
    out_features = 2 * n_heads
    d_h = 4 * n_heads
    gat_model = GAT(in_features, hidden_features, out_features, n_heads, d_h).to(device)

    # Set the optimizer and loss function
    optimizer = optim.Adam(gat_model.parameters(), lr=0.01)
    criterion = torch.nn.NLLLoss().to(device)

    # Move the feature matrix and adjacency tensor to the GPU
    x = x.to(device)
    adj_tensor = adj_tensor.to(device)

    # Training loop
    epochs = 100
    for epoch in range(epochs):
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output = gat_model(x, adj_tensor)

        # Generate random labels for the current graph
        num_nodes = x.shape[0]
        labels = torch.tensor([random.randint(0, 1) for _ in range(num_nodes)]).to(device)

        # Compute the loss
        loss = criterion(output, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        print("Graph {}: Epoch: {:03d}, Loss: {:.4f}".format(graph_idx+1, epoch+1, loss.item()))

Graph 1: Epoch: 001, Loss: -0.0254
Graph 1: Epoch: 002, Loss: -0.1674
Graph 1: Epoch: 003, Loss: -0.3029
Graph 1: Epoch: 004, Loss: -0.3608
Graph 1: Epoch: 005, Loss: -0.3142
Graph 1: Epoch: 006, Loss: -0.4000
Graph 1: Epoch: 007, Loss: -0.5569
Graph 1: Epoch: 008, Loss: -0.4095
Graph 1: Epoch: 009, Loss: -0.3047
Graph 1: Epoch: 010, Loss: -0.4282
Graph 1: Epoch: 011, Loss: -0.4983
Graph 1: Epoch: 012, Loss: -0.6975
Graph 1: Epoch: 013, Loss: -0.4644
Graph 1: Epoch: 014, Loss: -0.5200
Graph 1: Epoch: 015, Loss: -0.2665
Graph 1: Epoch: 016, Loss: -0.6714
Graph 1: Epoch: 017, Loss: -0.4025
Graph 1: Epoch: 018, Loss: -0.4637
Graph 1: Epoch: 019, Loss: -0.6675
Graph 1: Epoch: 020, Loss: -0.4056
Graph 1: Epoch: 021, Loss: -0.5430
Graph 1: Epoch: 022, Loss: -0.5476
Graph 1: Epoch: 023, Loss: -0.4104
Graph 1: Epoch: 024, Loss: -0.6197
Graph 1: Epoch: 025, Loss: -0.2773
Graph 1: Epoch: 026, Loss: -0.2756
Graph 1: Epoch: 027, Loss: -0.5531
Graph 1: Epoch: 028, Loss: -0.3454
Graph 1: Epoch: 029,