### Setup

In [1]:
%%capture
%load_ext kedro.ipython

In [2]:
%%capture
import logging
import os
import pickle

import pandas as pd
from dotenv import load_dotenv
from neo4j import GraphDatabase
from neo4j.exceptions import DriverError, Neo4jError
from sklearn.metrics import completeness_score, homogeneity_score, v_measure_score

load_dotenv(r"..\conf\local\.env")

In [3]:
# Parameters
MODEL_NAME: str = "nomic-embed-text-v1.5"
CONTRIBUTOR: str = "Health Promotion Board"

# specify content_category. input 'all' if running across all categories
CONTENT_CATEGORY: str = "live-healthy-articles"

In [4]:
# Connect to neo4j
uri = "neo4j://localhost:7687"
username = os.getenv("neo4j_username")
password = os.getenv("neo4j_password")

NEO4J = {
    "uri": uri,
    "auth": (username, password),
    "database": CONTENT_CATEGORY,  # create this database in neo4j first
}

# Test connection
with GraphDatabase.driver(**NEO4J) as driver:
    try:
        driver.verify_connectivity()
        print("Connection estabilished.")
    except (DriverError, Neo4jError) as exception:
        print(exception)

Connection estabilished.


In [5]:
INPUT_GROUNDTRUTH_PATH = os.path.join(
    "..",
    "data",
    "01_raw",
    "Synapxe Content Prioritisation - Live Healthy_020724.xlsx",
)

DATA_FOLDER_PATH = os.path.join(
    "..",
    "data",
    "07_model_output",
    f"{CONTENT_CATEGORY}",
)

INPUT_EMBEDDING_NEO4J_PATH = os.path.join(
    DATA_FOLDER_PATH,
    f"{CONTENT_CATEGORY}_{MODEL_NAME}_embeddings_neo4j.pkl",
)

OUTPUT_CLUSTER_METRICS_PATH = os.path.join(
    DATA_FOLDER_PATH,
    f"{CONTENT_CATEGORY}_compiled_model_variation_metrics.csv",
)

NEO4J_FOLDER_PATH = os.path.join(
    DATA_FOLDER_PATH,
    "neo4j",
)

if not os.path.exists(NEO4J_FOLDER_PATH):
    os.makedirs(NEO4J_FOLDER_PATH)

OUTPUT_PREDICTED_CLUSTER_PATH = os.path.join(
    NEO4J_FOLDER_PATH,
    f"{MODEL_NAME}_predicted_clusters.csv",
)

OUTPUT_CLUSTERED_NODES_PATH = os.path.join(
    NEO4J_FOLDER_PATH,
    f"{MODEL_NAME}_neo4j_clustered_data.csv",
)

OUTPUT_UNCLUSTERED_NODES_PATH = os.path.join(
    NEO4J_FOLDER_PATH,
    f"{MODEL_NAME}_neo4j_unclustered_data.csv",
)

OUTPUT_MONGODB_PREDICTED_CLUSTER_PATH = os.path.join(
    NEO4J_FOLDER_PATH,
    f"{MODEL_NAME}_neo4j_predicted_clusters.pkl",
)

OUTPUT_MONGODB_EDGES_PATH = os.path.join(
    NEO4J_FOLDER_PATH,
    f"{MODEL_NAME}_neo4j_edges.pkl",
)

## Load files

In [6]:
# Load merged_df data
merged_data_df = catalog.load("merged_data")  # noqa

# load ground truth data
ground_truth = pd.read_excel(INPUT_GROUNDTRUTH_PATH, sheet_name=2)
ground_truth = ground_truth[ground_truth["Owner"].str.contains(CONTRIBUTOR)]
ground_truth = ground_truth[["Page Title", "Combine Group ID", "URL"]]
ground_truth = ground_truth[ground_truth["Combine Group ID"].notna()]

# Extract id from merged_data_df to ground truth
ground_truth = pd.merge(
    ground_truth, merged_data_df, how="inner", left_on="URL", right_on="full_url"
)
ground_truth = ground_truth[["id", "Page Title", "URL", "Combine Group ID"]]
ground_truth.rename(columns={"Combine Group ID": "ground_truth_cluster"}, inplace=True)

print(ground_truth.shape)
ground_truth.head(2)

(184, 4)


Unnamed: 0,id,Page Title,URL,ground_truth_cluster
0,1442828,Getting ready for solids,https://www.healthhub.sg/live-healthy/baby-get...,1.0
1,1445136,Getting Your Baby Started on Solids,https://www.healthhub.sg/live-healthy/getting-...,1.0


In [7]:
# load embeddings file
with open(INPUT_EMBEDDING_NEO4J_PATH, "rb") as f:
    articles = pickle.load(f)

# merge with ground truth
articles_df = pd.merge(
    articles,
    ground_truth,
    how="inner",
    left_on="id",
    right_on="id",
)

