### Setup

In [61]:
%%capture
%load_ext kedro.ipython

In [62]:
%%capture
import logging
import os
import pickle

import pandas as pd
from dotenv import load_dotenv
from neo4j import GraphDatabase
from neo4j.exceptions import DriverError, Neo4jError
from sklearn.metrics import completeness_score, homogeneity_score, v_measure_score

load_dotenv(r"..\conf\local\.env")

## Parameters

In [233]:
# Parameters
MODEL_NAME: str = "all-mpnet-base-v2"
CONTRIBUTOR: str = "Health Promotion Board"

# specify content_category. input 'all' if running across all categories
CONTENT_CATEGORY: str = "live-healthy-articles"

# adjust accordingly
THRESHOLD: float = 0.7446

In [234]:
INPUT_GROUNDTRUTH_PATH = os.path.join(
    "..",
    "data",
    "01_raw",
    "Synapxe Content Prioritisation - Live Healthy_020724.xlsx",
)

DATA_FOLDER_PATH = os.path.join(
    "..",
    "data",
    "07_model_output",
    f"{CONTENT_CATEGORY}",
)

INPUT_EMBEDDING_NEO4J_PATH = os.path.join(
    DATA_FOLDER_PATH,
    f"{CONTENT_CATEGORY}_{MODEL_NAME}_embeddings_neo4j.pkl",
)

OUTPUT_PREDICTED_CLUSTER_PATH = os.path.join(
    DATA_FOLDER_PATH,
    f"{CONTENT_CATEGORY}__{MODEL_NAME}_predicted_clusters_th{'0'+ str(int(THRESHOLD*100))}.csv",
)

OUTPUT_CLUSTER_METRICS_PATH = os.path.join(
    DATA_FOLDER_PATH,
    f"{CONTENT_CATEGORY}_compiled_model_variation_metrics.csv",
)

NEO4J_FOLDER_PATH = os.path.join(
    DATA_FOLDER_PATH,
    "neo4j",
)

if not os.path.exists(NEO4J_FOLDER_PATH):
    os.makedirs(NEO4J_FOLDER_PATH)

In [None]:
uri = "neo4j://localhost:7687"
username = os.getenv("neo4j_username")
password = os.getenv("neo4j_password")
# driver = GraphDatabase.driver(uri, auth=(username, password))

NEO4J = {
    "uri": uri,
    "auth": (username, password),
    "database": CONTENT_CATEGORY,  # create this database in neo4j first
}

# Test connection
with GraphDatabase.driver(**NEO4J) as driver:
    try:
        driver.verify_connectivity()
        print("Connection estabilished.")
    except (DriverError, Neo4jError) as exception:
        print(exception)

## Load files

In [235]:
# Load merged_df data
merged_data_df = catalog.load("merged_data")  # noqa

# load ground truth data
ground_truth = pd.read_excel(INPUT_GROUNDTRUTH_PATH, sheet_name=2)
ground_truth = ground_truth[ground_truth["Owner"].str.contains(CONTRIBUTOR)]
ground_truth = ground_truth[["Page Title", "Combine Group ID", "URL"]]
ground_truth = ground_truth[ground_truth["Combine Group ID"].notna()]

# Extract id from merged_data_df to ground truth
ground_truth = pd.merge(
    ground_truth, merged_data_df, how="inner", left_on="URL", right_on="full_url"
)
ground_truth = ground_truth[["id", "Page Title", "URL", "Combine Group ID"]]
ground_truth.rename(columns={"Combine Group ID": "ground_truth_cluster"}, inplace=True)

print(ground_truth.shape)
ground_truth.head(2)

(184, 4)


Unnamed: 0,id,Page Title,URL,ground_truth_cluster
0,1442828,Getting ready for solids,https://www.healthhub.sg/live-healthy/baby-get...,1.0
1,1445136,Getting Your Baby Started on Solids,https://www.healthhub.sg/live-healthy/getting-...,1.0


In [236]:
# load embeddings file
with open(INPUT_EMBEDDING_NEO4J_PATH, "rb") as f:
    articles = pickle.load(f)

# merge with ground truth
articles_df = pd.merge(
    articles,
    ground_truth,
    how="inner",
    left_on="id",
    right_on="id",
)

if MODEL_NAME != "tfidf":
    vector_columns = [col for col in articles_df.columns if "vector" in col]
    for col in vector_columns:
        articles_df[col] = articles_df[col].apply(lambda x: x.tolist())

print(articles_df.shape)
articles_df.head(2)

(176, 12)


