In [1]:
from torch_geometric.datasets import Twitch
import os.path as osp

import torch
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [3]:
dataset = Twitch(root='data/Twitch', name='EN')
print(dataset[0])
#
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Downloading https://graphmining.ai/datasets/ptg/twitch/EN.npz


Data(x=[7126, 128], edge_index=[2, 77774], y=[7126])

Dataset: Twitch():
Number of graphs: 1
Number of features: 128
Number of classes: 2


Processing...
Done!


In [11]:
from torch_geometric.utils import train_test_split_edges


data = train_test_split_edges(dataset[0])

print('Train edges:', data.train_pos_edge_index.size(1))
print('Validation edges (positive):', data.val_pos_edge_index.size(1))
print('Validation edges (negative):', data.val_neg_edge_index.size(1))
print('Test edges (positive):', data.test_pos_edge_index.size(1))
print('Test edges (negative):', data.test_neg_edge_index.size(1))

print(data)



Train edges: 60052
Validation edges (positive): 1766
Validation edges (negative): 1766
Test edges (positive): 3532
Test edges (negative): 3532
Data(x=[7126, 128], y=[7126], val_pos_edge_index=[2, 1766], test_pos_edge_index=[2, 3532], train_pos_edge_index=[2, 60052], train_neg_adj_mask=[7126, 7126], val_neg_edge_index=[2, 1766], test_neg_edge_index=[2, 3532])


In [15]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()


model = Net(dataset.num_features, 64, 32).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
print(device)

cpu


In [16]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x.to(device), data.train_pos_edge_index.to(device))

    pos_edge_index = data.train_pos_edge_index.to(device)
    pos_out = model.decode(z, pos_edge_index)
    pos_loss = criterion(pos_out, torch.ones(pos_out.size(0), device=device))

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, 
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edge_index.size(1)
    ).to(device)
    neg_out = model.decode(z, neg_edge_index)
    neg_loss = criterion(neg_out, torch.zeros(neg_out.size(0), device=device))

    loss = pos_loss + neg_loss
    loss.backward()
    optimizer.step()
    return loss.item()

In [13]:
def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x.to(device), data.train_pos_edge_index.to(device))
    
    pos_out = model.decode(z, pos_edge_index.to(device))
    neg_out = model.decode(z, neg_edge_index.to(device))

    pos_y = torch.ones(pos_out.size(0), device=device)
    neg_y = torch.zeros(neg_out.size(0), device=device)
    y = torch.cat([pos_y, neg_y])
    pred = torch.cat([pos_out, neg_out])

    loss = criterion(pred, y).item()
    pred = pred > 0
    acc = pred.eq(y).sum().item() / y.size(0)
    return loss, acc

In [17]:
for epoch in range(1, 201):
    loss = train()
    val_loss, val_acc = test(data.val_pos_edge_index, data.val_neg_edge_index)
    test_loss, test_acc = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Loss: 1.6221, Val Loss: 0.6134, Val Acc: 0.5770, Test Loss: 0.6135, Test Acc: 0.5810
Epoch: 002, Loss: 1.1809, Val Loss: 0.6419, Val Acc: 0.5589, Test Loss: 0.6518, Test Acc: 0.5534
Epoch: 003, Loss: 1.2422, Val Loss: 0.6205, Val Acc: 0.5603, Test Loss: 0.6250, Test Acc: 0.5617
Epoch: 004, Loss: 1.1881, Val Loss: 0.6058, Val Acc: 0.5807, Test Loss: 0.6041, Test Acc: 0.5858
Epoch: 005, Loss: 1.1464, Val Loss: 0.6059, Val Acc: 0.5943, Test Loss: 0.6018, Test Acc: 0.6070
Epoch: 006, Loss: 1.1417, Val Loss: 0.6001, Val Acc: 0.6065, Test Loss: 0.5971, Test Acc: 0.6179
Epoch: 007, Loss: 1.1275, Val Loss: 0.5894, Val Acc: 0.6104, Test Loss: 0.5895, Test Acc: 0.6230
Epoch: 008, Loss: 1.1036, Val Loss: 0.5815, Val Acc: 0.6175, Test Loss: 0.5859, Test Acc: 0.6217
Epoch: 009, Loss: 1.0864, Val Loss: 0.5785, Val Acc: 0.6203, Test Loss: 0.5868, Test Acc: 0.6244
Epoch: 010, Loss: 1.0742, Val Loss: 0.5750, Val Acc: 0.6305, Test Loss: 0.5847, Test Acc: 0.6290
Epoch: 011, Loss: 1.0602, Val 

In [18]:
gData = dataset[0]
print(gData)

Data(x=[7126, 128], edge_index=[2, 77774], y=[7126])


In [21]:
def get_neighbor_count(data, node_index):
    if node_index < 0 or node_index >= data.num_nodes:
        raise ValueError("exceed the dataset")
    edge_index = data.edge_index
    neighbors = edge_index[1][edge_index[0] == node_index]
    neighbor_of_neighbor_count = 0
    for neighbor in neighbors:
        second_neighbors = edge_index[1][edge_index[0] == neighbor]
        neighbor_of_neighbor_count += second_neighbors.size(0)
    return neighbors.size(0), neighbor_of_neighbor_count

In [24]:
# print out the size of the size graph (the neighbors of the target node, the neibhors of the neighbors node)
for i in range(100):
    print(get_neighbor_count(gData, i))

(2, 6)
(27, 801)
(2, 339)
(8, 133)
(2, 81)
(5, 833)
(11, 194)
(3, 13)
(2, 13)
(13, 1840)
(2, 6)
(5, 103)
(3, 9)
(13, 290)
(11, 648)
(6, 551)
(6, 96)
(7, 131)
(3, 474)
(10, 237)
(8, 176)
(8, 329)
(3, 732)
(59, 2283)
(123, 5005)
(2, 76)
(92, 4983)
(2, 5)
(3, 80)
(4, 392)
(38, 1803)
(3, 138)
(10, 377)
(4, 49)
(12, 755)
(14, 272)
(7, 837)
(5, 1161)
(3, 798)
(5, 134)
(5, 810)
(8, 191)
(5, 77)
(14, 333)
(6, 114)
(2, 23)
(10, 1675)
(2, 46)
(11, 169)
(6, 90)
(2, 16)
(3, 11)
(7, 93)
(10, 219)
(3, 25)
(15, 307)
(15, 1519)
(14, 284)
(13, 1559)
(4, 32)
(5, 112)
(6, 129)
(5, 772)
(3, 33)
(5, 1416)
(23, 943)
(3, 724)
(3, 12)
(5, 448)
(6, 121)
(4, 20)
(2, 5)
(3, 81)
(28, 395)
(5, 79)
(19, 950)
(3, 42)
(6, 367)
(5, 46)
(2, 7)
(18, 173)
(18, 541)
(4, 75)
(7, 288)
(3, 282)
(13, 368)
(9, 276)
(2, 723)
(7, 1048)
(6, 442)
(3, 250)
(11, 541)
(2, 10)
(155, 5411)
(25, 2801)
(6, 879)
(8, 210)
(5, 60)
(3, 85)
(11, 1538)