if MODEL_NAME != "tfidf":
    vector_columns = [col for col in articles_df.columns if "vector" in col]
    for col in vector_columns:
        articles_df[col] = articles_df[col].apply(lambda x: x.tolist())

print(articles_df.shape)
articles_df.head(2)

(176, 12)


Unnamed: 0,id,title,full_url,content,meta_description,vector_title,vector_article_category_names,vector_category_description,vector_extracted_content_body,Page Title,URL,ground_truth_cluster
0,1443987,All You Need to Know About Childhood Immunisat...,https://www.healthhub.sg/live-healthy/all-you-...,Every child in Singapore is vaccinated accordi...,"To prevent diseases such as measles and mumps,...","[0.19928120076656342, 0.927337646484375, -3.30...","[0.911236584186554, 0.2581771910190582, -3.850...","[0.12509267032146454, 0.420778751373291, -3.49...","[0.2495705485343933, 0.5062752366065979, -2.47...",All You Need to Know About Childhood Immunisat...,https://www.healthhub.sg/live-healthy/all-you-...,16.0
1,1442828,Getting ready for solids,https://www.healthhub.sg/live-healthy/baby-get...,Weaning Tips\nThe process of switching an infa...,You have breastfed your baby for 6 months and ...,"[0.3498051166534424, 0.4663933515548706, -3.85...","[0.0, 0.0, -1.8538571014675015e-13, 1.44700809...","[1.1512057781219482, 0.9110057353973389, -3.02...","[1.1622024774551392, 1.2845065593719482, -2.55...",Getting ready for solids,https://www.healthhub.sg/live-healthy/baby-get...,1.0


## Clustering

In [8]:
documents = articles_df.to_dict(orient="records")
documents[0].keys()

