In [1]:
import os
import pandas as pd
import pickle

import hdbscan
import pandas as pd

from umap import UMAP
from hdbscan import HDBSCAN
# from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer

from bertopic import BERTopic
from bertopic.representation import MaximalMarginalRelevance
from bertopic.vectorizers import ClassTfidfTransformer

# from sklearn.metrics.pairwise import cosine_similarity
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
np.random.seed(1234)

In [3]:
root_path = "../data"
experiment_path = os.path.join(root_path,"07_model_output","nomic_weighted_emb (0.7 body, 0.3 title) updated")
weighted_embeddings_path = os.path.join(root_path, "04_feature", "weighted_embeddings.pkl")
neo4j_predicted_cluster_pkl_path = os.path.join(experiment_path, "neo4j_predicted_clusters.pkl")

with open(weighted_embeddings_path, "rb") as f:
    weighted_embeddings = pickle.load(f)

with open(neo4j_predicted_cluster_pkl_path, "rb") as f:
    neo4j_predicted_cluster_pkl = pickle.load(f)

pred_cluster_df = pd.read_csv(os.path.join(experiment_path, "predicted_cluster.csv"))

In [4]:
cluster_size_count = pred_cluster_df.cluster.value_counts()
to_keep = cluster_size_count[cluster_size_count >10].index
cluster_morethan10 = pred_cluster_df[pred_cluster_df.cluster.isin(to_keep)]
print('No. of cluster to do 2nd level clustering: ', cluster_morethan10.cluster.nunique())

No. of cluster to do 2nd level clustering:  11


In [5]:
cluster_morethan10_embeddings = pd.merge(
    cluster_morethan10,
    weighted_embeddings[['id','vector_extracted_content_body']],
    how='left',
    on='id')

print(cluster_morethan10.shape[0] == cluster_morethan10_embeddings.shape[0])

True


In [6]:
print(cluster_morethan10_embeddings.shape)
print(cluster_morethan10_embeddings.cluster.value_counts())

(337, 6)
cluster
137    68
581    58
555    43
58     38
115    35
438    30
503    20
537    12
150    11
219    11
605    11
Name: count, dtype: int64


In [7]:
def get_embeddings(cluster_df):
    embeddings = np.array(cluster_df.vector_extracted_content_body.to_list())
    doc_titles = cluster_df.title.to_list()
    docs = cluster_df.body_content.to_list()
    ids = cluster_df.id.to_list()
    umap_model = UMAP(n_neighbors=15, n_components=8, min_dist=0.0, metric='cosine', random_state=42)
    umap_embeddings = umap_model.fit_transform(embeddings)

    return embeddings, doc_titles, docs, ids, umap_embeddings

In [8]:
def hyperparameter_tuning(embeddings):
    best_score = 0

    for min_cluster_size in [2,3,4,5,6]:
        for min_samples in [1,2,3,4,5,6,7]:
            for cluster_selection_method in ['leaf']:
                for metric in ['euclidean','manhattan']:
                    # for each combination of parameters of hdbscan
                    hdb = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size,min_samples=min_samples,
                                        cluster_selection_method=cluster_selection_method, metric=metric, 
                                        gen_min_span_tree=True).fit(embeddings)
                    # DBCV score
                    score = hdb.relative_validity_
                    if score > best_score:
                        best_score = score
                        best_parameters = {'min_cluster_size': min_cluster_size, 
                                'min_samples':  min_samples, 'cluster_selection_method': cluster_selection_method,
                                'metric': metric}

    print("Best DBCV score: {:.3f}".format(best_score))
    print("Best parameters: {}".format(best_parameters))
    return best_parameters

