In [None]:
! pip install --upgrade torch_geometric ogb
! python -c "import ogb; print(ogb.__version__)"

from ogb.linkproppred import PygLinkPropPredDataset
from torch_geometric.loader import DataLoader

# Download and process data, stored in './dataset/ogbl_collab/'
dataset = PygLinkPropPredDataset(name="ogbl-collab", root='dataset/')

Downloading http://snap.stanford.edu/ogb/data/linkproppred/collab.zip


Downloaded 0.11 GB: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [00:06<00:00, 19.18it/s]


Extracting dataset/collab.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 32.49it/s]


Converting graphs into PyG objects...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 222.62it/s]

Saving...



Done!


Dataset Name: ogbl-collab
Number of Nodes: 235868
Number of Edges: 2358104
Number of Node Features: 128
Number of Edge Features: None


In [None]:
data = dataset[0]

print("Dataset Name:", dataset.name)
print("Number of Nodes:", data.num_nodes)
print("Number of Edges:", data.edge_index.shape[1])
print("Number of Node Features:", data.x.shape[1] if data.x is not None else "None")
print("Number of Edge Features:", data.edge_attr.shape[1] if data.edge_attr is not None else "None")

In [3]:
# Get split edges from the dataset
split_edge = dataset.get_edge_split()

# Extract edge lists for train, valid, and test splits
train_edges = split_edge["train"]['edge']  # List of training edges
valid_edges = split_edge["valid"]['edge']  # List of validation edges
test_edges = split_edge["test"]['edge']    # List of test edges

from torch_geometric.data import Data

# Convert to Data objects if necessary
train_data = Data(edge_index=train_edges.T)  # Transpose to match PyG format
valid_data = Data(edge_index=valid_edges.T)
test_data = Data(edge_index=test_edges.T)

# Create DataLoaders
train_loader = DataLoader([train_data], batch_size=32, shuffle=True)
valid_loader = DataLoader([valid_data], batch_size=32, shuffle=False)
test_loader = DataLoader([test_data], batch_size=32, shuffle=False)

In [None]:
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
from torch_geometric.data import Data
from torch.utils.data import DataLoader

# Define a simple GCN-based link prediction model
class GCNLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNLinkPredictor, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
    
    def predict(self, x_i, x_j):
        # Dot product to predict link probability
        return (x_i * x_j).sum(dim=-1)

# Load graph data
data = dataset[0]  # PyG graph object
in_channels = data.num_features
hidden_channels = 64
out_channels = 32

# Instantiate the model and optimizer
model = GCNLinkPredictor(in_channels, hidden_channels, out_channels)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
def train():
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    x = model(data.x, data.edge_index)
    
    # Positive edges
    pos_edge_index = train_edges.T
    
    # Sample negative edges
    neg_edge_index = negative_sampling(pos_edge_index, num_nodes=data.num_nodes)
    
    # Compute link prediction scores for positive and negative edges
    pos_pred = model.predict(x[pos_edge_index[0]], x[pos_edge_index[1]])
    neg_pred = model.predict(x[neg_edge_index[0]], x[neg_edge_index[1]])
    
    # Labels for positive and negative edges
    pos_label = torch.ones(pos_pred.size(0))
    neg_label = torch.zeros(neg_pred.size(0))
    
    # Concatenate predictions and labels
    pred = torch.cat([pos_pred, neg_pred], dim=0)
    label = torch.cat([pos_label, neg_label], dim=0)
    
    # Loss
    loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, label)
    loss.backward()
    optimizer.step()
    return loss.item()

# Training the model
for epoch in range(100):
    loss = train()
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss}')

# Evaluate on validation or test set
def evaluate(edge_index):
    model.eval()
    with torch.no_grad():
        x = model(data.x, data.edge_index)
        pos_pred = model.predict(x[edge_index[0]], x[edge_index[1]])
        return torch.sigmoid(pos_pred)

# Predict on validation and test edges
valid_scores = evaluate(valid_edges.T)
test_scores = evaluate(test_edges.T)

print("Validation Scores:", valid_scores)
print("Test Scores:", test_scores)


Epoch 0, Loss: 0.6763266324996948
Epoch 10, Loss: 0.5770453214645386
Epoch 20, Loss: 0.5569776296615601
Epoch 30, Loss: 0.5353121757507324
Epoch 40, Loss: 0.4935433864593506
Epoch 50, Loss: 0.48045605421066284
Epoch 60, Loss: 0.4622229337692261
Epoch 70, Loss: 0.45671918988227844
Epoch 80, Loss: 0.4531313478946686
Epoch 90, Loss: 0.44962701201438904
Validation Scores: tensor([1.0000, 0.8845, 0.7879,  ..., 0.9831, 0.8556, 0.6998])
Test Scores: tensor([0.0314, 0.8364, 0.7879,  ..., 0.9927, 0.8873, 0.8494])


In [5]:
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

# Convert to NetworkX graph for visualization
G = to_networkx(data, to_undirected=True)

# Plot original graph with nodes and edges
plt.figure(figsize=(10, 10))
pos = nx.spring_layout(G, seed=42)  # Use spring layout for visualization

# Draw original graph nodes and edges
nx.draw_networkx_nodes(G, pos, node_size=50, node_color="blue", alpha=0.7)
nx.draw_networkx_edges(G, pos, alpha=0.5)

# Visualize the predicted links with high confidence
def visualize_predicted_links(edge_index, scores, threshold=0.5):
    for i, (u, v) in enumerate(edge_index.T):
        # Check if the score exceeds the threshold
        if scores[i] > threshold:
            # Draw the edge if the score is high
            nx.draw_networkx_edges(
                G, pos, edgelist=[(u.item(), v.item())], edge_color="red", alpha=0.3, width=2
            )

# Assume `valid_scores` and `test_scores` contain predicted probabilities for links
# Visualize links with high confidence in validation and test sets
visualize_predicted_links(valid_edges.T, valid_scores, threshold=0.8)
visualize_predicted_links(test_edges.T, test_scores, threshold=0.8)

# Show plot
plt.title("Graph Visualization with Predicted Links")
plt.show()


KeyboardInterrupt: 

<Figure size 1000x1000 with 0 Axes>