In [19]:
from torch_geometric.datasets import Twitch
import os.path as osp

import torch
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling

In [20]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [21]:
dataset = Twitch(root='data/Twitch', name='EN')
print(dataset[0])
#
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Data(x=[7126, 128], edge_index=[2, 77774], y=[7126])

Dataset: Twitch():
Number of graphs: 1
Number of features: 128
Number of classes: 2


In [22]:
from torch_geometric.utils import train_test_split_edges


data = train_test_split_edges(dataset[0])

print('Train edges:', data.train_pos_edge_index.size(1))
print('Validation edges (positive):', data.val_pos_edge_index.size(1))
print('Validation edges (negative):', data.val_neg_edge_index.size(1))
print('Test edges (positive):', data.test_pos_edge_index.size(1))
print('Test edges (negative):', data.test_neg_edge_index.size(1))

print(data)

Train edges: 60052
Validation edges (positive): 1766
Validation edges (negative): 1766
Test edges (positive): 3532
Test edges (negative): 3532
Data(x=[7126, 128], y=[7126], val_pos_edge_index=[2, 1766], test_pos_edge_index=[2, 3532], train_pos_edge_index=[2, 60052], train_neg_adj_mask=[7126, 7126], val_neg_edge_index=[2, 1766], test_neg_edge_index=[2, 3532])


In [23]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        x = (z[edge_label_index[0]] * z[edge_label_index[1]])
        x = x.sum(dim=-1)
        return x

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

    def forward(self, x, edge_index, edge_label_index):
        outputs = {}
        x = self.conv1(x, edge_index).relu()
        outputs["conv1"] = x
        z = self.conv2(x, edge_index)
        outputs["conv2"] = z

        x = (z[edge_label_index[0]] * z[edge_label_index[1]])
        outputs["decode_mul"] = x
        x = x.sum(dim=-1)
        outputs["decode_sum"] = x

        prob_adj = z @ z.t()
        outputs["prob_adj"] = prob_adj
        outputs["decode_all_final"] = (prob_adj > 0).nonzero(as_tuple=False).t()
        return outputs
        


model = Net(dataset.num_features, 64, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
print(device)

cpu


In [24]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x.to(device), data.train_pos_edge_index.to(device))

    pos_edge_index = data.train_pos_edge_index.to(device)
    pos_out = model.decode(z, pos_edge_index)
    pos_loss = criterion(pos_out, torch.ones(pos_out.size(0), device=device))

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, 
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edge_index.size(1)
    ).to(device)
    neg_out = model.decode(z, neg_edge_index)
    neg_loss = criterion(neg_out, torch.zeros(neg_out.size(0), device=device))

    loss = pos_loss + neg_loss
    loss.backward()
    optimizer.step()
    return loss.item()

In [25]:
def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x.to(device), data.train_pos_edge_index.to(device))
    
    pos_out = model.decode(z, pos_edge_index.to(device))
    neg_out = model.decode(z, neg_edge_index.to(device))

    pos_y = torch.ones(pos_out.size(0), device=device)
    neg_y = torch.zeros(neg_out.size(0), device=device)
    y = torch.cat([pos_y, neg_y])
    pred = torch.cat([pos_out, neg_out])

    loss = criterion(pred, y).item()
    pred = torch.sigmoid(pred)  
    pred = pred > 0.5
    acc = pred.eq(y).sum().item() / y.size(0)
    return loss, acc

In [26]:
for epoch in range(1, 101):
    loss = train()
    val_loss, val_acc = test(data.val_pos_edge_index, data.val_neg_edge_index)
    test_loss, test_acc = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Loss: 2.5447, Val Loss: 0.6718, Val Acc: 0.5227, Test Loss: 0.6670, Test Acc: 0.5207