In [9]:
def topic_modelling(hyperparameters):
    # Step 3 - Cluster reduced embeddings
    hdbscan_model = HDBSCAN(min_cluster_size=hyperparameters['min_cluster_size'], min_samples=hyperparameters['min_samples'], metric=hyperparameters['metric'], cluster_selection_method=hyperparameters['cluster_selection_method'], prediction_data=True, gen_min_span_tree=True)

    # Step 4 - Tokenize topics
    vectorizer_model = CountVectorizer(stop_words="english")

    # Step 5 - Create topic representation
    ctfidf_model = ClassTfidfTransformer()

    # Step 6 - (Optional) Fine-tune topic representations with 
    representation_model = MaximalMarginalRelevance(diversity=0.3)

    # All steps together
    topic_model = BERTopic(
    # embedding_model=embedding_model,          # Step 1 - Extract embeddings
    # umap_model=umap_model,                    # Step 2 - Reduce dimensionality
    hdbscan_model=hdbscan_model,              # Step 3 - Cluster reduced embeddings
    vectorizer_model=vectorizer_model,        # Step 4 - Tokenize topics
    ctfidf_model=ctfidf_model,                # Step 5 - Extract topic words
    representation_model=representation_model, # Step 6 - (Optional) Fine-tune topic represenations
    # nr_topics="auto" #default is none, will auto reduce topics using HDBSCAN
    )
    return topic_model

In [None]:
def create_topic_assigner(start_counter):
    counter = start_counter
    
    def assign_new_topic(x):
        nonlocal counter
        if x == -1:
            new_topic = counter
            counter += 1
            return new_topic
        else:
            return x

    return assign_new_topic


def process_cluster(cluster_df):
    # Step 1: Extract embeddings and umap_embeddings
    embeddings, doc_titles, docs, ids, umap_embeddings = get_embeddings(cluster_df)

    # Step 2: Perform hyperparameter tuning for berttopic
    hyperparameters = hyperparameter_tuning(umap_embeddings)

    # Step 3: Create and fit topic model 
    topic_model = topic_modelling(hyperparameters)
    topics, _ = topic_model.fit_transform(docs, embeddings)

    ###############
    # Visualisation 
    ################
    
    # Uncomment and adjust as needed for visualization purposes

    # top_n = 50
    # top_topics = topic_model.get_topic_freq().head(top_n)['Topic'].tolist()

    # reduced_embeddings = topic_model.umap_model.embedding_
    # hover_data = [f"{title} - Topic {topic}" for title, topic in zip(doc_titles, topics)]
    # visualization = topic_model.visualize_documents(hover_data, reduced_embeddings=reduced_embeddings, topics=top_topics, title=f'Top {top_n} Topics') 
    # visualization.show() 

    # visualization_barchart = topic_model.visualize_barchart(top_n_topics=top_n)
    # visualization_barchart.show()

    # Step 4: Create a DataFrame with assigned topics, titles and ids.
    result_df = pd.DataFrame({"Assigned Topic": topics, "Title": doc_titles, "id": ids})
    
    # Step 5: Extract topic information and get top 5 keywords, if article is unclustered where Topic is -1, topic representation/kws will be removed
    topic_kws = topic_model.get_topic_info()[['Topic', 'Representation']]
    topic_kws['top_5_kws'] = topic_kws.apply(lambda row: row['Representation'][:5] if row['Topic'] != -1 else np.nan, axis=1)
    
    # Step 6: Merge results with the top keywords
    result_df_kws = pd.merge(result_df, topic_kws, how='left', left_on='Assigned Topic', right_on='Topic')
    result_df_kws = result_df_kws.drop(['Representation', 'Topic'], axis=1)
    result_df_kws = result_df_kws[['id', 'Title', 'Assigned Topic', 'top_5_kws']]

    # Step 7: Assign new topic numbers to topics that are -1, starting from the max assigned topic in the results_df_kws. 
    max_topic = result_df_kws['Assigned Topic'].max()
    new_topic_counter = max_topic + 1
    assign_new_topic_func = create_topic_assigner(new_topic_counter)
    result_df_kws['Assigned Topic'] = result_df_kws['Assigned Topic'].apply(assign_new_topic_func)

    # Step 8: Update the 'Assigned Topic' column with cluster information to prevent repeat cluster numbers
    cluster_id = cluster_df['cluster'].unique()[0]
    result_df_kws['Assigned Topic'] = result_df_kws['Assigned Topic'].apply(lambda x: 'Cluster_' + str(cluster_id) + '_' + str(x))
    print(result_df_kws)

    return result_df_kws

