# Link Prediction

GPU acceleration dengan CUDA

In [2]:
%env NX_CUGRAPH_AUTOCONFIG=True

env: NX_CUGRAPH_AUTOCONFIG=True


In [3]:
!pip install igraph

Collecting igraph
  Downloading igraph-0.11.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting texttable>=1.6.2 (from igraph)
  Downloading texttable-1.7.0-py2.py3-none-any.whl.metadata (9.8 kB)
Downloading igraph-0.11.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m43.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading texttable-1.7.0-py2.py3-none-any.whl (10 kB)
Installing collected packages: texttable, igraph
Successfully installed igraph-0.11.8 texttable-1.7.0


In [4]:
!pip install networkit

Collecting networkit
  Downloading networkit-11.1.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Downloading networkit-11.1.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.0/11.0 MB[0m [31m72.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: networkit
Successfully installed networkit-11.1.post1


In [5]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import networkx as nx
import pickle
import random
import igraph as ig
import networkit as nk

from itertools import combinations
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import mutual_info_classif
from tqdm import tqdm
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB

## Dataset Preparation

In [6]:
pickle_file_path = 'dataset/amazon_copurchase_graph.pickle'
with open(pickle_file_path, 'rb') as f:
    G = pickle.load(f)

print(G)

DiGraph with 259102 nodes and 1207337 edges


### Features

#### Node Features

In [7]:

print(f"Total Nodes: {G.number_of_nodes()}")


for node, data in list(G.nodes(data=True))[:5]:
    print(f"Node: {node}, Data: {data}")

print()
sample_node = next(iter(G.nodes(data=True)))[1]
print("Node features:", list(sample_node.keys()))

Total Nodes: 259102
Node: 1, Data: {'title': 'Patterns of Preaching: A Sermon Sampler', 'group': 'Book', 'salesrank': 396585.0, 'review_cnt': 2, 'downloads': 2, 'rating': 5.0, 'in_degree': 0, 'out_degree': 4, 'pagerank_centrality': 6.210153588242165e-07, 'betweenness_centrality': 0.0, 'harmonic_closeness_centrality': 0.1442557706580312, 'degree_centrality': 1.5437995221940477e-05, 'community': 10}
Node: 2, Data: {'title': 'Candlemas: Feast of Flames', 'group': 'Book', 'salesrank': 168596.0, 'review_cnt': 12, 'downloads': 12, 'rating': 4.5, 'in_degree': 1, 'out_degree': 4, 'pagerank_centrality': 7.560926314778459e-07, 'betweenness_centrality': 31563.672353370643, 'harmonic_closeness_centrality': 0.1444868764333364, 'degree_centrality': 1.92974940274256e-05, 'community': 10}
Node: 4, Data: {'title': 'Life Application Bible Commentary: 1 and 2 Timothy and Titus', 'group': 'Book', 'salesrank': 631289.0, 'review_cnt': 1, 'downloads': 1, 'rating': 4.0, 'in_degree': 24, 'out_degree': 5, 'page

Fitur-fitur dari node dalam graph ini meliputi:  

*   **`title`**:  
    *   **Tipe Data**: String (Teks)  
    *   **Deskripsi**: Nama atau judul produk. Fitur ini memberikan deskripsi tekstual tentang produk yang dimaksud.  
    *   **Contoh**: "Patterns of Preaching: A Sermon Sampler", "Candlemas: Feast of Flames", dll.  

*   **`group`**:  
    *   **Tipe Data**: String (Kategorikal)  
    *   **Deskripsi**: Kategori atau grup tempat produk tersebut berada. Fitur ini membantu dalam memahami jenis produk (misalnya, Buku, Musik, DVD, dll.).  
    *   **Contoh**: "Book"  

*   **`salesrank`**:  
    *   **Tipe Data**: Float  
    *   **Deskripsi**: Peringkat penjualan produk di Amazon. Semakin rendah nilai `salesrank`, semakin tinggi tingkat penjualan dan popularitasnya. Fitur ini sering digunakan untuk mengukur seberapa baik suatu produk terjual di Amazon.  
    *   **Contoh**: `396585.0`, `168596.0`, `1270652.0`, dll.  

*   **`review_cnt`**:  
    *   **Tipe Data**: Integer  
    *   **Deskripsi**: Jumlah ulasan pelanggan yang diterima oleh produk. Nilai `review_cnt` yang lebih tinggi bisa menunjukkan tingkat visibilitas produk yang lebih besar, popularitas yang lebih tinggi, atau keterlibatan pelanggan yang lebih banyak.  
    *   **Contoh**: `2`, `12`, `1`, `1`, `0`, dll.  

*   **`downloads`**:  
    *   **Tipe Data**: Integer  
    *   **Deskripsi**: Jumlah unduhan yang terkait dengan produk. Makna pastinya dapat bervariasi tergantung pada sumber dataset. Bisa saja mewakili unduhan produk digital atau metrik keterlibatan lainnya. Dalam konteks produk "Book" pada contoh ini, bisa merujuk pada unduhan sampel buku atau bentuk keterlibatan lain yang relevan dengan dataset.  
    *   **Contoh**: `2`, `12`, `1`, `1`, `0`, dll.  

*   **`rating`**:  
    *   **Tipe Data**: Float  
    *   **Deskripsi**: Rata-rata rating pelanggan terhadap produk, biasanya dalam skala 0 hingga 5 (atau sistem serupa). Fitur ini mencerminkan tingkat kepuasan pelanggan serta persepsi kualitas produk secara keseluruhan.  
    *   **Contoh**: `5.0`, `4.5`, `5.0`, `4.0`, `0.0`, dll.   

*   **`in_degree`**:  
    *   **Tipe Data**: Integer  
    *   **Deskripsi**: Jumlah edge (sisi) yang masuk ke node ini. Menunjukkan seberapa banyak produk lain yang terhubung ke produk ini dalam graph. Dalam konteks dataset ini, bisa menunjukkan seberapa sering produk ini direferensikan oleh produk lain.  
    *   **Contoh**: `0`, `1`, `24`, `53`, `21`, dll.  

*   **`out_degree`**:  
    *   **Tipe Data**: Integer  
    *   **Deskripsi**: Jumlah edge (sisi) yang keluar dari node ini. Menunjukkan seberapa banyak produk lain yang direferensikan oleh produk ini.  
    *   **Contoh**: `4`, `4`, `5`, `5`, `5`, dll.  

*   **`pagerank_centrality`**:  
    *   **Tipe Data**: Float  
    *   **Deskripsi**: Skor PageRank node dalam graph. Metrik ini mengukur kepentingan sebuah node berdasarkan jumlah dan kualitas tautan yang mengarah ke node tersebut. Semakin tinggi nilainya, semakin berpengaruh node tersebut dalam jaringan.  
    *   **Contoh**: `6.21e-07`, `7.56e-07`, `1.34e-05`, dll.  

*   **`betweenness_centrality`**:  
    *   **Tipe Data**: Float  
    *   **Deskripsi**: Mengukur seberapa sering sebuah node menjadi perantara dalam jalur terpendek antara dua node lainnya. Node dengan betweenness centrality tinggi berperan sebagai "jembatan" yang menghubungkan berbagai bagian dalam graph.  
    *   **Contoh**: `0.0`, `31563.67`, `6528478.27`, `15442396.47`, dll.  

*   **`harmonic_closeness_centrality`**:  
    *   **Tipe Data**: Float  
    *   **Deskripsi**: Versi alternatif dari closeness centrality yang menghitung seberapa dekat suatu node dengan node lain berdasarkan jarak harmonik. Makin tinggi nilainya, makin dekat node tersebut ke banyak node lain dalam graph.  
    *   **Contoh**: `0.1442`, `0.1444`, `0.1558`, `0.1658`, dll.  

*   **`degree_centrality`**:  
    *   **Tipe Data**: Float  
    *   **Deskripsi**: Mengukur proporsi node lain yang terhubung dengan node ini dalam graph. Degree centrality dihitung sebagai jumlah total koneksi (degree) node ini dibagi dengan jumlah maksimum koneksi yang mungkin dalam graph.  
    *   **Contoh**: `1.54e-05`, `1.92e-05`, `1.11e-04`, `2.23e-04`, dll.  

*   **`community`**:  
    *   **Tipe Data**: Integer (Kategorikal)  
    *   **Deskripsi**: Identitas komunitas tempat node ini tergabung, berdasarkan algoritma deteksi komunitas. Node dalam komunitas yang sama lebih cenderung saling terhubung dibandingkan dengan node di komunitas lain.  
    *   **Contoh**: `10`, `10`, `10`, `31`, dll.  

#### Edge Features

In [8]:
print(f"Total Edges: {G.number_of_edges()}")

for u, v, data in list(G.edges(data=True))[:5]:
    print(f"Edge: ({u}, {v}), Data: {data}")

sample_edge = next(iter(G.edges(data=True)))[2]
print("\nEdge features:", list(sample_edge.keys()))


Total Edges: 1207337
Edge: (1, 2), Data: {}
Edge: (1, 4), Data: {}
Edge: (1, 5), Data: {}
Edge: (1, 15), Data: {}
Edge: (2, 11), Data: {}

Edge features: []


Tidak ada edge feature pada graph ini

### Split Dataset

In [9]:
nkG = nk.nxadapter.nx2nk(G)

edges = list(G.edges())
existing_edges = set(edges)

# Sampling dengan Networkit Graph (lebih cepat)
def sample_non_edges_nk(nkG, num_samples):
    non_edges = set()
    nodes = list(G.nodes())

    while len(non_edges) < num_samples:
        u, v = random.sample(nodes, 2)
        if not nkG.hasEdge(u, v):
            non_edges.add((u, v))

    return list(non_edges)

num_samples = len(edges)
non_edges = sample_non_edges_nk(nkG, num_samples)

train_edges, test_edges = train_test_split(edges, test_size=0.2, random_state=42)
train_non_edges = random.sample(non_edges, len(train_edges))
test_non_edges = random.sample(non_edges, len(test_edges))

G_train = nx.Graph()
G_train.add_nodes_from(G.nodes())
G_train.add_edges_from(train_edges)

print(f"Train Edges: {len(train_edges)}, Test Edges: {len(test_edges)}")
print(f"Train Non-Edges: {len(train_non_edges)}, Test Non-Edges: {len(test_non_edges)}")

Train Edges: 965869, Test Edges: 241468
Train Non-Edges: 965869, Test Non-Edges: 241468


## Heuristic Link Prediction

In [10]:
def heuristic_score(G, node_pairs, method):
    scores = []
    for u, v in tqdm(node_pairs, desc=f"Computing {method} scores", disable=True):
        if method == "common_neighbors":
            score = len(list(nx.common_neighbors(G, u, v)))
        elif method == "jaccard":
            score = list(nx.jaccard_coefficient(G, [(u, v)]))[0][2]
        elif method == "adamic_adar":
            score = list(nx.adamic_adar_index(G, [(u, v)]))[0][2]
        elif method == "preferential_attachment":
            score = list(nx.preferential_attachment(G, [(u, v)]))[0][2]
        else:
            raise ValueError("Method not recognized")
        scores.append(score)
    return scores


In [11]:
# Metrik evaluasi ranking problem
def precision_at_k(y_true, y_scores, k):
    sorted_indices = np.argsort(y_scores)[::-1]
    top_k = sorted_indices[:k]
    return np.mean(y_true[top_k])

def recall_at_k(y_true, y_scores, k):
    sorted_indices = np.argsort(y_scores)[::-1]
    top_k = sorted_indices[:k]
    return np.sum(y_true[top_k]) / np.sum(y_true)

def mean_average_precision(y_true, y_scores):
    sorted_indices = np.argsort(y_scores)[::-1]
    relevant = np.cumsum(y_true[sorted_indices])
    precision_at_i = relevant / (np.arange(len(y_true)) + 1)
    return np.sum(precision_at_i * y_true[sorted_indices]) / np.sum(y_true)

def f1_beta_at_k(y_true, y_scores, k, beta=1):
    precision_k = precision_at_k(y_true, y_scores, k)
    recall_k = recall_at_k(y_true, y_scores, k)

    if precision_k + recall_k == 0:
        return 0.0

    beta_sq = beta ** 2
    return (1 + beta_sq) * (precision_k * recall_k) / ((beta_sq * precision_k) + recall_k)



## Basic ML Link Prediction

In [None]:
# train_pairs = train_edges + train_non_edges
# train_labels = np.array([1] * len(train_edges) + [0] * len(train_non_edges))

# test_pairs = test_edges + test_non_edges
# test_labels = np.array([1] * len(test_edges) + [0] * len(test_non_edges))

# train_features = extract_features(G_train, train_pairs)
# test_features = extract_features(G_train, test_pairs)

# X_train = train_features.drop(columns=["node1", "node2"])
# X_test = test_features.drop(columns=["node1", "node2"])

# # X_train_selected, selected_features = feature_selection(X_train, train_labels, top_k=10)
# # selected_feature_names = X_train.columns[selected_features]
# # print("Fitur yang dipilih:", selected_feature_names.tolist())

# # X_test_selected = X_test.iloc[:, selected_features]

# models = {
#     "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
#     "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
#     "Naive Bayes": GaussianNB()
# }

# k = 100000

# print("{:<25} {:>10} {:>10} {:>15} {:>15} {:>10} {:>10}".format(
#     "Model", "AUC-ROC", "AP Score", f"Precision@{k}", f"Recall@{k}", "MAP", f"F1@{k}"
# ))
# print("=" * 105)

# for name, model in models.items():
#     model.fit(X_train, train_labels)

#     probabilities = model.predict_proba(X_test)[:, 1]

#     auc_roc = roc_auc_score(test_labels, probabilities)
#     ap_score = average_precision_score(test_labels, probabilities)

#     precision_at_k_ml = precision_at_k(test_labels, probabilities, k)
#     recall_at_k_ml = recall_at_k(test_labels, probabilities, k)
#     map_score = mean_average_precision(test_labels, probabilities)
#     f1_k_ml = f1_beta_at_k(test_labels, probabilities, k)

#     print("{:<25} {:>10.6f} {:>10.6f} {:>15.6f} {:>15.6f} {:>10.6f} {:>10.6f}".format(
#         name.upper(), auc_roc, ap_score, precision_at_k_ml, recall_at_k_ml, map_score, f1_k_ml
#     ))


## Graph Embedding Link Prediction

In [17]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding
!pip install pykeen

Collecting pykeen
  Downloading pykeen-1.11.0-py3-none-any.whl.metadata (85 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/85.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dataclasses-json (from pykeen)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting click-default-group (from pykeen)
  Downloading click_default_group-1.2.4-py2.py3-none-any.whl.metadata (2.8 kB)
Collecting optuna>=2.0.0 (from pykeen)
  Downloading optuna-4.2.1-py3-none-any.whl.metadata (17 kB)
Collecting more-click (from pykeen)
  Downloading more_click-0.1.2-py3-none-any.whl.metadata (4.3 kB)
Collecting pystow>=0.4.3 (from pykeen)
  Downloading pystow-0.7.0-py3-none-any.whl.metadata (17 kB)
Collecting docdata (from pykeen)
  Downloading docdata-0.0.4-py3-none-any.whl.metadata (13 kB)
Collecting class-resolver>=0.5.1 (from pykeen)

In [29]:
import pandas as pd
import numpy as np
from pykeen.triples import TriplesFactory

# Assuming 'train_edges' and 'G' are defined from the previous code

# Convert edges to a suitable format for PyKEEN
triples = np.array(train_edges)

relation_placeholder = np.full((triples.shape[0], 1), "bought_with", dtype=object)
triples = np.column_stack((triples[:, 0], relation_placeholder, triples[:, 1]))
triples = triples.astype(str)

num_samples = int(len(triples) * 0.4)  # Take 10% of the triples

tf = TriplesFactory.from_labeled_triples(triples[:num_samples], create_inverse_triples=True)

tf_train, tf_validation, tf_test = tf.split([0.7, 0.15, 0.15])

INFO:pykeen.triples.splitting:done splitting triples to groups of sizes [48724, 57952, 57953]


In [30]:
from pykeen.pipeline import pipeline

# Define and train the model
result = pipeline(
    training=tf_train,
    testing=tf_test,
    validation=tf_validation,
    model='TransE',
    epochs=30,
    model_kwargs={'embedding_dim': 200},
    optimizer='Adam',
    optimizer_kwargs={'lr': 0.01},
    loss='MarginRankingLoss',
    training_kwargs={'batch_size': 256},
    negative_sampler='basic',
    regularizer='LP',
    regularizer_kwargs={'weight': 0.01},
    evaluator_kwargs={
        'filtered': True,
        'batch_size': 64
    }
)

# Evaluate the model
result.metric_results.to_df()

INFO:pykeen.pipeline.api:Using device: None
INFO:pykeen.triples.triples_factory:Creating inverse triples.


Training epochs on cuda:0:   0%|          | 0/30 [00:00<?, ?epoch/s]

INFO:pykeen.triples.triples_factory:Creating inverse triples.


Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Training batches on cuda:0:   0%|          | 0/2113 [00:00<?, ?batch/s]

Evaluating on cuda:0:   0%|          | 0.00/58.0k [00:00<?, ?triple/s]

INFO:pykeen.evaluation.evaluator:Evaluation took 300.62s seconds


Unnamed: 0,Side,Rank_type,Metric,Value
0,head,optimistic,adjusted_arithmetic_mean_rank_index,0.015806
1,tail,optimistic,adjusted_arithmetic_mean_rank_index,0.022362
2,both,optimistic,adjusted_arithmetic_mean_rank_index,0.019084
3,head,realistic,adjusted_arithmetic_mean_rank_index,0.015805
4,tail,realistic,adjusted_arithmetic_mean_rank_index,0.022361
...,...,...,...,...
220,tail,realistic,adjusted_hits_at_k,-0.000007
221,both,realistic,adjusted_hits_at_k,0.000002
222,head,pessimistic,adjusted_hits_at_k,0.000010
223,tail,pessimistic,adjusted_hits_at_k,-0.000007


In [31]:
from pykeen.evaluation import RankBasedEvaluator

evaluator = RankBasedEvaluator()
results = evaluator.evaluate(
    model=result.model,
    mapped_triples=tf_test.mapped_triples,
    batch_size=64,  # Adjust if necessary
    additional_filter_triples=[tf_train.mapped_triples, tf_validation.mapped_triples],
)

results

print(f"Hits@1: {results.get_metric('hits@1')}")
print(f"Hits@3: {results.get_metric('hits@3')}")
print(f"Hits@5: {results.get_metric('hits@5')}")
print(f"Hits@10: {results.get_metric('hits@10')}")
print(f"Mean Reciprocal Rank: {results.get_metric('mean_reciprocal_rank')}")

Evaluating on cuda:0:   0%|          | 0.00/58.0k [00:00<?, ?triple/s]

INFO:pykeen.evaluation.evaluator:Evaluation took 299.18s seconds


<pykeen.evaluation.rank_based_evaluator.RankBasedMetricResults at 0x7e915d0aba10>

## Graph Neural Network (GNN) Link Prediction

In [None]:
# prompt: Using any libraries, try doing graph neural network.

!pip install torch-geometric

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score

# Assuming 'G' and 'train_edges', 'test_edges', 'test_non_edges' are defined from previous code

# Create PyG Data object
edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()
x = torch.tensor(list(G.nodes()), dtype=torch.long) # Placeholder for node features
data = Data(x=x, edge_index=edge_index)

# Define GNN model
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

model = GNN(in_channels=1, hidden_channels=16, out_channels=16)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

# Training loop (simplified for demonstration)
for epoch in range(1, 101):
    model.train()
    optimizer.zero_grad()

    z = model.encode(data.x.float().view(-1, 1), data.edge_index) # Need a float tensor

    pos_edge_index = data.edge_index
    neg_edge_index = negative_sampling(edge_index=pos_edge_index, num_nodes=data.num_nodes, num_neg_samples=len(pos_edge_index[0]))

    pos_out = model.decode(z, pos_edge_index)
    neg_out = model.decode(z, neg_edge_index)

    out = torch.cat([pos_out, neg_out], dim=0)
    pos_labels = torch.ones(pos_out.size(0))
    neg_labels = torch.zeros(neg_out.size(0))
    labels = torch.cat([pos_labels, neg_labels], dim=0)

    loss = torch.nn.BCEWithLogitsLoss()(out, labels)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

# Evaluation
model.eval()
with torch.no_grad():
    z = model.encode(data.x.float().view(-1, 1), data.edge_index)

    test_edges_tensor = torch.tensor(test_edges, dtype=torch.long).t().contiguous()
    test_non_edges_tensor = torch.tensor(test_non_edges, dtype=torch.long).t().contiguous()

    test_pos_out = model.decode(z, test_edges_tensor)
    test_neg_out = model.decode(z, test_non_edges_tensor)

    out_test = torch.cat([test_pos_out, test_neg_out], dim=0)
    labels_test = torch.cat([torch.ones(test_pos_out.size(0)), torch.zeros(test_neg_out.size(0))], dim=0)

    roc_auc = roc_auc_score(labels_test.cpu().numpy(), out_test.cpu().numpy().flatten())
    print(f'ROC AUC score: {roc_auc:.4f}')