Epoch: 002, Loss: 1.2956, Val Loss: 0.6217, Val Acc: 0.5515, Test Loss: 0.6222, Test Acc: 0.5476
Epoch: 003, Loss: 1.1953, Val Loss: 0.6397, Val Acc: 0.5357, Test Loss: 0.6418, Test Acc: 0.5327
Epoch: 004, Loss: 1.2357, Val Loss: 0.6419, Val Acc: 0.5286, Test Loss: 0.6417, Test Acc: 0.5236
Epoch: 005, Loss: 1.2407, Val Loss: 0.6295, Val Acc: 0.5260, Test Loss: 0.6270, Test Acc: 0.5242
Epoch: 006, Loss: 1.2147, Val Loss: 0.6191, Val Acc: 0.5272, Test Loss: 0.6158, Test Acc: 0.5276
Epoch: 007, Loss: 1.1918, Val Loss: 0.6114, Val Acc: 0.5323, Test Loss: 0.6087, Test Acc: 0.5345
Epoch: 008, Loss: 1.1777, Val Loss: 0.6031, Val Acc: 0.5413, Test Loss: 0.6020, Test Acc: 0.5430
Epoch: 009, Loss: 1.1608, Val Loss: 0.5945, Val Acc: 0.5504, Test Loss: 0.5956, Test Acc: 0.5515
Epoch: 010, Loss: 1.1470, Val Loss: 0.5876, Val Acc: 0.5629, Test Loss: 0.5907, Test Acc: 0.5637
Epoch: 011, Loss: 1.1334, Val 

In [27]:
gData = dataset[0]
print(gData)

Data(x=[7126, 128], edge_index=[2, 77774], y=[7126])


In [30]:
def get_neighbor_count(data, node_index):
    if node_index < 0 or node_index >= data.num_nodes:
        raise ValueError("exceed the dataset")
    
    edge_index = data.edge_index
    
    # First-level neighbors
    neighbors = edge_index[1][edge_index[0] == node_index]
    neighbor_of_neighbor_count = 0
    third_neighbors_set = set()
    
    # Iterate through first-level neighbors
    for neighbor in neighbors:
        # Second-level neighbors
        second_neighbors = edge_index[1][edge_index[0] == neighbor]
        neighbor_of_neighbor_count += second_neighbors.size(0)
        
        # Iterate through second-level neighbors
        for second_neighbor in second_neighbors:
            # Third-level neighbors
            third_neighbors = edge_index[1][edge_index[0] == second_neighbor]
            third_neighbors_set.update(third_neighbors.tolist())
    
    # Remove first-level and second-level neighbors from third-level neighbors set
    all_first_and_second_neighbors = set(neighbors.tolist())
    for neighbor in neighbors:
        second_neighbors = edge_index[1][edge_index[0] == neighbor]
        all_first_and_second_neighbors.update(second_neighbors.tolist())
    
    third_neighbors_set.difference_update(all_first_and_second_neighbors)
    
    third_neighbors_list = list(third_neighbors_set)
    third_neighbors_count = len(third_neighbors_list)
    
    return neighbors.size(0), neighbor_of_neighbor_count, third_neighbors_count

In [37]:
# print out the size of the size graph
# node i - node j - subgraph size for node i - subgraph size for node j
# (the neighbors of the target node, the neibhors of the neighbors node)
# New algorithm to store sums and corresponding indices
results = []
max_sum = 0
max_node = 0

for i in range(1000):  # Assuming there are 5 nodes for the example
    a, b, c = get_neighbor_count(gData, i)
    sum_neighbors = a + b + c
    results.append((i, sum_neighbors, (a, b, c)))
    if max_sum < sum_neighbors:
        max_sum = sum_neighbors
        max_node = i
    print(i, (a, b, c), sum_neighbors)

# Sort the results based on the sum
results.sort(key=lambda x: x[1], reverse=True)

print("Sorted results based on sum:")
for result in results:
    print(f"Node {result[0]}: Sum {result[1]}")

print("Max node:", max_node, "Max sum:", max_sum)