[1;35mdict_keys[0m[1m([0m[1m[[0m[32m'id'[0m, [32m'title'[0m, [32m'full_url'[0m, [32m'content'[0m, [32m'meta_description'[0m, [32m'vector_title'[0m, [32m'vector_article_category_names'[0m, [32m'vector_category_description'[0m, [32m'vector_extracted_content_body'[0m, [32m'Page Title'[0m, [32m'URL'[0m, [32m'ground_truth_cluster'[0m[1m][0m[1m)[0m

In [9]:
logging.basicConfig(level=logging.INFO)

def clear_db(tx):
    logging.info("Clearing database")
    tx.run("MATCH (n) DETACH DELETE n")

def create_graph_nodes(tx, doc):
    # logging.info("Create nodes")
    tx.run(
        """
    CREATE (d:Article {
        id: $id,
        title: $title,
        url: $url,
        content: $content,
        meta_desc: $meta_description,
        vector_body: $vector_body,
        vector_title: $vector_title,
        vector_category: $vector_category,
        vector_desc: $vector_desc,
        ground_truth: $ground_truth
    })""",
        id=doc["id"],
        title=doc["title"],
        url=doc["full_url"],
        content=doc["content"],
        meta_description=doc["meta_description"],
        vector_title=doc["vector_title"],
        vector_category=doc["vector_article_category_names"],
        vector_desc=doc["vector_category_description"],
        vector_body=doc["vector_extracted_content_body"],
        ground_truth=doc["ground_truth_cluster"],
    )

def calculate_similarity(tx):
    logging.info("Create edges")
    query = """
        MATCH (a:Article), (b:Article)
        WHERE a.id < b.id
        WITH a, b, gds.similarity.cosine(a.vector_body, b.vector_body) AS similarity
        RETURN a.id AS node_1_id,
            b.id AS node_2_id,
            a.title AS node_1_title, 
            b.title AS node_2_title,
            a.ground_truth AS node_1_ground_truth, 
            b.ground_truth AS node_2_ground_truth,
            similarity AS edge_weight
        """
    result = tx.run(query)
    return [record for record in result]

def median_threshold(sim_result):
    df = pd.DataFrame(sim_result, columns=["node_1_id", "node_2_id", "node_1_title", "node_2_title","node_1_ground_truth", "node_2_ground_truth", "edge_weight"])
    df_filtered = df[df["node_1_ground_truth"] == df["node_2_ground_truth"]]
    threshold = df_filtered["edge_weight"].median()
    return threshold

def create_sim_edges(tx, threshold):
    logging.info("Create edges")
    tx.run(
        """
    MATCH (a:Article), (b:Article)
    WHERE a.id < b.id
    WITH a, b, gds.similarity.cosine(a.vector_body, b.vector_body) AS similarity
    WHERE similarity > $threshold
    CREATE (a)-[:SIMILAR {similarity: similarity}]->(b)
    """,
        threshold=threshold,
    )

def drop_graph_projection(tx):
    result = tx.run(
        """
    CALL gds.graph.exists('articleGraph')
    YIELD exists
    RETURN exists
    """
    )
    if result.single()["exists"]:
        tx.run("CALL gds.graph.drop('articleGraph')")

def create_graph_proj(tx):
    # logging.info("Create projection")
    tx.run(
        """
           CALL gds.graph.project(
            'articleGraph',
            'Article',
            {
                SIMILAR: {
                    properties: 'similarity'
                }
            }
           )
    """
    )

def detect_community(tx):
    # logging.info("Detect community")
    tx.run(
        """
        CALL gds.louvain.write(
        'articleGraph',
        {
            writeProperty: 'community'
        }
        )
    """
    )

def return_pred_cluster(tx):
    query = """
        MATCH (a:Article)
        RETURN a.id AS id,
            a.title AS title, 
            a.url AS url, 
            a.community AS cluster
        ORDER BY a.community
        """
    result = tx.run(query)
    df = pd.DataFrame(result.data())
    return df

def get_clustered_nodes(tx):
    query = """
        MATCH (n)-[r]->(m)
        RETURN n.id AS node_1_id,
            m.id AS node_2_id,
            n.title AS node_1_title, 
            m.title AS node_2_title,
            r.similarity AS edge_weight,
            n.ground_truth AS node_1_ground_truth,
            m.ground_truth AS node_2_ground_truth,
            n.community AS node_1_pred_cluster,
            m.community AS node_2_pred_cluster
        """
    result = tx.run(query)
    df = pd.DataFrame(result.data())
    return df

def get_unclustered_nodes(tx):
    query = """
        MATCH (n)
        WHERE NOT EXISTS ((n)--())
        RETURN n.title AS node_title,
            n.ground_truth AS node_ground_truth,
            n.community AS node_community,
            n.meta_desc AS node_meta_desc
        """
    result = tx.run(query)
    df = pd.DataFrame(result.data())
    return df

def count_articles(tx):
    query = """
        MATCH (a:Article)
        RETURN a.community AS cluster, count(a) AS article_count
        ORDER BY cluster
        """
    result = tx.run(query)
    df = pd.DataFrame(result.data())
    return df

def return_by_cluster(tx):
    """Return only clusters with more than one article"""

    query = """
    MATCH (n)
    WITH n.community AS cluster, collect(n.title) AS titles, count(n) AS count
    WHERE count > 1
    RETURN cluster, titles
    ORDER BY cluster
        """
    result = tx.run(query)
    df = pd.DataFrame(result.data())
    return df

In [10]:
with GraphDatabase.driver(**NEO4J) as driver:
    with driver.session() as session:
        session.execute_write(clear_db)  # Clear the database
        for doc in documents:
            session.execute_write(create_graph_nodes, doc)
        sim_result = session.execute_write(calculate_similarity)
        threshold = median_threshold(sim_result)
        session.execute_write(create_sim_edges, threshold)
        session.execute_write(drop_graph_projection)
        session.execute_write(create_graph_proj)
        session.execute_write(detect_community)
        pred_cluster = session.execute_read(return_pred_cluster)
        clustered_nodes = session.execute_read(get_clustered_nodes)
        unclustered_nodes = session.execute_read(get_unclustered_nodes)
        cluster_article_count = session.execute_read(count_articles)
        cluster_articles = session.execute_read(return_by_cluster)

min_count = cluster_article_count[cluster_article_count["article_count"] > 1][
    "article_count"
].min()
max_count = cluster_article_count["article_count"].max()
num_clusters = (cluster_article_count["article_count"] != 1).sum()
unclustered_count = (cluster_article_count["article_count"] == 1).sum()

cluster_articles_dict = cluster_articles.to_dict(orient='records')

edges_in_same_cluster = clustered_nodes[clustered_nodes["node_1_pred_cluster"] == clustered_nodes["node_2_pred_cluster"]]
edges = edges_in_same_cluster[["node_1_title", "node_2_title", "edge_weight"]]
edges_dict = edges.to_dict(orient='records')

In [11]:
data = pd.DataFrame(
    {
        "Model": [MODEL_NAME],
        "Threshold": [threshold],
        "Number of clusters": [num_clusters],
        "Min cluster size": [min_count],
        "Max cluster size": [max_count],
        "Number of articles not clustered": [unclustered_count],
    }
)
data

Unnamed: 0,Model,Threshold,Number of clusters,Min cluster size,Max cluster size,Number of articles not clustered
0,nomic-embed-text-v1.5,0.827237,33,2,21,27


In [12]:
# export files
pred_cluster.to_csv(OUTPUT_PREDICTED_CLUSTER_PATH)
clustered_nodes.to_csv(OUTPUT_CLUSTERED_NODES_PATH)
unclustered_nodes.to_csv(OUTPUT_UNCLUSTERED_NODES_PATH)

with open(OUTPUT_MONGODB_EDGES_PATH,"wb") as file:
    pickle.dump(edges_dict, file)

with open(OUTPUT_MONGODB_PREDICTED_CLUSTER_PATH,"wb") as file:
    pickle.dump(cluster_articles_dict, file)