<a href="https://colab.research.google.com/github/Sayed-Hossein-Hosseini/Node_Classification_in_the_Amazon_Product_Graph/blob/master/Edge_Prediction_in_the_Amazon_Product_Graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Edge Prediction in the Amazon Product Graph**

## **Library**

In [1]:
!pip install torch torchvision torchaudio
!pip install torch-geometric ogb
!pip install matplotlib





In [2]:
import torch
import torch.nn as nn
from torch.serialization import add_safe_globals
from torch_geometric.data.data import Data
from torch_geometric.data.data import DataEdgeAttr
from torch_geometric.data import Data
from torch_geometric.transforms import ToUndirected, RandomLinkSplit
from ogb.nodeproppred import NodePropPredDataset
from torch_geometric.loader import NeighborLoader
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.utils import train_test_split_edges, negative_sampling
from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import random

In [3]:
add_safe_globals({'torch_geometric.data.data.Data': Data})
add_safe_globals({'torch_geometric.data.data.DataEdgeAttr': DataEdgeAttr})

## **Loading the ogbn-products Dataset**

In [4]:
dataset = NodePropPredDataset(name='ogbn-products')
split_idx = dataset.get_idx_split()

graph, labels = dataset[0]

print(graph)

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/products.zip


Downloaded 1.38 GB: 100%|██████████| 1414/1414 [00:24<00:00, 58.71it/s]


Extracting dataset/products.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:01<00:00,  1.02s/it]


Saving...
{'edge_index': array([[      0,  152857,       0, ..., 2449028,   53324, 2449028],
       [ 152857,       0,   32104, ...,  162836, 2449028,   53324]]), 'edge_feat': None, 'node_feat': array([[ 0.03193326, -0.1958605 ,  0.0519961 , ...,  0.07669606,
        -0.3929545 , -0.06478424],
       [-0.02405794,  0.63032097,  1.0605699 , ..., -1.6874819 ,
         3.5866776 ,  0.818219  ],
       [ 0.33269015, -0.5585958 , -0.28860757, ..., -0.37157044,
         0.2520575 ,  0.04153213],
       ...,
       [ 0.10660695,  0.2654852 , -0.00567423, ...,  1.0867023 ,
         0.07590195, -1.1736895 ],
       [ 0.24968362, -0.25740346,  0.41230008, ...,  1.5465808 ,
         1.0309792 , -0.29657176],
       [ 0.7175324 , -0.23930131,  0.04430327, ..., -1.0132493 ,
        -0.41407427, -0.08227058]], dtype=float32), 'num_nodes': 2449029}


## **Preprocessing and Preparing Data Graphs**

In [5]:
edge_index = torch.tensor(graph['edge_index'], dtype=torch.long)
x = torch.tensor(graph['node_feat'], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

print(data)

Data(x=[2449029, 100], edge_index=[2, 123718280])


## **Graph Undirectedness**

In [6]:
# Convert the graph to undirected (i.e. (i→j) + (j→i))
data = ToUndirected()(data)
print(data)

Data(x=[2449029, 100], edge_index=[2, 123718152])


## **GCN and GraphSAGE Models**

In [None]:
# GCN Model
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x  # node embeddings


# GraphSAGE Model
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x  # node embeddings

## **Link Predictor**

In [None]:
class MLPLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=128, num_layers=3, dropout=0.3):
        super().__init__()
        layers = []
        input_dim = in_channels * 2  # concatenated features
        for i in range(num_layers - 1):
            layers.append(nn.Linear(input_dim if i == 0 else hidden_channels, hidden_channels))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(hidden_channels, 1))  # final score
        self.mlp = nn.Sequential(*layers)

    def forward(self, z, edge_label_index):
        z_src = z[edge_label_index[0]]
        z_dst = z[edge_label_index[1]]
        z_pair = torch.cat([z_src, z_dst], dim=-1)
        return self.mlp(z_pair).squeeze(-1)

## **Evaluate Encoder and Link Predictor**