def process_all_clusters(cluster_morethan10_embeddings):
    unique_clusters = cluster_morethan10_embeddings['cluster'].unique()
    all_results = []

    for cluster_id in unique_clusters:
        print(f"cluster id: {cluster_id}")
        cluster_df = cluster_morethan10_embeddings[cluster_morethan10_embeddings['cluster'] == cluster_id]
        result_df_kws = process_cluster(cluster_df)
        all_results.append(result_df_kws)

    combined_df = pd.concat(all_results, ignore_index=True)
    return combined_df

def assign_unique_numbers_to_topics(final_result_df, pred_cluster_df):
    """
    Assigns unique numbers to each unique 'Assigned Topic' in the final_result_df
    based on the maximum cluster value from the pred_cluster_df.

    Parameters:
    final_result_df (pd.DataFrame): DataFrame containing the final results with an 'Assigned Topic' column.
    pred_cluster_df (pd.DataFrame): DataFrame containing the predicted clusters with a 'cluster' column.

    Returns:
    pd.DataFrame: Updated final_result_df with an additional 'Assigned Topic Number' column.
    """
    max_cluster_value = pred_cluster_df['cluster'].max()
    unique_assigned_topics = final_result_df['Assigned Topic'].unique()
    topic_number_mapping = {topic: idx + max_cluster_value + 1 for idx, topic in enumerate(unique_assigned_topics)}
    
    final_result_df['Assigned Topic Number'] = final_result_df['Assigned Topic'].map(topic_number_mapping)
    return final_result_df

final_result_df = process_all_clusters(cluster_morethan10_embeddings)
final_result_df_with_numbers = assign_unique_numbers_to_topics(final_result_df, pred_cluster_df)

In [11]:
new_cluster_to_merge = final_result_df_with_numbers[['id','top_5_kws','Assigned Topic Number']]
new_cluster_to_merge.columns = ['id','cluster_kws','new_cluster']
updated_pred_cluster = pd.merge(pred_cluster_df, new_cluster_to_merge, how='left', on='id')

In [12]:
updated_pred_cluster['new_cluster'] = updated_pred_cluster['new_cluster'].fillna(updated_pred_cluster['cluster']).apply(int)
updated_pred_cluster

Unnamed: 0,id,title,url,body_content,cluster,cluster_kws,new_cluster
0,1437477,Hepatitis B,https://www.healthhub.sg/a-z/diseases-and-cond...,Hepatitis B Symptoms\nWhile some people who ha...,0,,0
1,1437465,Hepatitis A,https://www.healthhub.sg/a-z/diseases-and-cond...,Hepatitis is a generic term for inflammation o...,0,,0
2,1437303,Pneumonia,https://www.healthhub.sg/a-z/diseases-and-cond...,Pneumonia is a serious medical condition and m...,3,,3
3,1437301,Pneumococcal Disease,https://www.healthhub.sg/a-z/diseases-and-cond...,Update: You can book a pneumococcal vaccinatio...,3,,3
4,1437357,Colorectal Cancer,https://www.healthhub.sg/a-z/diseases-and-cond...,What is Colorectal Cancer?\nColorectal cancer ...,5,,5
...,...,...,...,...,...,...,...
663,1437509,"Herpes: Causes, Symptoms, and Treatment",https://www.healthhub.sg/a-z/diseases-and-cond...,What is Herpes?\nHerpes is a contagious viral ...,666,,666
664,1437405,"Genital Herpes: Symptoms, Causes and Treatments",https://www.healthhub.sg/a-z/diseases-and-cond...,Genital herpes is one of the most common sexua...,666,,666
665,1444590,5 Ways to Psych Yourself for a Mammogram,https://www.healthhub.sg/live-healthy/5-ways-t...,Have You Gone for Your Mammogram Screening?\nB...,667,,667
666,1435040,Breast Screening Subsidies in Singapore,https://www.healthhub.sg/a-z/costs-and-financi...,Breast cancer is the number one cancer among w...,667,,667


