In [109]:
from dotenv import load_dotenv
import os
import numpy as np
import matplotlib.pyplot as plt
from neo4j import GraphDatabase

load_dotenv()

True

In [110]:
os.getenv("NEO4J_URI"), os.getenv("NEO4J_USER"), os.getenv("NEO4J_PASS")

('bolt://localhost:7687', 'neo4j', '123456789')

In [111]:
#OUTLINE:
#1. Connect to the database
#2. Get the data
#2. Check for the existing centroids (this can happen if partial clustering has been done)
#3. Initialize K-Means on all the data, setting the initial centroids to the existing centroids
#4. Run K-Means until convergence
#5. Save the centroids to the database
#6. For nodes that didn't have a relationship to a cluster before, add them. 
#7. For each centroid, compute its "title" property by taking the embeddings of the n-closest nodes and averaging them

In [112]:
driver = GraphDatabase.driver("neo4j://localhost:7687", auth=("neo4j", "password"))
driver.get_server_info()

<neo4j.api.ServerInfo at 0x28bb16750>

In [113]:
#import the papers to the database
import json
PAPERS_DATA_PATH = "extract/papers_data.json"

with open(PAPERS_DATA_PATH, 'r') as json_file:
        papers_data = json.load(json_file)
        #remove the papers with no abstract
        papers_data = [paper for paper in papers_data if len(paper['abstract']) > 0]

CIPHR = """
WITH $data as data
UNWIND data as paper
MERGE (p:Paper {title: paper.title})
SET p.abstract = paper.abstract
SET p.url = paper.url
SET p.embeddings = paper.abstract_embedding
"""

with driver.session() as session:
    result = session.run(CIPHR, data=papers_data)

In [114]:
#create a vector index on the embeddings property
#in neo4j you can do this with 
# CALL db.index.vector.createNodeIndex('paper-embeddings', 'Paper', 'embeddings', 1536, 'cosine')

In [115]:
from neo4j import GraphDatabase

# Assuming you have already created a Neo4j driver instance
# driver = GraphDatabase.driver(uri, auth=(user, password))

GET_PAPERS = """
MATCH (p:Paper)
WHERE p.embeddings IS NOT NULL
RETURN p.title as title, p.abstract as abstract, p.url as url, p.embeddings as embeddings
"""

with driver.session() as session:
    papers_result = session.run(GET_PAPERS)
    papers = [dict(record) for record in papers_result]
    print(f"Number of papers: {len(papers)}")

GET_CENTROIDS = """
MATCH (c:Centroid)
RETURN c.coordinates as coordinates, c.title as centroid_title
"""

with driver.session() as session:
    centroids_result = session.run(GET_CENTROIDS)
    centroids = [dict(record) for record in centroids_result]

# papers, centroids

Number of papers: 19


In [116]:
#perform k-means clustering
from sklearn.cluster import KMeans
K=3

#convert the embeddings to a numpy array
embeddings = np.array([paper['embeddings'] for paper in papers])

#convert the centroids to a numpy array
centroid_embeddings = np.array([centroid['coordinates'] for centroid in centroids])

# if there aren't enough centroids, initialize the rest randomly
if len(centroids) < K:
    # Initialize the missing centroids randomly
    if centroid_embeddings.shape[0] == 0:
        centroid_embeddings = "kmeans++"
    else:
        centroid_embeddings = np.vstack([centroid_embeddings, np.random.rand(K - len(centroids), 1536)])

#initialize the k-means algorithm
kmeans = KMeans(n_clusters=K, init=centroid_embeddings, n_init=1)

#fit the algorithm to the data
predictions = kmeans.fit_predict(embeddings)
predictions, kmeans.cluster_centers_


(array([2, 2, 2, 2, 1, 2, 1, 1, 1, 0, 2, 1, 2, 1, 0, 1, 2, 2, 2],
       dtype=int32),
 array([[ 0.00957119,  0.00659721,  0.01460397, ...,  0.00216391,
         -0.02703943, -0.01040091],
        [ 0.00425553,  0.00378272,  0.01124918, ..., -0.0180465 ,
         -0.01018701, -0.0188333 ],
        [-0.00924649,  0.0132581 ,  0.00046706, ..., -0.01001863,
         -0.01145077, -0.01652182]]))