In [None]:
@torch.no_grad()
def evaluate_link_prediction(encoder, link_predictor, data):
    encoder.eval()
    link_predictor.eval()

    z = encoder(data.x, data.edge_index)

    results = {}
    for split in ['train', 'val', 'test']:
        pred = link_predictor(z, data[split].edge_label_index)
        label = data[split].edge_label.float()

        auc = roc_auc_score(label.cpu(), pred.cpu())
        ap = average_precision_score(label.cpu(), pred.cpu())
        results[split] = {'AUC': auc, 'AP': ap}

    return results

## **Train Link Predictor**

In [None]:
def train_link_prediction_model(encoder, link_predictor, data, epochs=20):
    encoder.to(device)
    link_predictor.to(device)
    data = {split: split_data.to(device) for split, split_data in data.items()}

    optimizer = torch.optim.Adam(
        list(encoder.parameters()) + list(link_predictor.parameters()),
        lr=0.01,
        weight_decay=1e-4
    )
    criterion = torch.nn.BCEWithLogitsLoss()

    for epoch in range(1, epochs + 1):
        encoder.train()
        link_predictor.train()

        optimizer.zero_grad()

        z = encoder(data['train'].x, data['train'].edge_index)

        pred = link_predictor(z, data['train'].edge_label_index)
        label = data['train'].edge_label.float()

        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()

        # Evaluate
        encoder.eval()
        link_predictor.eval()

        results = {}
        for split in ['train', 'val', 'test']:
            z = encoder(data[split].x, data[split].edge_index)
            with torch.no_grad():
                pred = link_predictor(z, data[split].edge_label_index)
                score = torch.sigmoid(pred).cpu().numpy()
                label = data[split].edge_label.cpu().numpy()
                auc = roc_auc_score(label, score)
            results[split] = auc

        print(f"[Epoch {epoch:02d}] Loss: {loss.item():.4f} | "
              f"AUC - Train: {results['train']:.4f}, "
              f"Val: {results['val']:.4f}, "
              f"Test: {results['test']:.4f}")

## **Train Model + Link Predictor**

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Model Encoder: GCN
encoder_GCN = GCN(
    in_channels=data.num_node_features, 
    hidden_channels=128, 
    out_channels=128  # Size of the node embeddings
)

# Link Predictor: Dot product or MLP
link_predictor = LinkPredictor(in_channels=128)  # Define this earlier if not already defined

# Training
print("\n🧠 Training GCN + LinkPredictor for Link Prediction:")
results = train_link_prediction_model(
    encoder_GCN, link_predictor, data_splits, epochs=100
)

In [None]:
# Model Encoder: GraphSAGE
encoder_SAGE = GraphSAGE(
    in_channels=data.num_node_features, 
    hidden_channels=128, 
    out_channels=128  # Size of the node embeddings
)

# Link Predictor: Dot product or MLP
link_predictor = LinkPredictor(in_channels=128)  # Define this earlier if not already defined

# Training
print("\n🧠 Training GraphSAGE + LinkPredictor for Link Prediction:")
results = train_link_prediction_model(
    encoder_SAGE, link_predictor, data_splits, epochs=100
)

## **Confusion Matrix**

In [None]:
@torch.no_grad()
def plot_link_prediction_confusion_matrix(encoder, link_predictor, data_splits, split="test"):
    encoder.eval()
    link_predictor.eval()

    z = encoder(data.x.to(device), data.edge_index.to(device))
    edge_index = data_splits[split].edge_label_index.to(device)
    edge_label = data_splits[split].edge_label.to(device)

    preds = link_predictor(z, edge_index).sigmoid()
    pred_labels = (preds > 0.5).long().cpu().numpy()
    true_labels = edge_label.cpu().numpy()

    cm = confusion_matrix(true_labels, pred_labels)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Negative", "Positive"])
    
    disp.plot(cmap='Greens')
    plt.title(f'Confusion Matrix for Link Prediction ({split.capitalize()} set)')
    plt.show()

### **GCN Model**

In [None]:
plot_link_prediction_confusion_matrix(encoder_GCN, link_predictor, data_splits, split="test")

### **SAGE Model**

In [None]:
plot_link_prediction_confusion_matrix(encoder_SAGE, link_predictor, data_splits, split="test")