In [13]:
first_level_pred_cluster = pd.DataFrame(neo4j_predicted_cluster_pkl)
first_level_cluster_dict = dict(zip(first_level_pred_cluster['cluster'], first_level_pred_cluster['cluster_keywords']))

mask = updated_pred_cluster['cluster'] == updated_pred_cluster['new_cluster']
updated_pred_cluster.loc[mask, 'cluster_kws'] = updated_pred_cluster.loc[mask, 'cluster_kws'].fillna(
    updated_pred_cluster['cluster'].map(first_level_cluster_dict)
)

# Formatting
updated_pred_cluster.rename(columns={'cluster':'first_level_cluster','new_cluster':'second_level_cluster','cluster_kws':'second_level_cluster_kws'}, inplace=True)
updated_pred_cluster =updated_pred_cluster[['id','title','url','body_content','first_level_cluster','second_level_cluster','second_level_cluster_kws']]

In [14]:
updated_pred_cluster

Unnamed: 0,id,title,url,body_content,first_level_cluster,second_level_cluster,second_level_cluster_kws
0,1437477,Hepatitis B,https://www.healthhub.sg/a-z/diseases-and-cond...,Hepatitis B Symptoms\nWhile some people who ha...,0,0,"[hepatitis, liver, hav, infected, virus]"
1,1437465,Hepatitis A,https://www.healthhub.sg/a-z/diseases-and-cond...,Hepatitis is a generic term for inflammation o...,0,0,"[hepatitis, liver, hav, infected, virus]"
2,1437303,Pneumonia,https://www.healthhub.sg/a-z/diseases-and-cond...,Pneumonia is a serious medical condition and m...,3,3,"[pneumococcal, pneumonia, lung, vaccination, 65]"
3,1437301,Pneumococcal Disease,https://www.healthhub.sg/a-z/diseases-and-cond...,Update: You can book a pneumococcal vaccinatio...,3,3,"[pneumococcal, pneumonia, lung, vaccination, 65]"
4,1437357,Colorectal Cancer,https://www.healthhub.sg/a-z/diseases-and-cond...,What is Colorectal Cancer?\nColorectal cancer ...,5,5,
...,...,...,...,...,...,...,...
663,1437509,"Herpes: Causes, Symptoms, and Treatment",https://www.healthhub.sg/a-z/diseases-and-cond...,What is Herpes?\nHerpes is a contagious viral ...,666,666,"[herpes, genital, hsv, sore, blister]"
664,1437405,"Genital Herpes: Symptoms, Causes and Treatments",https://www.healthhub.sg/a-z/diseases-and-cond...,Genital herpes is one of the most common sexua...,666,666,"[herpes, genital, hsv, sore, blister]"
665,1444590,5 Ways to Psych Yourself for a Mammogram,https://www.healthhub.sg/live-healthy/5-ways-t...,Have You Gone for Your Mammogram Screening?\nB...,667,667,"[breast, cancer, mammogram, screening, 50]"
666,1435040,Breast Screening Subsidies in Singapore,https://www.healthhub.sg/a-z/costs-and-financi...,Breast cancer is the number one cancer among w...,667,667,"[breast, cancer, mammogram, screening, 50]"


In [15]:
adjusted_cluster = updated_pred_cluster[updated_pred_cluster['first_level_cluster'] != updated_pred_cluster['second_level_cluster']]
adjusted_cluster.head()