Unnamed: 0,id,title,full_url,content,meta_description,vector_title,vector_article_category_names,vector_category_description,vector_extracted_content_body,Page Title,URL,ground_truth_cluster
0,1443987,All You Need to Know About Childhood Immunisat...,https://www.healthhub.sg/live-healthy/all-you-...,Every child in Singapore is vaccinated accordi...,"To prevent diseases such as measles and mumps,...","[0.0075735547579824924, -0.047238513827323914,...","[0.06393319368362427, 0.023783838376402855, -0...","[-0.018412871286273003, 0.008014391176402569, ...","[0.02329820767045021, -0.026086032390594482, -...",All You Need to Know About Childhood Immunisat...,https://www.healthhub.sg/live-healthy/all-you-...,16.0
1,1442828,Getting ready for solids,https://www.healthhub.sg/live-healthy/baby-get...,Weaning Tips\nThe process of switching an infa...,You have breastfed your baby for 6 months and ...,"[0.035273078829050064, 0.03444071486592293, -0...","[-0.0363435372710228, -0.0007167053408920765, ...","[0.02195669896900654, 0.04195531830191612, 0.0...","[0.007204481866210699, 0.01930098980665207, -0...",Getting ready for solids,https://www.healthhub.sg/live-healthy/baby-get...,1.0


## Clustering

In [237]:
documents = articles_df.to_dict(orient="records")
documents[0].keys()

