# Node Classification Using Node Embeddings

*By: Aman Barot*

This notebook performs node classification using node embeddings. Two sets of embeddings are used: DeepWalk and node2vec. The data set used is the PubMed citation network data set. In summary, both DeepWalk and node2vec have an **accuracy around 70% on test data**.

In [12]:
from torch_geometric.datasets import Planetoid
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report


In [13]:
# setting the device
# MPS is for Apple Silicon Macs, CUDA is for NVIDIA GPUs, and CPU is the fallback
device = None
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

## Data Set Description

The PubMed dataset is a citation network data set. Each node in the network is a biomedical research paper and the edges correspond to citations between papers. There are a total of 19,717 nodes in the graph.

In [14]:
pubmed_data = Planetoid(root='data/', name='PubMed',)

## Node Classification using DeepWalk Embeddings

In [15]:
train_output = torch.load(
    "checkpoints/run_20250712_225510/ckpt_epoch_0_batch_269568.pt"
    )['model_state_dict']

In [16]:
# Get train and validation masks from the Planetoid dataset
train_mask = pubmed_data[0].train_mask.numpy()
val_mask = pubmed_data[0].val_mask.numpy()
labels_all = pubmed_data[0].y.numpy()

# Prepare train and validation sets using node embeddings
X = train_output['inp_emb.weight'].detach().cpu().numpy()
y = labels_all

X_train = X[train_mask]
y_train = y[train_mask]
X_val = X[val_mask]
y_val = y[val_mask]

### Performance on Validation and Testing Sets

In [17]:

# Train logistic regression on train set
logreg = LogisticRegression(max_iter=1000, solver='lbfgs')
logreg.fit(X_train, y_train)

# Predict on validation set
y_val_pred_logreg = logreg.predict(X_val)

# Evaluate
acc_logreg = accuracy_score(y_val, y_val_pred_logreg)
report_logreg = classification_report(y_val, y_val_pred_logreg)
print(f"Validation Accuracy (Logistic Regression): {acc_logreg:.4f}")
print("Classification Report (Logistic Regression):\n", report_logreg)

Validation Accuracy (Logistic Regression): 0.7300
Classification Report (Logistic Regression):
               precision    recall  f1-score   support

           0       0.64      0.70      0.67        98
           1       0.72      0.79      0.75       194
           2       0.79      0.69      0.74       208

    accuracy                           0.73       500
   macro avg       0.72      0.73      0.72       500
weighted avg       0.74      0.73      0.73       500



In [18]:
# Predict on test set
test_mask = pubmed_data[0].test_mask.numpy()
X_test = X[test_mask]
y_test = y[test_mask]
y_test_pred_logreg = logreg.predict(X_test)
# Evaluate
acc_logreg = accuracy_score(y_test, y_test_pred_logreg)
report_logreg = classification_report(y_test, y_test_pred_logreg)
print(f"Test Accuracy (Logistic Regression): {acc_logreg:.4f}")
print("Classification Report (Logistic Regression):\n", report_logreg)

Test Accuracy (Logistic Regression): 0.7090
Classification Report (Logistic Regression):
               precision    recall  f1-score   support

           0       0.60      0.68      0.64       180
           1       0.69      0.78      0.74       413
           2       0.80      0.65      0.71       407

    accuracy                           0.71      1000
   macro avg       0.70      0.70      0.70      1000
weighted avg       0.72      0.71      0.71      1000



## Node Classification Using node2vec Embeddings

In [23]:
train_output = torch.load(
    "checkpoints/run_20250712_151314/ckpt_epoch_0_batch_269568.pt"
    )['model_state_dict']

In [24]:
# Get train and validation masks from the Planetoid dataset
train_mask = pubmed_data[0].train_mask.numpy()
val_mask = pubmed_data[0].val_mask.numpy()
labels_all = pubmed_data[0].y.numpy()

# Prepare train and validation sets using node embeddings
X = train_output['inp_emb.weight'].detach().cpu().numpy()
y = labels_all

X_train = X[train_mask]
y_train = y[train_mask]
X_val = X[val_mask]
y_val = y[val_mask]

### Performance on Validation and Testing Sets

In [25]:

# Train logistic regression on train set
logreg = LogisticRegression(max_iter=1000, solver='lbfgs')
logreg.fit(X_train, y_train)

# Predict on validation set
y_val_pred_logreg = logreg.predict(X_val)

# Evaluate
acc_logreg = accuracy_score(y_val, y_val_pred_logreg)
report_logreg = classification_report(y_val, y_val_pred_logreg)
print(f"Validation Accuracy (Logistic Regression): {acc_logreg:.4f}")
print("Classification Report (Logistic Regression):\n", report_logreg)

Validation Accuracy (Logistic Regression): 0.7000
Classification Report (Logistic Regression):
               precision    recall  f1-score   support

           0       0.58      0.78      0.66        98
           1       0.72      0.73      0.72       194
           2       0.77      0.63      0.70       208

    accuracy                           0.70       500
   macro avg       0.69      0.71      0.69       500
weighted avg       0.71      0.70      0.70       500



In [26]:
# Predict on test set
test_mask = pubmed_data[0].test_mask.numpy()
X_test = X[test_mask]
y_test = y[test_mask]
y_test_pred_logreg = logreg.predict(X_test)
# Evaluate
acc_logreg = accuracy_score(y_test, y_test_pred_logreg)
report_logreg = classification_report(y_test, y_test_pred_logreg)
print(f"Test Accuracy (Logistic Regression): {acc_logreg:.4f}")
print("Classification Report (Logistic Regression):\n", report_logreg)

Test Accuracy (Logistic Regression): 0.6970
Classification Report (Logistic Regression):
               precision    recall  f1-score   support

           0       0.61      0.71      0.66       180
           1       0.70      0.71      0.71       413
           2       0.74      0.68      0.71       407

    accuracy                           0.70      1000
   macro avg       0.68      0.70      0.69      1000
weighted avg       0.70      0.70      0.70      1000