Unnamed: 0,id,title,url,body_content,first_level_cluster,second_level_cluster,second_level_cluster_kws
25,1437716,Diabetes (Pocket Guide),https://www.healthhub.sg/a-z/diseases-and-cond...,What is Type 2 Diabetes?\n\n \n Insulin and D...,58,668,"[diabetes, fat, glucose, blood, type]"
26,1442923,"If You Think Thin People Don’t Get Diabetes, T...",https://www.healthhub.sg/live-healthy/if-you-t...,Question: Which of these four body types is/ar...,58,668,"[diabetes, fat, glucose, blood, type]"
27,1445336,Make a Healthier Choice Today!,https://www.healthhub.sg/live-healthy/make_hea...,If you are a shopper looking for healthier foo...,58,669,"[sugar, wholegrain, healthier, wholegrains, food]"
28,1444565,Wholegrains—The Wise Choice!,https://www.healthhub.sg/live-healthy/whole_gr...,Grain Nutrients\n\nWhat qualifies as wholegrai...,58,669,"[sugar, wholegrain, healthier, wholegrains, food]"
29,1442686,Diabetes - Are You at Risk?,https://www.healthhub.sg/live-healthy/diabetes...,About one in three Singaporeans has a lifetime...,58,668,"[diabetes, fat, glucose, blood, type]"


In [73]:
updated_pred_cluster = pd.read_csv(os.path.join(experiment_path, "predicted_cluster_2nd_level_clustering.csv"))
broken_down_groups = updated_pred_cluster[updated_pred_cluster['first_level_cluster'] != updated_pred_cluster['second_level_cluster']]
broken_down_groups['second_level_cluster:kws'] = broken_down_groups.apply(lambda x: str(x['second_level_cluster']) + ' : ' + str(x['second_level_cluster_kws']), axis=1)
agg_result = broken_down_groups.groupby('first_level_cluster').agg(
    number_of_articles_in_first_level = ('first_level_cluster','size'),
    number_of_clusters=('second_level_cluster_kws', 'nunique'),
    second_level_cluster_article_counts=('second_level_cluster', lambda x: [v for v in x.value_counts().to_dict().values() if v > 1]),
    number_of_single_articles=('second_level_cluster_kws', lambda x: x.isna().sum()),
    # second_level_clusters = ('second_level_cluster:kws',set)
)
agg_result["OG_keywords"] = agg_result.index.map(first_level_cluster_dict)
agg_result['OG_cluster'] = agg_result.index.astype(str) + ' - ' + agg_result['OG_keywords'].astype(str)
agg_result =agg_result[['OG_cluster','number_of_articles_in_first_level','number_of_clusters','second_level_cluster_article_counts','number_of_single_articles']]
agg_result

Unnamed: 0_level_0,OG_cluster,number_of_articles_in_first_level,number_of_clusters,second_level_cluster_article_counts,number_of_single_articles
first_level_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
58,"58 - ['diabetes', 'sugar', 'wholegrain', 'gluc...",38,4,"[15, 11, 5, 5]",2
115,"115 - ['bmi', 'nutrition', 'school', 'serving'...",35,3,"[11, 9, 6]",9
137,"137 - ['toddler', 'infant', 'solid', 'feed', '...",68,5,"[18, 14, 12, 12, 9]",3
150,"150 - ['flu', 'influenza', 'vaccine', 'vaccina...",11,2,"[5, 3]",3
219,"219 - ['esteem', 'resilience', 'self', 'positi...",11,4,"[3, 3, 2, 2]",1
438,"438 - ['hawker', 'quarter', 'plate', 'dish', '...",30,4,"[8, 7, 3, 2]",10
503,"503 - ['teeth', 'tooth', 'dental', 'decay', 't...",20,6,"[3, 3, 3, 3, 3, 3]",2
537,"537 - ['trail', 'tiong', 'park', 'travel', 'ba...",12,4,"[5, 3, 2, 2]",0
555,"555 - ['quit', 'smoking', 'nicotine', 'quittin...",43,2,"[33, 10]",0
581,"581 - ['fitness', 'intensity', 'aerobic', 'wor...",58,7,"[11, 11, 8, 6, 4, 4, 3]",11


In [74]:
updated_pred_cluster.to_csv(os.path.join(experiment_path,"predicted_cluster_2nd_level_clustering.csv"), index=False)
agg_result.to_csv(os.path.join(experiment_path,"agg_result_cluster_2nd_level_clustering.csv"), index=False)