In [117]:
#add the centroids to the database
ADD_CENTROIDS = """
UNWIND $centroids as centroid
MERGE (c:Centroid {id: centroid.id})
SET c.coordinates = centroid.coordinates
"""

centroids = [{'id': i, 'coordinates': centroid} for i, centroid in enumerate(kmeans.cluster_centers_)]

with driver.session() as session:
    session.run(ADD_CENTROIDS, centroids=centroids)

#add the cluster relationships to the database
ADD_CLUSTER_RELATIONSHIPS = """
UNWIND $data as d
MATCH (p:Paper {title: d.title})
MATCH (c:Centroid {id: d.cluster})
MERGE (p)-[:PART_OF_CLUSTER]->(c)
"""

data = [{'title': paper['title'], 'cluster': prediction} for paper, prediction in zip(papers, predictions)]

with driver.session() as session:
    session.run(ADD_CLUSTER_RELATIONSHIPS, data=data)


In [118]:
#compute the title of each centroid
AVERAGE_EMBEDDINGS = """
match(c:Centroid)
call db.index.vector.queryNodes('paper-embeddings', 3, c.coordinates) YIELD node, score
RETURN c.id as cluster_id, COLLECT({title: node.title, score: score}) AS nearest_nodes
"""

with driver.session() as session:
    result = session.run(AVERAGE_EMBEDDINGS, centroids=centroids)
    centroids_with_nearest_nodes = [dict(record) for record in result]

centroids_with_nearest_nodes

[{'cluster_id': 0,
  'nearest_nodes': [{'title': 'Microbially Induced Carbonate Precipitation Using Microorganisms Enriched from Calcareous Materials in Marine Environments and Their Metabolites',
    'score': 0.989421010017395},
   {'title': 'Characteristics of bio-CaCO 3 from microbial bio-mineralization with different bacteria species',
    'score': 0.9894202947616577},
   {'title': 'Biocement Fabrication and Design Application for a Sustainable Urban Area',
    'score': 0.9602454900741577}]},
 {'cluster_id': 1,
  'nearest_nodes': [{'title': 'Getting into the groove: Opportunities to enhance the ecological value of hard coastal infrastructure using ﬁne-scale surface textures',
    'score': 0.9751214981079102},
   {'title': 'Learning from nature to enhance Blue engineering of marine infrastructure',
    'score': 0.9691004753112793},
   {'title': 'Availability of microhabitats explains a widespread pattern and informs theory on ecological engineering of boulder reefs',
    'score': 0.

In [119]:
from langchain.document_loaders import PyPDFLoader
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings

#langchain configuration
response_schemas = [
    ResponseSchema(name="labels", description="An array of objects of schema {cluster_id: int, title: str} representing the labels of each cluster. Only give one title per cluster. You are given data of a handful of papers, as well as the cluster they belong to. The score is a number betweeo 0-1 giving the confidence of the model that the paper belongs to the cluster. Make your labels general, remember you are only given a subset of the papers, there are many more and the label should represent all of them"),
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)

format_instructions = output_parser.get_format_instructions()
prompt = ChatPromptTemplate(
    messages=[
        HumanMessagePromptTemplate.from_template(
    "We are clustering research papers based on their embeddings. You are given the titles of some research papers, as well as their distance to the cluster. Give each cluster a unique name that is representative of the data it is close to. .\n{format_instructions}\n{clusters}")]
    ,
    input_variables=["clusters"],
    partial_variables={"format_instructions": format_instructions}
)

chat_model = ChatOpenAI(temperature=0.5)

_input = prompt.format_prompt(clusters=centroids_with_nearest_nodes)
output = chat_model(_input.to_messages())
labels = output_parser.parse(output.content)
labels

{'labels': [{'cluster_id': 0,
   'title': 'Microbially Induced Carbonate Precipitation'},
  {'cluster_id': 1,
   'title': 'Ecological Enhancement of Coastal Infrastructure'},
  {'cluster_id': 2, 'title': 'Engineered Living Materials'}]}

In [120]:
#using the labels, update the database
UPDATE_CENTROID_TITLES = """
UNWIND $labels as label
MATCH (c:Centroid {id: label.cluster_id})
SET c.title = label.title
"""

with driver.session() as session:
    result = session.run(UPDATE_CENTROID_TITLES, labels=labels.get('labels'))
