Neo4j code for creating embeddings

In [None]:
CALL gds.graph.project(
  'authorGraph',
  'Author',
  'COAUTHORS'  // replace with your actual relationship type if different
);


CALL gds.node2vec.write(
  'authorGraph',
  {
    embeddingDimension: 64,
    writeProperty: 'node2vec'
  }
)
YIELD nodePropertiesWritten, writeMillis;

Python code for cluster-based classification

In [None]:
import pandas as pd
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from common.neo4j_utils import run_query, close_driver, Neo4jOperation


unsup_embedding_query = Neo4jOperation(query="""
MATCH (a:Author) WHERE a.node2vec IS NOT NULL
RETURN a.authorId AS authorId, a.node2vec AS embedding
""")
embedding_df = unsup_embedding_query.run()

X = np.array(embedding_df['embedding'].to_list())
kmeans = KMeans(n_clusters=10, random_state=0)


if X is None or len(X) == 0:
    print("X is empty. Aborting KMeans.")
    exit()

kmeans.fit(X)
embedding_df['cluster'] = kmeans.labels_

X = embedding_df['embedding'].to_list()
Y = embedding_df['cluster'].to_list()

x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.4, random_state=42)
clf = RandomForestClassifier()
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)

print("Cluster-Based Classification Report:")
print(classification_report(y_test, y_pred))



#ROC curve plotting


y_test_bin = label_binarize(y_test, classes=range(10))
y_pred_prob = clf.predict_proba(x_test)

plt.figure(figsize=(10, 7))

for i in range(10):
    fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_pred_prob[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'Class {i} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--', label='Random Guess (AUC = 0.5)')

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for Multi-Class Classification')
plt.legend(loc='lower right')
plt.show()