In [75]:
def get_cluster_size(pred_cluster):
    grouped_counts = pred_cluster.groupby('second_level_cluster').size()
    filtered_grouped_counts = grouped_counts[grouped_counts != 1]
    single_nodes =  len(grouped_counts[grouped_counts == 1])
    bins = range(1, filtered_grouped_counts.max() + 10, 10)
    labels = [f"{i}-{i+9}" for i in bins[:-1]]
    labels[0] = '2-10'
    binned_counts = pd.cut(filtered_grouped_counts, bins=bins, labels=labels, right=False)
    banded_counts = binned_counts.value_counts().sort_index()
    cluster_size_df2 = pd.DataFrame(banded_counts).reset_index().rename(columns={'index':"Cluster size",'count':"Num of clusters"})
    new_row = {'Cluster size': '1', 'Num of clusters': single_nodes}  # Customize with your data
    cluster_size_df2.loc[-1] = new_row
    cluster_size_df2 = cluster_size_df2.sort_index().reset_index(drop=True)
    return cluster_size_df2

get_cluster_size(updated_pred_cluster)

Unnamed: 0,Cluster size,Num of clusters
0,1,212
1,2-10,84
2,11-20,9
3,21-30,0
4,31-40,1


In [95]:
grouped_counts = updated_pred_cluster.groupby('second_level_cluster').size()
filtered_grouped_counts = grouped_counts[grouped_counts != 1]
print(f"no. of clusters: {filtered_grouped_counts.value_counts().sum()}")
print(f"min, max cluster size: {filtered_grouped_counts.min()}, {filtered_grouped_counts.max()}")
print(f"no. of single nodes: {len(grouped_counts[grouped_counts == 1])}")

no. of clusters: 94
min, max cluster size: 2, 33
no. of single nodes: 212


## Updating neo_4j_clustered_data & neo_4j_unclustered_data files for visualization

In [96]:
neo4j_clustered_df = pd.read_csv(os.path.join(experiment_path, "neo_4j_clustered_data.csv"))
neo4j_unclustered_df = pd.read_csv(os.path.join(experiment_path, "neo_4j_unclustered_data.csv"))

In [117]:
neo4j_unclustered_df

Unnamed: 0,node_title,node_ground_truth,node_community,node_meta_desc
0,Colorectal Cancer,,5,"<span data-contrast=""auto"" class=""TextRun SCXW..."
1,"Measles: Symptoms, Treatment, and Prevention",,10,Is that measles or an innocent rash? Learn mor...
2,Molar Incisor Hypomineralisation (MIH),,14,Your Guide to Understanding Molar Incisor Hypo...
3,Asthma (Common Childhood Illnesses),,15,Asthma affects about one in five children in S...
4,Understanding Leong's Premolars (LP),,16,Your Guide to Understanding Leong's Premolars ...
...,...,...,...,...
166,Eat to Lower Blood Pressure,,652,Do you watch the mercury rising every time you...
167,Kashmiri Pulao,,654,A healthy brown rice dish with raisins seasone...
168,How to Study Difficult Subjects,,656,Your brain is a muscle. You can train it too!
169,Baby Friendly Hospital Initiative,,659,The Baby Friendly Hospital Initiative (BFHI) s...


In [129]:
unique_new_clusters = updated_pred_cluster['second_level_cluster'].value_counts()
single_article_cluster = unique_new_clusters[unique_second_level_clusters == 1].index
unclustered_df = updated_pred_cluster[updated_pred_cluster['second_level_cluster'].isin(single_article_cluster)]
unclustered_df = unclustered_df.rename(
    columns={
        "id": "node_id", 
        "title": "node_title", 
        "second_level_cluster": "node_community", 
    }
).drop(columns=['url', 'body_content','first_level_cluster','second_level_cluster_kws'])
unclustered_df

Unnamed: 0,node_id,node_title,node_community
4,1437357,Colorectal Cancer,5
8,1437661,"Measles: Symptoms, Treatment, and Prevention",10
9,1437890,Molar Incisor Hypomineralisation (MIH),14
10,1437735,Asthma (Common Childhood Illnesses),15
11,1437884,Understanding Leong's Premolars (LP),16
...,...,...,...
656,1445677,Eat to Lower Blood Pressure,652
657,1445661,Kashmiri Pulao,654
658,1445358,How to Study Difficult Subjects,656
659,1439066,Baby Friendly Hospital Initiative,659


