In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import networkx as nx
import matplotlib.pyplot as plt
import torch.optim as optim
from scipy import sparse as sp
import random

device = torch.device('cuda')

In [2]:
# Decoder
class Decoder(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(Decoder, self).__init__()
        self.n_heads = n_heads
        self.hidden_features = hidden_features
        self.d_h = d_h

        self.linear1 = nn.Linear(in_features, hidden_features)
        self.linear2 = nn.Linear(hidden_features, out_features)
        self.attn_linear1 = nn.Linear(hidden_features, d_h)
        self.attn_linear2 = nn.Linear(hidden_features, d_h)
        self.softmax = nn.Softmax(dim=1)
        self.activation = nn.Tanh()

    def forward(self, x, v_prev, neighbors):
        n_nodes = x.shape[0]

        v_prev = self.linear1(v_prev)
        v_prev = v_prev.unsqueeze(0).repeat(n_nodes, 1)

        neighbors = self.linear1(neighbors)

        attn_input = torch.cat([v_prev, neighbors], dim=-1)
        attn_input = self.activation(attn_input)

        attn1 = self.attn_linear1(attn_input)
        attn2 = self.attn_linear2(attn_input)

        attn_output = torch.matmul(attn1, attn2.transpose(0, 1)) / (self.d_h ** 0.5)
        attn_output = self.activation(attn_output)

        masked_attn_output = attn_output.masked_fill(neighbors == 0, float('-inf'))
        attn_weights = self.softmax(masked_attn_output)

        x = self.linear2(x)
        x = x.unsqueeze(0).repeat(n_nodes, 1, 1)

        output = torch.matmul(attn_weights.unsqueeze(1), x)
        output = output.squeeze(1)

        return output, attn_weights

In [3]:
# Encoder
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, n_heads, is_concat = True, dropout = 0.6, leacky_relu_negative_slope = 0.2):
        super(GraphAttentionLayer, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.is_concat = is_concat
        self.n_heads = n_heads

        if is_concat:
            assert out_features % n_heads == 0

            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features

        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias = False)

        self.attn = nn.Linear(self.n_hidden * 2, 1, bias = False)
        self.activation = nn.LeakyReLU(negative_slope = leacky_relu_negative_slope)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout)
        # self.decoder = Decoder(self.n_hidden)
        

    def forward(self, x, adj):
        n_nodes = x.shape[0]
        g=self.linear(x).view(n_nodes, self.n_heads, self.n_hidden)
        g_repeat = g.repeat(n_nodes, 1,1)
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim = -1)
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
        e = self.activation(self.attn(g_concat))
        e = e.squeeze(-1)
        assert adj.shape[0] == 1 or adj.shape[0] == n_nodes
        assert adj.shape[1] == 1 or adj.shape[1] == n_nodes
        assert adj.shape[2] == 1 or adj.shape[2] == self.n_heads
        e=e.masked_fill(adj == 0, 1)
        a = self.softmax(e)
        a = self.dropout(a)
        attn_res = torch.einsum('ijh,jhf->ihf', a, g)
        if self.is_concat:
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            return attn_res.mean(dim = 1)
        
        # attention_coefficients = self.decoder(attn_res)
        # return attention_coefficients


In [4]:
class GAT(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(GAT, self).__init__()
        self.n_heads = n_heads
        self.attention1 = GraphAttentionLayer(in_features, hidden_features, n_heads)
        self.attention2 = GraphAttentionLayer(hidden_features, out_features, n_heads)
        self.norm= nn.LayerNorm(out_features)
        self.decoder = Decoder(out_features, hidden_features, out_features, n_heads, d_h)
    
    def forward(self, x, adj):
        x = self.attention1(x, adj)
        x = self.attention2(x, adj)
        x = self.norm(x)
        x = F.softmax(x, dim=-1)
        return x
    
    def decode(self, x, v_prev, neighbors):
        return self.decoder(x, v_prev, neighbors)

In [6]:
# Create multiple dummy graphs with different node sizes
graph_list = []

# Graph 1
G1 = nx.Graph()
G1.add_nodes_from(range(4))  # Add nodes
G1.add_edges_from([(0, 1), (1, 2), (2, 3)])  # Add edges

adj_matrix1 = nx.adjacency_matrix(G1)
adj_matrix1 = adj_matrix1 + sp.eye(adj_matrix1.shape[0])  # Add self-loop
adj_tensor1 = torch.Tensor(adj_matrix1.todense())


num_nodes1 = G1.number_of_nodes()
in_features1 = 8
x1 = torch.randn(num_nodes1, in_features1)

# Resize adjacency tensor to match the input features size
adj_tensor1 = adj_tensor1.unsqueeze(0)  # Add an extra dimension
adj_tensor1 = adj_tensor1.repeat(num_nodes1, 1, 1)  # Repeat the adjacency tensor
adj_tensor1 = adj_tensor1.transpose(0, 1)  # Transpose the dimensions

# Generate labels for Graph 1
labels1 = torch.randint(0, 2, (num_nodes1,)).to(device)

graph_list.append((x1, adj_tensor1))

# Graph 2
G2 = nx.Graph()
G2.add_nodes_from(range(5))  # Add nodes
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])  # Add edges