[1;35mdict_keys[0m[1m([0m[1m[[0m[32m'id'[0m, [32m'title'[0m, [32m'full_url'[0m, [32m'content'[0m, [32m'meta_description'[0m, [32m'vector_title'[0m, [32m'vector_article_category_names'[0m, [32m'vector_category_description'[0m, [32m'vector_extracted_content_body'[0m, [32m'Page Title'[0m, [32m'URL'[0m, [32m'ground_truth_cluster'[0m[1m][0m[1m)[0m

In [238]:
logging.basicConfig(level=logging.INFO)


def clear_db(tx):
    logging.info("Clearing database")
    tx.run("MATCH (n) DETACH DELETE n")


def create_graph_nodes(tx, doc):
    # logging.info("Create nodes")
    tx.run(
        """
    CREATE (d:Article {
        id: $id,
        title: $title,
        url: $url,
        content: $content,
        meta_desc: $meta_description,
        vector_body: $vector_body,
        vector_title: $vector_title,
        vector_category: $vector_category,
        vector_desc: $vector_desc,
        ground_truth: $ground_truth
    })""",
        id=doc["id"],
        title=doc["title"],
        url=doc["full_url"],
        content=doc["content"],
        meta_description=doc["meta_description"],
        vector_title=doc["vector_title"],
        vector_category=doc["vector_article_category_names"],
        vector_desc=doc["vector_category_description"],
        vector_body=doc["vector_extracted_content_body"],
        ground_truth=doc["ground_truth_cluster"],
    )


def create_sim_edges(tx, threshold):
    logging.info("Create edges")
    tx.run(
        """
    MATCH (a:Article), (b:Article)
    WHERE a.id < b.id
    WITH a, b, gds.similarity.cosine(a.vector_body, b.vector_body) AS similarity
    WHERE similarity > $threshold
    CREATE (a)-[:SIMILAR {similarity: similarity}]->(b)
    """,
        threshold=threshold,
    )


def drop_graph_projection(tx):
    result = tx.run(
        """
    CALL gds.graph.exists('articleGraph')
    YIELD exists
    RETURN exists
    """
    )
    if result.single()["exists"]:
        tx.run("CALL gds.graph.drop('articleGraph')")


def create_graph_proj(tx):
    # logging.info("Create projection")
    tx.run(
        """
           CALL gds.graph.project(
            'articleGraph',
            'Article',
            {
                SIMILAR: {
                    properties: 'similarity'
                }
            }
           )
    """
    )


def detect_community(tx):
    # logging.info("Detect community")
    tx.run(
        """
        CALL gds.louvain.write(
        'articleGraph',
        {
            writeProperty: 'community'
        }
        )
    """
    )


def return_community(tx):
    query = """
        MATCH (a:Article)
        RETURN a.community AS cluster, collect(a.title) AS articles
        ORDER BY cluster
        """
    result = tx.run(query)
    return [record for record in result]


def return_pred_cluster(tx):
    query = """
        MATCH (a:Article)
        RETURN a.id, a.title, a.url, a.community AS cluster
        ORDER BY a.community
        """
    result = tx.run(query)
    return [record for record in result]


def count_articles(tx):
    query = """
        MATCH (a:Article)
        RETURN a.community AS cluster, count(a) AS articleCount
        ORDER BY cluster
        """
    result = tx.run(query)
    return [record for record in result]

In [239]:
with GraphDatabase.driver(**NEO4J) as driver:
    with driver.session() as session:
        session.execute_write(clear_db)  # Clear the database
        for doc in documents:
            session.execute_write(create_graph_nodes, doc)
        session.execute_write(create_sim_edges, THRESHOLD)
        session.execute_write(drop_graph_projection)
        session.execute_write(create_graph_proj)
        session.execute_write(detect_community)
        records = session.execute_read(return_community)
        pred_cluster = session.execute_read(return_pred_cluster)
        articles_count = session.execute_read(count_articles)

In [240]:
pred_cluster_df = pd.DataFrame(
    pred_cluster, columns=["id", "title", "url", "pred_cluster"]
)
# pred_cluster_df.to_csv(OUTPUT_PREDICTED_CLUSTER_PATH)

cluster_article_count = pd.DataFrame(
    articles_count, columns=["pred_cluster_number", "article_count"]
)

In [241]:
results_df = pd.merge(
    articles_df, pred_cluster_df, how="inner", left_on="id", right_on="id"
)

results_df = results_df[
    ["id", "Page Title", "URL", "ground_truth_cluster", "pred_cluster"]
]
results_df["ground_truth_cluster"] = results_df["ground_truth_cluster"].astype(int)

results_df.head(2)

Unnamed: 0,id,Page Title,URL,ground_truth_cluster,pred_cluster
0,1443987,All You Need to Know About Childhood Immunisat...,https://www.healthhub.sg/live-healthy/all-you-...,16,0
1,1442828,Getting ready for solids,https://www.healthhub.sg/live-healthy/baby-get...,1,2


### Cluster metrics

In [242]:
def get_exact_match(results_df):
    pred_cluster_labels = results_df.groupby("pred_cluster")["id"].apply(set).to_list()
    ground_cluster_labels = (
        results_df.groupby("ground_truth_cluster")["id"].apply(set).to_list()
    )
    complete_match = [s for s in pred_cluster_labels if s in ground_cluster_labels]

    return len(complete_match)


def fill_single(series):
    max_val = series.max()
    fill_in_val = max_val
    filled_series = series.copy()
    for idx in series[series.isna()].index:
        filled_series.at[idx] = fill_in_val + 1
        fill_in_val += 1
    return filled_series.to_list()


def compute_vmeasure(results_df):
    ground_truth_labels = fill_single(results_df["ground_truth_cluster"])
    predicted_labels = fill_single(results_df["pred_cluster"])
    homogeneity = homogeneity_score(ground_truth_labels, predicted_labels)
    completeness = completeness_score(ground_truth_labels, predicted_labels)
    v_measure = v_measure_score(ground_truth_labels, predicted_labels)

    return homogeneity, completeness, v_measure

In [252]:
min_count = cluster_article_count[cluster_article_count["article_count"] > 1][
    "article_count"
].min()
max_count = cluster_article_count["article_count"].max()
num_clusters = (cluster_article_count["article_count"] != 1).sum()
unclustered_count = (cluster_article_count["article_count"] == 1).sum()

exact_match = get_exact_match(results_df)
homogeneity, completeness, v_measure = compute_vmeasure(results_df)

data = pd.DataFrame(
    {
        "Model": [MODEL_NAME],
        "Threshold": [THRESHOLD],
        "Exact cluster match": [exact_match],
        "Homogeneity": [round(homogeneity, 4)],
        "Completeness": [round(completeness, 4)],
        "V-measure": [round(v_measure, 4)],
        "Number of clusters": [num_clusters],
        "Min cluster size": [min_count],
        "Max cluster size": [max_count],
        "Number of articles not clustered": [unclustered_count],
    }
)

In [253]:
data

Unnamed: 0,Model,Threshold,Exact cluster match,Homogeneity,Completeness,V-measure,Number of clusters,Min cluster size,Max cluster size,Number of articles not clustered
0,all-mpnet-base-v2,0.7446,14,0.8027,0.9075,0.8519,27,2,26,27


In [None]:
if os.path.exists(OUTPUT_CLUSTER_METRICS_PATH):
    metrics_df = pd.read_csv(OUTPUT_CLUSTER_METRICS_PATH, index_col=0)
else:
    metrics_df = pd.DataFrame()

metrics_df = pd.concat([metrics_df, data], axis=0)
metrics_df.to_csv(OUTPUT_CLUSTER_METRICS_PATH)
metrics_df

### Cluster visualisation

In [249]:
query = """
MATCH (n)-[r]->(m)
RETURN n.id AS node_1_id,
    m.id AS node_2_id,
    n.title AS node_1_title,
    m.title AS node_2_title,
    r.similarity AS edge_weight,
    n.ground_truth AS node_1_ground_truth,
    m.ground_truth AS node_2_ground_truth,
    n.community AS node_1_pred_cluster,
    m.community AS node_2_pred_cluster
"""
# nodes with no relationship
query_2 = """MATCH (n)
WHERE NOT EXISTS ((n)--())
RETURN n.title AS node_title,
    n.ground_truth AS node_ground_truth,
    n.community AS node_community,
    n.meta_desc AS node_meta_desc
"""
with GraphDatabase.driver(**NEO4J) as driver:
    with driver.session() as session:
        results = session.run(query)
        data = pd.DataFrame(results.data())
        results_2 = session.run(query_2)
        data_2 = pd.DataFrame(results_2.data())

data["node_1_title"] = data["node_1_title"].astype(str)
data["node_2_title"] = data["node_2_title"].astype(str)
# data = data.dropna(subset=['node_2'])

data_2["node_community"] = ""

# save nodes and edges of clustered and unclustered (single nodes) data for visualisation
data.to_csv(
    os.path.join(
        NEO4J_FOLDER_PATH,
        f"{MODEL_NAME}_neo4j_clustered_data_th{'0'+ str(int(THRESHOLD*100))}.csv",
    )
)
data_2.to_csv(
    os.path.join(
        NEO4J_FOLDER_PATH,
        f"{MODEL_NAME}_neo4j_unclustered_data_th{'0'+ str(int(THRESHOLD*100))}.csv",
    )
)

## End