In [116]:
neo4j_clustered_df_new = pd.merge(
    neo4j_clustered_df,
    updated_pred_cluster[["id","second_level_cluster","second_level_cluster_kws"]],
    left_on="node_1_id",
    right_on='id',
    how='left'
).merge(
    updated_pred_cluster[["id","second_level_cluster","second_level_cluster_kws"]],
    left_on='node_2_id',
    right_on='id',
    how='left',
    suffixes=('_1', '_2')
).drop(columns=['id_1', 'id_2'])

neo4j_clustered_df_new = neo4j_clustered_df_new.rename(
    columns={
        "second_level_cluster_1": "node_1_pred_cluster_new", 
        "second_level_cluster_kws_1": "node_1_cluster_kws_new", 
        "second_level_cluster_2": "node_2_pred_cluster_new", 
        "second_level_cluster_kws_2": "node_2_cluster_kws_new"
    }
)

neo4j_clustered_df_new.head(2)

Unnamed: 0,node_1_id,node_2_id,node_1_title,node_2_title,edge_weight,node_1_ground_truth,node_2_ground_truth,node_1_pred_cluster,node_2_pred_cluster,node_1_cluster_kws,node_2_cluster_kws,node_1_pred_cluster_new,node_1_cluster_kws_new,node_2_pred_cluster_new,node_2_cluster_kws_new
0,1437643,1437728,Childhood Illnesses: 10 Most Common Conditions...,Cough and the Common Cold In Children,0.8808,,,137,137,"['toddler', 'infant', 'solid', 'feed', 'feeding']","['toddler', 'infant', 'solid', 'feed', 'feeding']",688,"['baby', 'fever', 'doctor', 'child', 'common']",688,"['baby', 'fever', 'doctor', 'child', 'common']"
1,1437643,1442801,Childhood Illnesses: 10 Most Common Conditions...,Common infant problems and conditions,0.853328,,8.0,137,137,"['toddler', 'infant', 'solid', 'feed', 'feeding']","['toddler', 'infant', 'solid', 'feed', 'feeding']",688,"['baby', 'fever', 'doctor', 'child', 'common']",688,"['baby', 'fever', 'doctor', 'child', 'common']"


In [138]:
import pyvis

def visualize_result(clustered_df, unclustered_df):
    visual_graph = pyvis.network.Network(select_menu=True, filter_menu=True)

    # Add nodes-nodes pair
    for _, row in clustered_df.iterrows():
        # Add nodes
        visual_graph.add_node(
            row["node_1_title"],
            label=row["node_1_title"],
            title=f"Predicted: {row['node_1_pred_cluster_new']}\nTitle: {row['node_1_title']}",
            group=row["node_1_cluster_kws_new"],
            cluster_num=row["node_1_pred_cluster_new"]
        )
        visual_graph.add_node(
            row["node_2_title"],
            label=row["node_2_title"],
            title=f"Predicted: {row['node_2_pred_cluster_new']}\nTitle: {row['node_2_title']}",
            group=row["node_2_cluster_kws_new"],
            cluster_num=row["node_2_pred_cluster_new"]
        )

        # Add edge
        visual_graph.add_edge(
            row["node_1_title"],
            row["node_2_title"],
            title=f"Edge Weight: {row['edge_weight']}",
        )

    # Add solo nodes
    for _, row in unclustered_df.iterrows():
        visual_graph.add_node(
            row["node_title"],
            label=row["node_title"],
            title=f"Predicted: No Community\nTitle: {row['node_title']}",
        )
    visual_graph.show(f"../data/07_model_output/neo4j_final_viz.html", notebook=False)

In [139]:
visualize_result(neo4j_clustered_df_new,unclustered_df)

../data/07_model_output/neo4j_final_viz.html