adj_matrix2 = nx.adjacency_matrix(G2)
adj_matrix2 = adj_matrix2 + sp.eye(adj_matrix2.shape[0])  # Add self-loop
adj_tensor2 = torch.Tensor(adj_matrix2.todense())

num_nodes2 = G2.number_of_nodes()
in_features2 = 8
x2 = torch.randn(num_nodes2, in_features2)

# Resize adjacency tensor to match the input features size
adj_tensor2 = adj_tensor2.unsqueeze(0)  # Add an extra dimension
adj_tensor2 = adj_tensor2.repeat(num_nodes2, 1, 1)  # Repeat the adjacency tensor
adj_tensor2 = adj_tensor2.transpose(0, 1)  # Transpose the dimensions

# Generate labels for Graph 2
labels2 = torch.randint(0, 2, (num_nodes2,)).to(device)

graph_list.append((x2, adj_tensor2))

# Access the graphs and their components from the graph list
for i, (feature_matrix, adj_tensor) in enumerate(graph_list):
    # # Expand the adj_tensor dimensions if using multiple attention heads
    # if adj_tensor.dim() == 2:
    #     adj_tensor = adj_tensor.unsqueeze(2).expand(-1, -1, self.n_heads)

    graph_list[i] = (feature_matrix.cuda(), adj_tensor.cuda())


    print(f"Graph {i+1} - Feature Matrix:")
    print(feature_matrix)

    print(f"\nGraph {i+1} - Adjacency Tensor:")
    print(adj_tensor)

    print("\n")

Graph 1 - Feature Matrix:
tensor([[-1.0030, -1.8632, -0.7330,  2.5421, -1.6586,  0.1317, -1.6019,  1.3239],
        [-1.6821, -0.7835,  0.1511,  1.6416, -0.3283,  0.6239, -0.4302,  0.3748],
        [-0.4827,  0.9448,  0.7304,  1.0704, -1.1878,  0.3692, -1.8309, -1.8636],
        [ 1.6016,  0.6327, -0.9213,  2.3722,  0.1404,  0.3952, -0.0668,  1.0579]])

Graph 1 - Adjacency Tensor:
tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.]],

        [[1., 1., 1., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 0.]],

        [[0., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.]],

        [[0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.]]])