0 (2, 6, 65) 73
1 (27, 801, 4459) 5287
2 (2, 339, 3116) 3457
3 (8, 133, 2477) 2618
4 (2, 81, 747) 830
5 (5, 833, 3718) 4556
6 (11, 194, 3063) 3268
7 (3, 13, 129) 145
8 (2, 13, 421) 436
9 (13, 1840, 3970) 5823
10 (2, 6, 204) 212
11 (5, 103, 763) 871
12 (3, 9, 766) 778
13 (13, 290, 3586) 3889
14 (11, 648, 3853) 4512
15 (6, 551, 3481) 4038
16 (6, 96, 1313) 1415
17 (7, 131, 2958) 3096
18 (3, 474, 1965) 2442
19 (10, 237, 1586) 1833
20 (8, 176, 2359) 2543
21 (8, 329, 2921) 3258
22 (3, 732, 2908) 3643
23 (59, 2283, 4706) 7048
24 (123, 5005, 4257) 9385
25 (2, 76, 1549) 1627
26 (92, 4983, 4005) 9080
27 (2, 5, 719) 726
28 (3, 80, 2996) 3079
29 (4, 392, 3688) 4084
30 (38, 1803, 4561) 6402
31 (3, 138, 2294) 2435
32 (10, 377, 3815) 4202
33 (4, 49, 1207) 1260
34 (12, 755, 4336) 5103
35 (14, 272, 1945) 2231
36 (7, 837, 3730) 4574
37 (5, 1161, 3445) 4611
38 (3, 798, 3173) 3974
39 (5, 134, 2729) 2868
40 (5, 810, 3746) 4561
41 (8, 191, 2435) 2634
42 (5, 77, 1134) 1216
43 (14, 333, 2637) 2984
44 (6, 114,

In [38]:
print("Sorted results based on sum:")
for result in results:
    print(f"Node {result[0]}: Sum {result[1]} : Values {result[2]}")

Sorted results based on sum:
Node 792: Sum 11355 : Values (147, 7867, 3341)
Node 166: Sum 11247 : Values (353, 6573, 4321)
Node 24: Sum 9385 : Values (123, 5005, 4257)
Node 93: Sum 9304 : Values (155, 5411, 3738)
Node 372: Sum 9176 : Values (94, 5363, 3719)
Node 26: Sum 9080 : Values (92, 4983, 4005)
Node 581: Sum 8966 : Values (253, 4521, 4192)
Node 877: Sum 8902 : Values (104, 4829, 3969)
Node 932: Sum 8529 : Values (84, 3825, 4620)
Node 779: Sum 8446 : Values (32, 4978, 3436)
Node 807: Sum 8369 : Values (63, 4514, 3792)
Node 675: Sum 8098 : Values (101, 3384, 4613)
Node 370: Sum 7873 : Values (64, 3722, 4087)
Node 560: Sum 7859 : Values (34, 4001, 3824)
Node 211: Sum 7648 : Values (66, 2961, 4621)
Node 952: Sum 7355 : Values (28, 3400, 3927)
Node 178: Sum 7344 : Values (80, 3203, 4061)
Node 793: Sum 7123 : Values (45, 3115, 3963)
Node 972: Sum 7093 : Values (59, 2957, 4077)
Node 938: Sum 7077 : Values (24, 2945, 4108)
Node 23: Sum 7048 : Values (59, 2283, 4706)
Node 485: Sum 7047 : 

In [None]:
def predict_edge(model, node_index1, node_index2):
    model.eval()
    with torch.no_grad():
        edge_label_index = torch.tensor([[node_index1], [node_index2]], device=device)
        prediction = model.forward(data.x.to(device), data.train_pos_edge_index.to(device), edge_label_index)
        return prediction

node_index1 = 45
node_index2 = 17

# for i in range(20):
#     for j in range(20):
#         if i != j:
#             prediction = predict_edge(model, i, j)
#             print(f"predict node {i} and node {j} probability that has an edge: {prediction}")

In [None]:


prediction = predict_edge(model, 12, 70)
print(prediction['decode_all_final'])

tensor([[   0,    0,    0,  ..., 7125, 7125, 7125],
        [   0,    2,    3,  ..., 7118, 7122, 7125]])


In [None]:
print(f'conv1 shape: {list(prediction['conv1'].shape)}')
print(f'conv2 shape: {list(prediction['conv2'].shape)}')
print(f'conv3 shape: {list(prediction['decode_mul'].shape)}')
print(f'final shape: {list(prediction['prob_adj'].shape)}')
print(f'final shape: {list(prediction['decode_all_final'].shape)}')

conv1 shape: [7126, 64]
conv2 shape: [7126, 64]
conv3 shape: [1, 64]
final shape: [7126, 7126]
final shape: [2, 23741230]


In [None]:
# dummy input
x = torch.randn(7126, 128)
edge_index = torch.randint(0, 7126, (2, 77774))
dummy_input = (x, edge_index, edge_index)
prediction = model.forward(x, edge_index, edge_index)


In [None]:
print(prediction)

{'conv1': tensor([[0.0000, 0.0401, 0.1980,  ..., 0.0000, 0.0000, 0.0000],
        [0.6730, 0.0000, 0.0000,  ..., 0.0383, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.1520],
        ...,
        [0.5915, 0.0325, 0.0000,  ..., 0.0000, 0.4006, 0.0000],
        [0.5642, 0.0913, 0.0377,  ..., 0.1975, 0.0000, 0.0000],
        [1.1327, 0.0702, 0.0000,  ..., 1.1574, 0.0350, 0.0000]],
       grad_fn=<ReluBackward0>), 'conv2': tensor([[-0.0646,  0.0085, -0.0156,  ..., -0.0884,  0.1009, -0.0490],
        [ 0.1263, -0.1882, -0.1733,  ..., -0.0648, -0.1400,  0.0913],
        [-0.0656, -0.2094, -0.1615,  ..., -0.0873, -0.0128, -0.0054],
        ...,
        [-0.0555, -0.2680, -0.0826,  ...,  0.0478, -0.0152,  0.0320],
        [-0.2204, -0.0187, -0.1971,  ..., -0.2343,  0.0632, -0.1698],
        [ 0.0695, -0.1956, -0.2670,  ..., -0.0942, -0.1813,  0.1866]],
       grad_fn=<AddBackward0>), 'decode_mul': tensor([[ 9.4007e-04,  5.2218e-02,  3.7268e-02,  ...,  7.5149e-03,
   

In [None]:
# python model to ONNX model
torch.onnx.export(model,               # model being run
                  dummy_input,         # model input 
                  "gnn_link_model.onnx",    # where to save the model
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=17,    # the ONNX version to export the model to
                #   do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['x', 'edge_index', 'edge_label_index'],   # the model's input names
                  output_names=['conv1', 'conv2', 'decode_mul', 'decode_sum', 'prob_adj', 'decode_all_final'],
                  dynamic_axes={'x': {0: 'num_nodes'},
                                'edge_index': {1: 'num_edges'},
                                'output': {0: 'batch_size'}})  # which axes should be considered dynamic)



In [None]:
# export data from dataset
import json, torch
def data_to_json(data):
    json_data = {}
    
    # Convert node features to a list of lists
    if data.x is not None:
        json_data['x'] = data.x.tolist()  # Assuming x is a tensor of node features
    
    # Convert edge index to a list of pairs/lists
    if data.edge_index is not None:
        edge_index_list = data.edge_index.tolist()  # Convert to [2] and then to list
        json_data['edge_index'] = edge_index_list 
    
    # Convert labels to a list
    if data.y is not None:
        json_data['y'] = data.y.tolist()  
    num_nodes = data.x.size(0)
    batch = torch.zeros(num_nodes, dtype=torch.int32)
    json_data['batch'] = batch.tolist()
    
    return json_data

json_data = data_to_json(dataset)
with open(f'twitch.json', 'w') as f:
    json.dump(json_data, f, indent=4)