Graph 2 - Feature Matrix:
tensor([[-0.7394,  1.7005, -0.1197, -1.1318,  1.6592,  0.2626, -0.4151,  0.1770],
        [ 0.2701,  0.3951, -0.2521,  0.3053, -2.1761,  0.5216, -0.

In [7]:
# Create and initialize the GAT models for each graph
gat_models = []
for i, (feature_matrix, adj_tensor) in enumerate(graph_list):
    in_features = feature_matrix.shape[1]
    n_heads = adj_tensor.shape[2]
    hidden_features = 4 * n_heads
    out_features = 2 * n_heads
    d_h = 4 * n_heads
    gat_model = GAT(in_features, hidden_features, out_features, n_heads, d_h).cuda()
    gat_models.append(gat_model)
    feature_matrix = feature_matrix.cuda()
    adj_tensor = adj_tensor.cuda()
    output = gat_model(feature_matrix, adj_tensor)
    print(f"Graph {i+1} - Output:")
    print(output)
    #output : 각 노드에 대한 클래스 라벨 예측 값

Graph 1 - Output:
tensor([[0.0198, 0.1400, 0.3839, 0.1161, 0.0379, 0.2337, 0.0343, 0.0343],
        [0.0156, 0.0506, 0.2588, 0.0805, 0.0436, 0.0411, 0.4440, 0.0659],
        [0.0170, 0.1283, 0.3172, 0.1139, 0.0301, 0.0301, 0.2601, 0.1031],
        [0.0217, 0.1102, 0.1182, 0.0431, 0.0269, 0.5848, 0.0350, 0.0601]],
       device='cuda:0', grad_fn=<SoftmaxBackward0>)
Graph 2 - Output:
tensor([[0.0386, 0.3556, 0.1901, 0.0851, 0.0224, 0.0251, 0.1601, 0.0178, 0.0231,
         0.0821],
        [0.0435, 0.3346, 0.0120, 0.0337, 0.0818, 0.0468, 0.1661, 0.0247, 0.0344,
         0.2224],
        [0.0327, 0.1145, 0.3320, 0.1621, 0.0595, 0.0319, 0.0431, 0.0221, 0.1897,
         0.0125],
        [0.1681, 0.2533, 0.0167, 0.0466, 0.0199, 0.0144, 0.1773, 0.1206, 0.0915,
         0.0915],
        [0.0070, 0.5392, 0.0506, 0.0506, 0.0506, 0.0506, 0.0581, 0.1019, 0.0411,
         0.0503]], device='cuda:0', grad_fn=<SoftmaxBackward0>)


In [8]:
# Set the optimizer and loss function
optimizer = optim.Adam(gat_model.parameters(), lr=0.01)
criterion = torch.nn.NLLLoss()

# Move the model and loss function to the GPU
gat_model = gat_model.cuda()
criterion = criterion.cuda()

# Training loop
epochs = 100

for epoch in range(epochs):
    total_loss = 0.0
    for graph_idx, (feature_matrix, adj_tensor) in enumerate(graph_list):
        feature_matrix = feature_matrix.to(device)
        adj_tensor = adj_tensor.to(device)
         # Generate random labels for the current graph
        num_nodes = feature_matrix.shape[0]
        labels = torch.tensor([random.randint(0, 1) for _ in range(num_nodes)]).to(device)
                
        # Zero the gradients
        gat_model.zero_grad()
        
        # Forward pass
        output = gat_models[graph_idx](feature_matrix, adj_tensor)
        
        # Compute the loss
        loss = criterion(output.squeeze(0), labels)
        total_loss += loss.item()
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    # Calculate the average loss for the epoch
    average_loss = total_loss / len(graph_list)
    
    for graph_idx in range(len(graph_list)):
        print("Graph {}: Epoch: {:03d}, Loss: {:.4f}".format(graph_idx+1, epoch+1, average_loss))
     
        

Graph 1: Epoch: 001, Loss: -0.0426
Graph 2: Epoch: 001, Loss: -0.0426
Graph 1: Epoch: 002, Loss: -0.1730
Graph 2: Epoch: 002, Loss: -0.1730
Graph 1: Epoch: 003, Loss: -0.1797
Graph 2: Epoch: 003, Loss: -0.1797
Graph 1: Epoch: 004, Loss: -0.1878
Graph 2: Epoch: 004, Loss: -0.1878
Graph 1: Epoch: 005, Loss: -0.1145
Graph 2: Epoch: 005, Loss: -0.1145
Graph 1: Epoch: 006, Loss: -0.1843
Graph 2: Epoch: 006, Loss: -0.1843
Graph 1: Epoch: 007, Loss: -0.2047
Graph 2: Epoch: 007, Loss: -0.2047
Graph 1: Epoch: 008, Loss: -0.1988
Graph 2: Epoch: 008, Loss: -0.1988
Graph 1: Epoch: 009, Loss: -0.2134
Graph 2: Epoch: 009, Loss: -0.2134
Graph 1: Epoch: 010, Loss: -0.1352
Graph 2: Epoch: 010, Loss: -0.1352
Graph 1: Epoch: 011, Loss: -0.3044
Graph 2: Epoch: 011, Loss: -0.3044
Graph 1: Epoch: 012, Loss: -0.2079
Graph 2: Epoch: 012, Loss: -0.2079
Graph 1: Epoch: 013, Loss: -0.2585
Graph 2: Epoch: 013, Loss: -0.2585
Graph 1: Epoch: 014, Loss: -0.2304
Graph 2: Epoch: 014, Loss: -0.2304
Graph 1: Epoch: 015,